xvector.py 21.7 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT.  If not, see <http://www.gnu.org/licenses/>.

"""
Anthony Larcher's avatar
Anthony Larcher committed
25
Copyright 2014-2019 Yevhenii Prokopalo, Anthony Larcher
Anthony Larcher's avatar
Anthony Larcher committed
26
27
28
29
30


The authors would like to thank the BUT Speech@FIT group (http://speech.fit.vutbr.cz) and Lukas BURGET
for sharing the source code that strongly inspired this module. Thank you for your valuable contribution.
"""
Anthony Larcher's avatar
Anthony Larcher committed
31

Anthony Larcher's avatar
Anthony Larcher committed
32
33
import h5py
import logging
Anthony Larcher's avatar
Anthony Larcher committed
34
import sys
Anthony Larcher's avatar
Anthony Larcher committed
35
36
import numpy
import torch
Anthony Larcher's avatar
Anthony Larcher committed
37
38
39
import torch.optim as optim
import torch.multiprocessing as mp
from collections import OrderedDict
Anthony Larcher's avatar
?    
Anthony Larcher committed
40
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
Anthony Larcher's avatar
Anthony Larcher committed
41
42
43
from sidekit.bosaris import IdMap
from sidekit.statserver import StatServer

Anthony Larcher's avatar
Anthony Larcher committed
44

Anthony Larcher's avatar
Anthony Larcher committed
45
46
47
48
49
50
51
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2019 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'
Anthony Larcher's avatar
Anthony Larcher committed
52
53


Anthony Larcher's avatar
Anthony Larcher committed
54
#logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Anthony Larcher's avatar
Anthony Larcher committed
55
56


Anthony Larcher's avatar
Anthony Larcher committed
57
58
59
60
61
62
63
def split_file_list(batch_files, num_processes):
    # Cut the list of files into args.num_processes lists of files
    batch_sub_lists = [[]] * num_processes
    x = [ii for ii in range(len(batch_files))]
    for ii in range(num_processes):
        batch_sub_lists[ii - 1] = [batch_files[z + ii] for z in x[::num_processes] if (z + ii) < len(batch_files)]
    return batch_sub_lists
Anthony Larcher's avatar
Anthony Larcher committed
64
65
66


class Xtractor(torch.nn.Module):
Anthony Larcher's avatar
Anthony Larcher committed
67
    def __init__(self, spk_number, dropout):
Anthony Larcher's avatar
Anthony Larcher committed
68
        super(Xtractor, self).__init__()
Anthony Larcher's avatar
test    
Anthony Larcher committed
69
        self.frame_conv0 = torch.nn.Conv1d(20, 512, 5, dilation=1)
Anthony Larcher's avatar
Anthony Larcher committed
70
71
72
        self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
        self.frame_conv2 = torch.nn.Conv1d(512, 512, 3, dilation=3)
        self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
Anthony Larcher's avatar
test    
Anthony Larcher committed
73
74
        self.frame_conv4 = torch.nn.Conv1d(512, 3 * 512, 1)
        self.seg_lin0 = torch.nn.Linear(3 * 512 * 2, 512)
Anthony Larcher's avatar
Anthony Larcher committed
75
        self.dropout_lin0 = torch.nn.Dropout(p=dropout)
Anthony Larcher's avatar
Anthony Larcher committed
76
        self.seg_lin1 = torch.nn.Linear(512, 512)
Anthony Larcher's avatar
Anthony Larcher committed
77
        self.dropout_lin1 = torch.nn.Dropout(p=dropout)
Anthony Larcher's avatar
Anthony Larcher committed
78
79
80
        self.seg_lin2 = torch.nn.Linear(512, spk_number)
        #
        self.norm0 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
81
82
83
        self.norm1 = torch.nn.BatchNorm1d(512)
        self.norm2 = torch.nn.BatchNorm1d(512)
        self.norm3 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
test    
Anthony Larcher committed
84
        self.norm4 = torch.nn.BatchNorm1d(3 * 512)
Anthony Larcher's avatar
Anthony Larcher committed
85
        self.norm6 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
86
        self.norm7 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
87
        #
Anthony Larcher's avatar
Anthony Larcher committed
88
        self.activation = torch.nn.LeakyReLU(0.2)
Anthony Larcher's avatar
Anthony Larcher committed
89
90

    def forward(self, x):
Anthony Larcher's avatar
Anthony Larcher committed
91
92
93
94
95
96
97
98
99
100
        frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
        frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
        frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
        frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
        frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
        # Pooling Layer that computes mean and standard devition of frame level embeddings
        # The output of the pooling layer is the first segment-level representation
        mean = torch.mean(frame_emb_4, dim=2)
        std = torch.std(frame_emb_4, dim=2)
        seg_emb_0 = torch.cat([mean, std], dim=1)
Anthony Larcher's avatar
Anthony Larcher committed
101
        # batch-normalisation after this layer
Anthony Larcher's avatar
Anthony Larcher committed
102
        seg_emb_1 = self.dropout_lin0(seg_emb_0)
Anthony Larcher's avatar
Anthony Larcher committed
103
        seg_emb_2 = self.norm6(self.activation(self.seg_lin0(seg_emb_1)))
Anthony Larcher's avatar
Anthony Larcher committed
104
        # new layer with batch Normalization
Anthony Larcher's avatar
Anthony Larcher committed
105
        seg_emb_3 = self.dropout_lin1(seg_emb_2)
Anthony Larcher's avatar
Anthony Larcher committed
106
        seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
Anthony Larcher's avatar
Anthony Larcher committed
107
        # No batch-normalisation after this layer
Anthony Larcher's avatar
Anthony Larcher committed
108
        seg_emb_5 = self.seg_lin2(seg_emb_4)
Anthony Larcher's avatar
Anthony Larcher committed
109
        result = self.activation(seg_emb_5)
Anthony Larcher's avatar
Anthony Larcher committed
110
        return result
Anthony Larcher's avatar
Anthony Larcher committed
111
112

    def init_weights(self):
Anthony Larcher's avatar
Anthony Larcher committed
113
        """
Anthony Larcher's avatar
Anthony Larcher committed
114
        """
Anthony Larcher's avatar
Anthony Larcher committed
115
116
117
118
119
120
121
122
        torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
        torch.nn.init.xavier_uniform(self.seg_lin0.weight)
        torch.nn.init.xavier_uniform(self.seg_lin1.weight)
        torch.nn.init.xavier_uniform(self.seg_lin2.weight)
Anthony Larcher's avatar
hot    
Anthony Larcher committed
123

Anthony Larcher's avatar
Anthony Larcher committed
124
125
126
127
128
129
130
131
        torch.nn.init.constant(self.frame_conv0.bias, 0.1)
        torch.nn.init.constant(self.frame_conv1.bias, 0.1)
        torch.nn.init.constant(self.frame_conv2.bias, 0.1)
        torch.nn.init.constant(self.frame_conv3.bias, 0.1)
        torch.nn.init.constant(self.frame_conv4.bias, 0.1)
        torch.nn.init.constant(self.seg_lin0.bias, 0.1)
        torch.nn.init.constant(self.seg_lin1.bias, 0.1)
        torch.nn.init.constant(self.seg_lin2.bias, 0.1)
Anthony Larcher's avatar
Anthony Larcher committed
132
133
134
135
136
137
138

    def extract(self, x):
        frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
        frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
        frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
        frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
        frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
Anthony Larcher's avatar
Anthony Larcher committed
139
140
141
142

        mean = torch.mean(frame_emb_4, dim=2)
        std = torch.std(frame_emb_4, dim=2)

Anthony Larcher's avatar
Anthony Larcher committed
143
144
145
146
147
148
149
150
        seg_emb_0 = torch.cat([mean, std], dim=1)
        # batch-normalisation after this layer
        seg_emb_1 = self.seg_lin0(seg_emb_0)
        seg_emb_2 = self.activation(seg_emb_1)
        seg_emb_3 = self.norm6(seg_emb_2)
        seg_emb_4 = self.seg_lin1(seg_emb_3)
        seg_emb_5 = self.activation(seg_emb_4)
        seg_emb_6 = self.norm7(seg_emb_5)
Anthony Larcher's avatar
Anthony Larcher committed
151

Anthony Larcher's avatar
Anthony Larcher committed
152
        return seg_emb_1, seg_emb_2, seg_emb_3, seg_emb_4, seg_emb_5, seg_emb_6
Anthony Larcher's avatar
Anthony Larcher committed
153
154
155
156
157
158
159

    def forward(self, x):
        frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
        frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
        frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
        frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
        frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
Anthony Larcher's avatar
Anthony Larcher committed
160
161
162
163
        # Pooling Layer that computes mean and standard devition of frame level embeddings
        # The output of the pooling layer is the first segment-level representation
        mean = torch.mean(frame_emb_4, dim=2)
        std = torch.std(frame_emb_4, dim=2)
Anthony Larcher's avatar
Anthony Larcher committed
164
165
166
167
168
169
170
        seg_emb_0 = torch.cat([mean, std], dim=1)
        # batch-normalisation after this layer
        seg_emb_1 = self.dropout_lin0(seg_emb_0)
        seg_emb_2 = self.norm6(self.activation(self.seg_lin0(seg_emb_1)))
        # new layer with batch Normalization
        seg_emb_3 = self.dropout_lin1(seg_emb_2)
        seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
Anthony Larcher's avatar
Anthony Larcher committed
171
        # No batch-normalisation after this layer
Anthony Larcher's avatar
Anthony Larcher committed
172
        seg_emb_5 = self.seg_lin2(seg_emb_4)
Anthony Larcher's avatar
Anthony Larcher committed
173
        result = self.activation(seg_emb_5)
Anthony Larcher's avatar
Anthony Larcher committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        return result

    def LossFN(self, x, label):
        loss = - torch.trace(torch.mm(torch.log10(x), torch.t(label)))
        return loss

    def init_weights(self):
        """
        """
        torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
        torch.nn.init.xavier_uniform(self.seg_lin0.weight)
        torch.nn.init.xavier_uniform(self.seg_lin1.weight)
        torch.nn.init.xavier_uniform(self.seg_lin2.weight)

        torch.nn.init.constant(self.frame_conv0.bias, 0.1)
        torch.nn.init.constant(self.frame_conv1.bias, 0.1)
        torch.nn.init.constant(self.frame_conv2.bias, 0.1)
        torch.nn.init.constant(self.frame_conv3.bias, 0.1)
        torch.nn.init.constant(self.frame_conv4.bias, 0.1)
        torch.nn.init.constant(self.seg_lin0.bias, 0.1)
        torch.nn.init.constant(self.seg_lin1.bias, 0.1)
        torch.nn.init.constant(self.seg_lin2.bias, 0.1)

    def extract(self, x):
        frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
        frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
        frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
        frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
        frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
Anthony Larcher's avatar
Anthony Larcher committed
207

Anthony Larcher's avatar
Anthony Larcher committed
208
209
210
        mean = torch.mean(frame_emb_4, dim=2)
        std = torch.std(frame_emb_4, dim=2)
        seg_emb = torch.cat([mean, std], dim=1)
Anthony Larcher's avatar
Anthony Larcher committed
211

Anthony Larcher's avatar
Anthony Larcher committed
212
213
214
215
        embedding_A = self.seg_lin0(seg_emb)
        embedding_B = self.seg_lin1(self.norm6(self.activation(embedding_A)))

        return embedding_A, embedding_B
Anthony Larcher's avatar
Anthony Larcher committed
216
217
218
219


def xtrain(args):
    # Initialize a first model and save to disk
Anthony Larcher's avatar
Anthony Larcher committed
220
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
221
222
223
224
225
226
227
228
    current_model_file_name = "initial_model"
    torch.save(model.state_dict(), current_model_file_name)

    for epoch in range(1, args.epochs + 1):
        current_model_file_name = train_epoch(epoch, args, current_model_file_name)

        # Add the cross validation here
        accuracy = cross_validation(args, current_model_file_name)
Anthony Larcher's avatar
Anthony Larcher committed
229
        print("*** Cross validation accuracy = {} %".format(accuracy))
Anthony Larcher's avatar
Anthony Larcher committed
230

Anthony Larcher's avatar
Anthony Larcher committed
231
        # Decrease learning rate after every epoch
Anthony Larcher's avatar
Anthony Larcher committed
232
        #args.lr = args.lr * 0.9
Anthony Larcher's avatar
sad    
Anthony Larcher committed
233
234
235
        args.lr = args.lr * 0.9
        print("        Decrease learning rate: {}".format(args.lr))
        
Anthony Larcher's avatar
Anthony Larcher committed
236

Anthony Larcher's avatar
Anthony Larcher committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

def train_epoch(epoch, args, initial_model_file_name):
    # Compute the megabatch number
    with open(args.batch_training_list, 'r') as fh:
        batch_file_list = [l.rstrip() for l in fh]

    # Shorten the batch_file_list to be a multiple of

    megabatch_number = len(batch_file_list) // (args.averaging_step * args.num_processes)
    megabatch_size = args.averaging_step * args.num_processes
    print("Epoch {}, number of megabatches = {}".format(epoch, megabatch_number))

    current_model = initial_model_file_name

    # For each sublist: run an asynchronous training and averaging of the model
    for ii in range(megabatch_number):
        print('Process megabatch [{} / {}]'.format(ii + 1, megabatch_number))
        current_model = train_asynchronous(epoch,
                                           args,
                                           current_model,
                                           batch_file_list[megabatch_size * ii: megabatch_size * (ii + 1)],
                                           ii,
                                           megabatch_number)  # function that split train, fuse and write the new model to disk
    return current_model


def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
264
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
265
266
267
268
    model.load_state_dict(torch.load(initial_model_file_name))
    model.train()

    torch.manual_seed(args.seed + rank)
Anthony Larcher's avatar
Anthony Larcher committed
269
    train_loader = XvectorMultiDataset(batch_list, args.batch_path)
Anthony Larcher's avatar
Anthony Larcher committed
270
271
272
273
274
275
276
277
278
279
280
281
282

    device = torch.device("cuda:{}".format(rank))
    model.to(device)

    # optimizer = optim.Adam(model.parameters(), lr = args.lr)
    optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
Anthony Larcher's avatar
Anthony Larcher committed
283
                            ], lr=args.lr)
Anthony Larcher's avatar
Anthony Larcher committed
284
285


Anthony Larcher's avatar
Anthony Larcher committed
286
    #criterion = torch.nn.CrossEntropyLoss(reduction='sum')
Anthony Larcher's avatar
Anthony Larcher committed
287
    #criterion = torch.nn.NLLLoss()
Anthony Larcher's avatar
Anthony Larcher committed
288
    criterion = torch.nn.CrossEntropyLoss()
Anthony Larcher's avatar
Anthony Larcher committed
289
290
291
292
293
294
295
296

    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()
Anthony Larcher's avatar
Anthony Larcher committed
297

Anthony Larcher's avatar
Anthony Larcher committed
298
        accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
Anthony Larcher's avatar
Anthony Larcher committed
299

Anthony Larcher's avatar
Anthony Larcher committed
300
301
302
303
304
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                epoch, batch_idx + 1, train_loader.__len__(),
                       100. * batch_idx / train_loader.__len__(), loss.item(),
                       100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
Anthony Larcher's avatar
Anthony Larcher committed
305

Anthony Larcher's avatar
Anthony Larcher committed
306
307
    model_param = OrderedDict()
    params = model.state_dict()
Anthony Larcher's avatar
Anthony Larcher committed
308

Anthony Larcher's avatar
Anthony Larcher committed
309
310
311
    for k in list(params.keys()):
        model_param[k] = params[k].cpu().detach().numpy()
    output_queue.put(model_param)
Anthony Larcher's avatar
Anthony Larcher committed
312
313


Anthony Larcher's avatar
Anthony Larcher committed
314
315
316
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
    # Split the list of files for each process
    sub_lists = split_file_list(batch_file_list, args.num_processes)
Anthony Larcher's avatar
Anthony Larcher committed
317

Anthony Larcher's avatar
Anthony Larcher committed
318
319
320
    #
    output_queue = mp.Queue()
    # output_queue = multiprocessing.Queue()
Anthony Larcher's avatar
Anthony Larcher committed
321

Anthony Larcher's avatar
Anthony Larcher committed
322
323
324
325
326
327
328
329
    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=train_worker,
                       args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first train the model across `num_processes` processes
        p.start()
        processes.append(p)
Anthony Larcher's avatar
Anthony Larcher committed
330

Anthony Larcher's avatar
Anthony Larcher committed
331
332
333
334
    # Average the models and write the new one to disk
    asynchronous_model = []
    for ii in range(args.num_processes):
        asynchronous_model.append(dict(output_queue.get()))
Anthony Larcher's avatar
Anthony Larcher committed
335

Anthony Larcher's avatar
Anthony Larcher committed
336
337
    for p in processes:
        p.join()
Anthony Larcher's avatar
Anthony Larcher committed
338

Anthony Larcher's avatar
Anthony Larcher committed
339
    av_model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
340
    tmp = av_model.state_dict()
Anthony Larcher's avatar
Anthony Larcher committed
341

Anthony Larcher's avatar
Anthony Larcher committed
342
343
344
    average_param = dict()
    for k in list(asynchronous_model[0].keys()):
        average_param[k] = asynchronous_model[0][k]
Anthony Larcher's avatar
Anthony Larcher committed
345

Anthony Larcher's avatar
Anthony Larcher committed
346
347
        for mod in asynchronous_model[1:]:
            average_param[k] += mod[k]
Anthony Larcher's avatar
Anthony Larcher committed
348

Anthony Larcher's avatar
Anthony Larcher committed
349
350
        if 'num_batches_tracked' not in k:
            tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))
Anthony Larcher's avatar
Anthony Larcher committed
351

Anthony Larcher's avatar
Anthony Larcher committed
352
353
354
355
356
357
    # return the file name of the new model
    current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
                                                                     megabatch_idx)
    torch.save(tmp, current_model_file_name)
    if megabatch_idx == megabatch_number:
        torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))
Anthony Larcher's avatar
Anthony Larcher committed
358

Anthony Larcher's avatar
Anthony Larcher committed
359
    return current_model_file_name
Anthony Larcher's avatar
Anthony Larcher committed
360

Anthony Larcher's avatar
Anthony Larcher committed
361
def cross_validation(args, current_model_file_name):
Anthony Larcher's avatar
Anthony Larcher committed
362
363
    """

Anthony Larcher's avatar
Anthony Larcher committed
364
365
    :param args:
    :param current_model_file_name:
Anthony Larcher's avatar
Anthony Larcher committed
366
367
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
368
    with open(args.cross_validation_list, 'r') as fh:
Anthony Larcher's avatar
Anthony Larcher committed
369
        cross_validation_list = [l.rstrip() for l in fh]
Anthony Larcher's avatar
Anthony Larcher committed
370
        sub_lists = split_file_list(cross_validation_list, args.num_processes)
Anthony Larcher's avatar
Anthony Larcher committed
371

Anthony Larcher's avatar
Anthony Larcher committed
372
373
    #
    output_queue = mp.Queue()
Anthony Larcher's avatar
Anthony Larcher committed
374

Anthony Larcher's avatar
Anthony Larcher committed
375
376
377
378
379
380
381
382
    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=cv_worker,
                       args=(rank, args, current_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first evaluate the model across `num_processes` processes
        p.start()
        processes.append(p)
Anthony Larcher's avatar
Anthony Larcher committed
383

Anthony Larcher's avatar
Anthony Larcher committed
384
385
386
387
    # Average the models and write the new one to disk
    result = []
    for ii in range(args.num_processes):
        result.append(output_queue.get())
Anthony Larcher's avatar
Anthony Larcher committed
388

Anthony Larcher's avatar
Anthony Larcher committed
389
390
    for p in processes:
        p.join()
Anthony Larcher's avatar
Anthony Larcher committed
391

Anthony Larcher's avatar
Anthony Larcher committed
392
393
394
    # Compute the global accuracy
    accuracy = 0.0
    total_batch_number = 0
Anthony Larcher's avatar
Anthony Larcher committed
395
    for bn, acc in result:
Anthony Larcher's avatar
Anthony Larcher committed
396
        accuracy += acc
Anthony Larcher's avatar
Anthony Larcher committed
397
398
        total_batch_number += bn
    
Anthony Larcher's avatar
Anthony Larcher committed
399
    return 100. * accuracy / (total_batch_number * args.batch_size)
Anthony Larcher's avatar
Anthony Larcher committed
400
401


Anthony Larcher's avatar
Anthony Larcher committed
402
def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
403
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
404
405
    model.load_state_dict(torch.load(current_model_file_name))
    model.eval()
Anthony Larcher's avatar
Anthony Larcher committed
406

Anthony Larcher's avatar
Anthony Larcher committed
407
    cv_loader = XvectorMultiDataset(batch_list, args.batch_path)
Anthony Larcher's avatar
Anthony Larcher committed
408

Anthony Larcher's avatar
Anthony Larcher committed
409
410
    device = torch.device("cuda:{}".format(rank))
    model.to(device)
Anthony Larcher's avatar
Anthony Larcher committed
411

Anthony Larcher's avatar
Anthony Larcher committed
412
    accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
413
    for batch_idx, (data, target) in enumerate(cv_loader):
Anthony Larcher's avatar
Anthony Larcher committed
414
415
        output = model(data.to(device))
        accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
Anthony Larcher's avatar
Anthony Larcher committed
416
    output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
Anthony Larcher's avatar
Anthony Larcher committed
417

Anthony Larcher's avatar
hot    
Anthony Larcher committed
418

Anthony Larcher's avatar
Anthony Larcher committed
419
def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
420
    """
Anthony Larcher's avatar
Anthony Larcher committed
421
422
    Function that takes a model and an idmap and extract all x-vectors based on this model
    and return a StatServer containing the x-vectors
Anthony Larcher's avatar
Anthony Larcher committed
423
    """
Anthony Larcher's avatar
Anthony Larcher committed
424
425
    #device = torch.device("cuda:{}".format(device_ID))
    device = torch.device('cpu')
Anthony Larcher's avatar
Anthony Larcher committed
426
427
428
429
430
431
432
433
434
435
436
437
438

    # Create the dataset
    tmp_idmap = IdMap(idmap_name)
    idmap = IdMap()
    idmap.leftids = tmp_idmap.leftids[segment_indices]
    idmap.rightids = tmp_idmap.rightids[segment_indices]
    idmap.start = tmp_idmap.start[segment_indices]
    idmap.stop = tmp_idmap.stop[segment_indices]

    segment_loader = StatDataset(idmap, fs_params)

    # Load the model
    model_file_name = '/'.join([args.model_path, args.model_name])
Anthony Larcher's avatar
Anthony Larcher committed
439
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
440
441
442
443
444
445
446
447
    model.load_state_dict(torch.load(model_file_name))
    model.eval()

    # Get the size of embeddings
    emb_a_size = model.seg_lin0.weight.data.shape[0]
    emb_b_size = model.seg_lin1.weight.data.shape[0]

    # Create a Tensor to store all x-vectors on the GPU
Anthony Larcher's avatar
Anthony Larcher committed
448
449
450
451
452
453
    emb_1 = numpy.zeros((idmap.leftids.shape[0], emb_a_size)).astype(numpy.float32)
    emb_2 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_3 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_4 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_5 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_6 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
Anthony Larcher's avatar
Anthony Larcher committed
454
455
456
457
458
459

    # Send on selected device
    model.to(device)

    # Loop to extract all x-vectors
    for idx, (model_id, segment_id, data) in enumerate(segment_loader):
Anthony Larcher's avatar
Anthony Larcher committed
460
461
        logging.critical('Process file {}, [{} / {}]'.format(segment_id, idx, segment_loader.__len__()))
        #print('Process file {}'.format(segment_id))
Anthony Larcher's avatar
Anthony Larcher committed
462
463
464
        if list(data.shape)[2] < 20:
            pass
        else:
Anthony Larcher's avatar
Anthony Larcher committed
465
466
467
468
469
470
471
            seg_1, seg_2, seg_3, seg_4, seg_5, seg_6 = model.extract(data.to(device))
            emb_1[idx, :] = seg_1.detach().cpu()
            emb_2[idx, :] = seg_2.detach().cpu()
            emb_3[idx, :] = seg_3.detach().cpu()
            emb_4[idx, :] = seg_4.detach().cpu()
            emb_5[idx, :] = seg_5.detach().cpu()
            emb_6[idx, :] = seg_6.detach().cpu()
Anthony Larcher's avatar
Anthony Larcher committed
472

Anthony Larcher's avatar
Anthony Larcher committed
473
    output_queue.put((segment_indices, emb_1, emb_2, emb_3, emb_4, emb_5, emb_6))
Anthony Larcher's avatar
Anthony Larcher committed
474
475


Anthony Larcher's avatar
Anthony Larcher committed
476
def extract_parallel(args, fs_params):
Anthony Larcher's avatar
Anthony Larcher committed
477
478
479
    emb_a_size = 512
    emb_b_size = 512

Anthony Larcher's avatar
Anthony Larcher committed
480
    idmap = IdMap(args.idmap)
Anthony Larcher's avatar
Anthony Larcher committed
481

Anthony Larcher's avatar
Anthony Larcher committed
482
483
484
485
486
487
488
489
490
491
492
493
494
    x_server_1 = StatServer(idmap, 1, emb_a_size)
    x_server_2 = StatServer(idmap, 1, emb_b_size)
    x_server_3 = StatServer(idmap, 1, emb_b_size)
    x_server_4 = StatServer(idmap, 1, emb_b_size)
    x_server_5 = StatServer(idmap, 1, emb_b_size)
    x_server_6 = StatServer(idmap, 1, emb_b_size)

    x_server_1.stat0 = numpy.ones(x_server_1.stat0.shape)
    x_server_2.stat0 = numpy.ones(x_server_2.stat0.shape)
    x_server_3.stat0 = numpy.ones(x_server_3.stat0.shape)
    x_server_4.stat0 = numpy.ones(x_server_4.stat0.shape)
    x_server_5.stat0 = numpy.ones(x_server_5.stat0.shape)
    x_server_6.stat0 = numpy.ones(x_server_6.stat0.shape)
Anthony Larcher's avatar
Anthony Larcher committed
495
496
497

    # Split the indices
    mega_batch_size = idmap.leftids.shape[0] // args.num_processes
Anthony Larcher's avatar
Anthony Larcher committed
498
499
500

    logging.critical("Number of sessions to process: {}".format(idmap.leftids.shape[0]))

Anthony Larcher's avatar
Anthony Larcher committed
501
502
503
    segment_idx = []
    for ii in range(args.num_processes):
        segment_idx.append(
Anthony Larcher's avatar
Anthony Larcher committed
504
505
506
507
            numpy.arange(ii * mega_batch_size, numpy.min([(ii + 1) * mega_batch_size, idmap.leftids.shape[0]])))

    for idx, si in enumerate(segment_idx):
        logging.critical("Number of session on process {}: {}".format(idx, len(si)))
Anthony Larcher's avatar
Anthony Larcher committed
508
509
510
511
512
513
514

    # Extract x-vectors in parallel
    output_queue = mp.Queue()

    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=extract_idmap,
Anthony Larcher's avatar
Anthony Larcher committed
515
                       args=(args, rank, segment_idx[rank], fs_params, args.idmap, output_queue)
Anthony Larcher's avatar
Anthony Larcher committed
516
517
518
519
520
521
522
                       )
        # We first train the model across `num_processes` processes
        p.start()
        processes.append(p)

    # Get the x-vectors and fill the StatServer
    for ii in range(args.num_processes):
Anthony Larcher's avatar
Anthony Larcher committed
523
524
525
526
527
528
529
        indices, seg_1, seg_2, seg_3, seg_4, seg_5, seg_6 = output_queue.get()
        x_server_1.stat1[indices, :] = seg_1
        x_server_2.stat1[indices, :] = seg_2
        x_server_3.stat1[indices, :] = seg_3
        x_server_4.stat1[indices, :] = seg_4
        x_server_5.stat1[indices, :] = seg_5
        x_server_6.stat1[indices, :] = seg_6
Anthony Larcher's avatar
Anthony Larcher committed
530
531
532
533

    for p in processes:
        p.join()

Anthony Larcher's avatar
Anthony Larcher committed
534
    return x_server_1, x_server_2, x_server_3, x_server_4, x_server_5, x_server_6
Anthony Larcher's avatar
Anthony Larcher committed
535
536
537
538
539