xvector.py 22.1 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT.  If not, see <http://www.gnu.org/licenses/>.

"""
Anthony Larcher's avatar
Anthony Larcher committed
25
Copyright 2014-2019 Yevhenii Prokopalo, Anthony Larcher
Anthony Larcher's avatar
Anthony Larcher committed
26
27
28
29
30


The authors would like to thank the BUT Speech@FIT group (http://speech.fit.vutbr.cz) and Lukas BURGET
for sharing the source code that strongly inspired this module. Thank you for your valuable contribution.
"""
Anthony Larcher's avatar
Anthony Larcher committed
31

Anthony Larcher's avatar
Anthony Larcher committed
32
33
import h5py
import logging
Anthony Larcher's avatar
Anthony Larcher committed
34
import sys
Anthony Larcher's avatar
Anthony Larcher committed
35
36
import numpy
import torch
Anthony Larcher's avatar
Anthony Larcher committed
37
38
39
import torch.optim as optim
import torch.multiprocessing as mp
from collections import OrderedDict
Anthony Larcher's avatar
?    
Anthony Larcher committed
40
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
Anthony Larcher's avatar
Anthony Larcher committed
41
42
43
from sidekit.bosaris import IdMap
from sidekit.statserver import StatServer

Anthony Larcher's avatar
Anthony Larcher committed
44

Anthony Larcher's avatar
Anthony Larcher committed
45
46
47
48
49
50
51
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2019 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'
Anthony Larcher's avatar
Anthony Larcher committed
52
53


Anthony Larcher's avatar
Anthony Larcher committed
54
#logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Anthony Larcher's avatar
Anthony Larcher committed
55
56


Anthony Larcher's avatar
Anthony Larcher committed
57
58
59
60
61
62
63
def split_file_list(batch_files, num_processes):
    # Cut the list of files into args.num_processes lists of files
    batch_sub_lists = [[]] * num_processes
    x = [ii for ii in range(len(batch_files))]
    for ii in range(num_processes):
        batch_sub_lists[ii - 1] = [batch_files[z + ii] for z in x[::num_processes] if (z + ii) < len(batch_files)]
    return batch_sub_lists
Anthony Larcher's avatar
Anthony Larcher committed
64
65
66


class Xtractor(torch.nn.Module):
67
68
69
    """
    Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
    """
Anthony Larcher's avatar
Anthony Larcher committed
70
    def __init__(self, spk_number, dropout):
Anthony Larcher's avatar
Anthony Larcher committed
71
        super(Xtractor, self).__init__()
Anthony Larcher's avatar
test    
Anthony Larcher committed
72
        self.frame_conv0 = torch.nn.Conv1d(20, 512, 5, dilation=1)
Anthony Larcher's avatar
Anthony Larcher committed
73
74
75
        self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
        self.frame_conv2 = torch.nn.Conv1d(512, 512, 3, dilation=3)
        self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
Anthony Larcher's avatar
test    
Anthony Larcher committed
76
77
        self.frame_conv4 = torch.nn.Conv1d(512, 3 * 512, 1)
        self.seg_lin0 = torch.nn.Linear(3 * 512 * 2, 512)
Anthony Larcher's avatar
Anthony Larcher committed
78
        self.dropout_lin0 = torch.nn.Dropout(p=dropout)
Anthony Larcher's avatar
Anthony Larcher committed
79
        self.seg_lin1 = torch.nn.Linear(512, 512)
Anthony Larcher's avatar
Anthony Larcher committed
80
        self.dropout_lin1 = torch.nn.Dropout(p=dropout)
Anthony Larcher's avatar
Anthony Larcher committed
81
82
83
        self.seg_lin2 = torch.nn.Linear(512, spk_number)
        #
        self.norm0 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
84
85
86
        self.norm1 = torch.nn.BatchNorm1d(512)
        self.norm2 = torch.nn.BatchNorm1d(512)
        self.norm3 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
test    
Anthony Larcher committed
87
        self.norm4 = torch.nn.BatchNorm1d(3 * 512)
Anthony Larcher's avatar
Anthony Larcher committed
88
        self.norm6 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
89
        self.norm7 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
90
        #
Anthony Larcher's avatar
Anthony Larcher committed
91
        self.activation = torch.nn.LeakyReLU(0.2)
Anthony Larcher's avatar
Anthony Larcher committed
92

93
    def produce_embeddings(self, x):
Anthony Larcher's avatar
Anthony Larcher committed
94
        """
Anthony Larcher's avatar
Anthony Larcher committed
95

96
97
98
        :param x:
        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
99
100
101
102
103
        frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
        frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
        frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
        frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
        frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
Anthony Larcher's avatar
Anthony Larcher committed
104
105
106

        mean = torch.mean(frame_emb_4, dim=2)
        std = torch.std(frame_emb_4, dim=2)
107
        seg_emb = torch.cat([mean, std], dim=1)
Anthony Larcher's avatar
Anthony Larcher committed
108

109
110
        embedding_a = self.seg_lin0(seg_emb)
        return embedding_a
Anthony Larcher's avatar
Anthony Larcher committed
111
112

    def forward(self, x):
113
114
115
116
117
118
        """

        :param x:
        :return:
        """
        seg_emb_0 = self.produce_embeddings(x)
Anthony Larcher's avatar
Anthony Larcher committed
119
        # batch-normalisation after this layer
120
        seg_emb_1 = self.norm6(self.activation(seg_emb_0))
Anthony Larcher's avatar
Anthony Larcher committed
121
        # new layer with batch Normalization
122
        seg_emb_2 = self.norm7(self.activation(self.seg_lin1(self.dropout_lin1(seg_emb_1))))
Anthony Larcher's avatar
Anthony Larcher committed
123
        # No batch-normalisation after this layer
124
        result = self.activation(self.seg_lin2(seg_emb_2))
Anthony Larcher's avatar
Anthony Larcher committed
125
126
        return result

127
128
129
130
131
132
133
134
135
136
137
    def extract(self, x):
        """
        Extract x-vector given an input sequence of features

        :param x:
        :return:
        """
        embedding_a = self.produce_embeddings(x)
        embedding_b = self.seg_lin1(self.norm6(self.activation(embedding_a)))

        return embedding_a, embedding_b
Anthony Larcher's avatar
Anthony Larcher committed
138
139
140

    def init_weights(self):
        """
141
        Initialize the x-vector extract weights and biaises
Anthony Larcher's avatar
Anthony Larcher committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        """
        torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
        torch.nn.init.xavier_uniform(self.seg_lin0.weight)
        torch.nn.init.xavier_uniform(self.seg_lin1.weight)
        torch.nn.init.xavier_uniform(self.seg_lin2.weight)

        torch.nn.init.constant(self.frame_conv0.bias, 0.1)
        torch.nn.init.constant(self.frame_conv1.bias, 0.1)
        torch.nn.init.constant(self.frame_conv2.bias, 0.1)
        torch.nn.init.constant(self.frame_conv3.bias, 0.1)
        torch.nn.init.constant(self.frame_conv4.bias, 0.1)
        torch.nn.init.constant(self.seg_lin0.bias, 0.1)
        torch.nn.init.constant(self.seg_lin1.bias, 0.1)
        torch.nn.init.constant(self.seg_lin2.bias, 0.1)

Anthony Larcher's avatar
Anthony Larcher committed
161
162

def xtrain(args):
163
164
165
166
167
168
    """
    Initialize and train an x-vector in asynchronous manner

    :param args:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
169
    # Initialize a first model and save to disk
Anthony Larcher's avatar
Anthony Larcher committed
170
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
171
172
173
174
175
176
177
178
    current_model_file_name = "initial_model"
    torch.save(model.state_dict(), current_model_file_name)

    for epoch in range(1, args.epochs + 1):
        current_model_file_name = train_epoch(epoch, args, current_model_file_name)

        # Add the cross validation here
        accuracy = cross_validation(args, current_model_file_name)
Anthony Larcher's avatar
Anthony Larcher committed
179
        print("*** Cross validation accuracy = {} %".format(accuracy))
Anthony Larcher's avatar
Anthony Larcher committed
180

Anthony Larcher's avatar
Anthony Larcher committed
181
        # Decrease learning rate after every epoch
Anthony Larcher's avatar
Anthony Larcher committed
182
        #args.lr = args.lr * 0.9
Anthony Larcher's avatar
sad    
Anthony Larcher committed
183
184
        args.lr = args.lr * 0.9
        print("        Decrease learning rate: {}".format(args.lr))
Anthony Larcher's avatar
Anthony Larcher committed
185

Anthony Larcher's avatar
Anthony Larcher committed
186
187

def train_epoch(epoch, args, initial_model_file_name):
188
189
190
191
192
193
194
195
    """
    Process one training epoch using an asynchronous implementation of the training

    :param epoch:
    :param args:
    :param initial_model_file_name:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    # Compute the megabatch number
    with open(args.batch_training_list, 'r') as fh:
        batch_file_list = [l.rstrip() for l in fh]

    # Shorten the batch_file_list to be a multiple of

    megabatch_number = len(batch_file_list) // (args.averaging_step * args.num_processes)
    megabatch_size = args.averaging_step * args.num_processes
    print("Epoch {}, number of megabatches = {}".format(epoch, megabatch_number))

    current_model = initial_model_file_name

    # For each sublist: run an asynchronous training and averaging of the model
    for ii in range(megabatch_number):
        print('Process megabatch [{} / {}]'.format(ii + 1, megabatch_number))
        current_model = train_asynchronous(epoch,
                                           args,
                                           current_model,
                                           batch_file_list[megabatch_size * ii: megabatch_size * (ii + 1)],
                                           ii,
                                           megabatch_number)  # function that split train, fuse and write the new model to disk
    return current_model


220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
    """
    Process one mega-batch of data asynchronously, average the model parameters across
    subrocesses and return the updated version of the model

    :param epoch:
    :param args:
    :param initial_model_file_name:
    :param batch_file_list:
    :param megabatch_idx:
    :param megabatch_number:
    :return:
    """
    # Split the list of files for each process
    sub_lists = split_file_list(batch_file_list, args.num_processes)

    #
    output_queue = mp.Queue()
    # output_queue = multiprocessing.Queue()

    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=train_worker,
                       args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first train the model across `num_processes` processes
        p.start()
        processes.append(p)

    # Average the models and write the new one to disk
    asynchronous_model = []
    for ii in range(args.num_processes):
        asynchronous_model.append(dict(output_queue.get()))

    for p in processes:
        p.join()

    av_model = Xtractor(args.class_number, args.dropout)
    tmp = av_model.state_dict()

    average_param = dict()
    for k in list(asynchronous_model[0].keys()):
        average_param[k] = asynchronous_model[0][k]

        for mod in asynchronous_model[1:]:
            average_param[k] += mod[k]

        if 'num_batches_tracked' not in k:
            tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))

    # return the file name of the new model
    current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
                                                                     megabatch_idx)
    torch.save(tmp, current_model_file_name)
    if megabatch_idx == megabatch_number:
        torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))

    return current_model_file_name


Anthony Larcher's avatar
Anthony Larcher committed
280
def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
281
282
283
284
285
286
287
288
289
290
291
    """


    :param rank:
    :param epoch:
    :param args:
    :param initial_model_file_name:
    :param batch_list:
    :param output_queue:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
292
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
293
294
295
296
    model.load_state_dict(torch.load(initial_model_file_name))
    model.train()

    torch.manual_seed(args.seed + rank)
Anthony Larcher's avatar
Anthony Larcher committed
297
    train_loader = XvectorMultiDataset(batch_list, args.batch_path)
Anthony Larcher's avatar
Anthony Larcher committed
298
299
300
301
302
303
304
305
306
307
308
309
310

    device = torch.device("cuda:{}".format(rank))
    model.to(device)

    # optimizer = optim.Adam(model.parameters(), lr = args.lr)
    optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
Anthony Larcher's avatar
Anthony Larcher committed
311
                            ], lr=args.lr)
Anthony Larcher's avatar
Anthony Larcher committed
312

Anthony Larcher's avatar
Anthony Larcher committed
313
    criterion = torch.nn.CrossEntropyLoss()
Anthony Larcher's avatar
Anthony Larcher committed
314
315
316
317
318
319
320
321

    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()
Anthony Larcher's avatar
Anthony Larcher committed
322

Anthony Larcher's avatar
Anthony Larcher committed
323
        accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
Anthony Larcher's avatar
Anthony Larcher committed
324

Anthony Larcher's avatar
Anthony Larcher committed
325
326
327
328
329
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                epoch, batch_idx + 1, train_loader.__len__(),
                       100. * batch_idx / train_loader.__len__(), loss.item(),
                       100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
Anthony Larcher's avatar
Anthony Larcher committed
330

Anthony Larcher's avatar
Anthony Larcher committed
331
332
    model_param = OrderedDict()
    params = model.state_dict()
Anthony Larcher's avatar
Anthony Larcher committed
333

Anthony Larcher's avatar
Anthony Larcher committed
334
335
336
    for k in list(params.keys()):
        model_param[k] = params[k].cpu().detach().numpy()
    output_queue.put(model_param)
Anthony Larcher's avatar
Anthony Larcher committed
337
338
339



Anthony Larcher's avatar
Anthony Larcher committed
340
def cross_validation(args, current_model_file_name):
Anthony Larcher's avatar
Anthony Larcher committed
341
342
    """

Anthony Larcher's avatar
Anthony Larcher committed
343
344
    :param args:
    :param current_model_file_name:
Anthony Larcher's avatar
Anthony Larcher committed
345
346
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
347
    with open(args.cross_validation_list, 'r') as fh:
Anthony Larcher's avatar
Anthony Larcher committed
348
        cross_validation_list = [l.rstrip() for l in fh]
Anthony Larcher's avatar
Anthony Larcher committed
349
        sub_lists = split_file_list(cross_validation_list, args.num_processes)
Anthony Larcher's avatar
Anthony Larcher committed
350

Anthony Larcher's avatar
Anthony Larcher committed
351
352
    #
    output_queue = mp.Queue()
Anthony Larcher's avatar
Anthony Larcher committed
353

Anthony Larcher's avatar
Anthony Larcher committed
354
355
356
357
358
359
360
361
    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=cv_worker,
                       args=(rank, args, current_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first evaluate the model across `num_processes` processes
        p.start()
        processes.append(p)
Anthony Larcher's avatar
Anthony Larcher committed
362

Anthony Larcher's avatar
Anthony Larcher committed
363
364
365
366
    # Average the models and write the new one to disk
    result = []
    for ii in range(args.num_processes):
        result.append(output_queue.get())
Anthony Larcher's avatar
Anthony Larcher committed
367

Anthony Larcher's avatar
Anthony Larcher committed
368
369
    for p in processes:
        p.join()
Anthony Larcher's avatar
Anthony Larcher committed
370

Anthony Larcher's avatar
Anthony Larcher committed
371
372
373
    # Compute the global accuracy
    accuracy = 0.0
    total_batch_number = 0
Anthony Larcher's avatar
Anthony Larcher committed
374
    for bn, acc in result:
Anthony Larcher's avatar
Anthony Larcher committed
375
        accuracy += acc
Anthony Larcher's avatar
Anthony Larcher committed
376
377
        total_batch_number += bn
    
Anthony Larcher's avatar
Anthony Larcher committed
378
    return 100. * accuracy / (total_batch_number * args.batch_size)
Anthony Larcher's avatar
Anthony Larcher committed
379
380


Anthony Larcher's avatar
Anthony Larcher committed
381
def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
382
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
383
384
    model.load_state_dict(torch.load(current_model_file_name))
    model.eval()
Anthony Larcher's avatar
Anthony Larcher committed
385

Anthony Larcher's avatar
Anthony Larcher committed
386
    cv_loader = XvectorMultiDataset(batch_list, args.batch_path)
Anthony Larcher's avatar
Anthony Larcher committed
387

Anthony Larcher's avatar
Anthony Larcher committed
388
389
    device = torch.device("cuda:{}".format(rank))
    model.to(device)
Anthony Larcher's avatar
Anthony Larcher committed
390

Anthony Larcher's avatar
Anthony Larcher committed
391
    accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
392
    for batch_idx, (data, target) in enumerate(cv_loader):
Anthony Larcher's avatar
Anthony Larcher committed
393
394
        output = model(data.to(device))
        accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
Anthony Larcher's avatar
Anthony Larcher committed
395
    output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
Anthony Larcher's avatar
Anthony Larcher committed
396

Anthony Larcher's avatar
hot    
Anthony Larcher committed
397

Anthony Larcher's avatar
Anthony Larcher committed
398
def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
399
    """
Anthony Larcher's avatar
Anthony Larcher committed
400
401
    Function that takes a model and an idmap and extract all x-vectors based on this model
    and return a StatServer containing the x-vectors
Anthony Larcher's avatar
Anthony Larcher committed
402
    """
Anthony Larcher's avatar
Anthony Larcher committed
403
404
    #device = torch.device("cuda:{}".format(device_ID))
    device = torch.device('cpu')
Anthony Larcher's avatar
Anthony Larcher committed
405
406
407
408
409
410
411
412
413
414
415
416
417

    # Create the dataset
    tmp_idmap = IdMap(idmap_name)
    idmap = IdMap()
    idmap.leftids = tmp_idmap.leftids[segment_indices]
    idmap.rightids = tmp_idmap.rightids[segment_indices]
    idmap.start = tmp_idmap.start[segment_indices]
    idmap.stop = tmp_idmap.stop[segment_indices]

    segment_loader = StatDataset(idmap, fs_params)

    # Load the model
    model_file_name = '/'.join([args.model_path, args.model_name])
Anthony Larcher's avatar
Anthony Larcher committed
418
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
419
420
421
422
423
424
425
426
    model.load_state_dict(torch.load(model_file_name))
    model.eval()

    # Get the size of embeddings
    emb_a_size = model.seg_lin0.weight.data.shape[0]
    emb_b_size = model.seg_lin1.weight.data.shape[0]

    # Create a Tensor to store all x-vectors on the GPU
Anthony Larcher's avatar
Anthony Larcher committed
427
428
429
430
431
432
    emb_1 = numpy.zeros((idmap.leftids.shape[0], emb_a_size)).astype(numpy.float32)
    emb_2 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_3 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_4 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_5 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
    emb_6 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
Anthony Larcher's avatar
Anthony Larcher committed
433
434
435
436
437
438

    # Send on selected device
    model.to(device)

    # Loop to extract all x-vectors
    for idx, (model_id, segment_id, data) in enumerate(segment_loader):
Anthony Larcher's avatar
Anthony Larcher committed
439
440
        logging.critical('Process file {}, [{} / {}]'.format(segment_id, idx, segment_loader.__len__()))
        #print('Process file {}'.format(segment_id))
Anthony Larcher's avatar
Anthony Larcher committed
441
442
443
        if list(data.shape)[2] < 20:
            pass
        else:
Anthony Larcher's avatar
Anthony Larcher committed
444
445
446
447
448
449
450
            seg_1, seg_2, seg_3, seg_4, seg_5, seg_6 = model.extract(data.to(device))
            emb_1[idx, :] = seg_1.detach().cpu()
            emb_2[idx, :] = seg_2.detach().cpu()
            emb_3[idx, :] = seg_3.detach().cpu()
            emb_4[idx, :] = seg_4.detach().cpu()
            emb_5[idx, :] = seg_5.detach().cpu()
            emb_6[idx, :] = seg_6.detach().cpu()
Anthony Larcher's avatar
Anthony Larcher committed
451

Anthony Larcher's avatar
Anthony Larcher committed
452
    output_queue.put((segment_indices, emb_1, emb_2, emb_3, emb_4, emb_5, emb_6))
Anthony Larcher's avatar
Anthony Larcher committed
453
454


455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def xtrain_single(args):
    """
    Initialize and train an x-vector on a single GPU

    :param args:
    :return:
    """
    # Initialize a first model and save to disk
    model = Xtractor(args.class_number, args.dropout)
    model.train()
    model.cuda()

    current_model_file_name = "initial_model"
    torch.save(model.state_dict(), current_model_file_name)

    for epoch in range(1, args.epochs + 1):
        # Process one epoch and return the current model
Anthony Larcher's avatar
Anthony Larcher committed
472
        model = train_epoch_single(model, epoch, args)
473
474
475
476
477
478

        # Add the cross validation here
        #accuracy = cross_validation_single(args, current_model_file_name)
        #print("*** Cross validation accuracy = {} %".format(accuracy))

        # Decrease learning rate after every epoch
Anthony Larcher's avatar
Anthony Larcher committed
479
480
481
        args.lr = args.lr * 0.9
        args.lr = args.lr * 0.9
        print("        Decrease learning rate: {}".format(args.lr))
482

Anthony Larcher's avatar
Anthony Larcher committed
483
484
485
        # return the file name of the new model
        current_model_file_name = "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch)
        torch.save(model, current_model_file_name)
486

Anthony Larcher's avatar
Anthony Larcher committed
487
488

def train_epoch_single(model, epoch, args):
489
490
491
492
493
494
495
496
497
498
499
500
    """

    :param model:
    :param epoch:
    :param args:
    :param batch_list:
    :param output_queue:
    :return:
    """
    device =  device = torch.device("cuda:0")

    torch.manual_seed(args.seed)
Anthony Larcher's avatar
Anthony Larcher committed
501
502
503
504
505
506
507

    # Get the list of batches
    print(args.batch_training_list)

    with open(args.batch_training_list, 'r') as fh:
        batch_list = [l.rstrip() for l in fh]

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    train_loader = XvectorMultiDataset(batch_list, args.batch_path)

    optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
                            ], lr=args.lr)
    criterion = torch.nn.CrossEntropyLoss()

    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()

        accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                epoch, batch_idx + 1, train_loader.__len__(),
                       100. * batch_idx / train_loader.__len__(), loss.item(),
                       100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))

    return model

Anthony Larcher's avatar
Anthony Larcher committed
539
def extract_parallel(args, fs_params):
540
541
542
543
544
545
    """

    :param args:
    :param fs_params:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
546
547
548
    emb_a_size = 512
    emb_b_size = 512

Anthony Larcher's avatar
Anthony Larcher committed
549
    idmap = IdMap(args.idmap)
Anthony Larcher's avatar
Anthony Larcher committed
550

Anthony Larcher's avatar
Anthony Larcher committed
551
552
553
554
555
556
557
558
559
560
561
562
563
    x_server_1 = StatServer(idmap, 1, emb_a_size)
    x_server_2 = StatServer(idmap, 1, emb_b_size)
    x_server_3 = StatServer(idmap, 1, emb_b_size)
    x_server_4 = StatServer(idmap, 1, emb_b_size)
    x_server_5 = StatServer(idmap, 1, emb_b_size)
    x_server_6 = StatServer(idmap, 1, emb_b_size)

    x_server_1.stat0 = numpy.ones(x_server_1.stat0.shape)
    x_server_2.stat0 = numpy.ones(x_server_2.stat0.shape)
    x_server_3.stat0 = numpy.ones(x_server_3.stat0.shape)
    x_server_4.stat0 = numpy.ones(x_server_4.stat0.shape)
    x_server_5.stat0 = numpy.ones(x_server_5.stat0.shape)
    x_server_6.stat0 = numpy.ones(x_server_6.stat0.shape)
Anthony Larcher's avatar
Anthony Larcher committed
564
565
566

    # Split the indices
    mega_batch_size = idmap.leftids.shape[0] // args.num_processes
Anthony Larcher's avatar
Anthony Larcher committed
567
568
569

    logging.critical("Number of sessions to process: {}".format(idmap.leftids.shape[0]))

Anthony Larcher's avatar
Anthony Larcher committed
570
571
572
    segment_idx = []
    for ii in range(args.num_processes):
        segment_idx.append(
Anthony Larcher's avatar
Anthony Larcher committed
573
574
575
576
            numpy.arange(ii * mega_batch_size, numpy.min([(ii + 1) * mega_batch_size, idmap.leftids.shape[0]])))

    for idx, si in enumerate(segment_idx):
        logging.critical("Number of session on process {}: {}".format(idx, len(si)))
Anthony Larcher's avatar
Anthony Larcher committed
577
578
579
580
581
582
583

    # Extract x-vectors in parallel
    output_queue = mp.Queue()

    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=extract_idmap,
Anthony Larcher's avatar
Anthony Larcher committed
584
                       args=(args, rank, segment_idx[rank], fs_params, args.idmap, output_queue)
Anthony Larcher's avatar
Anthony Larcher committed
585
586
587
588
589
590
591
                       )
        # We first train the model across `num_processes` processes
        p.start()
        processes.append(p)

    # Get the x-vectors and fill the StatServer
    for ii in range(args.num_processes):
Anthony Larcher's avatar
Anthony Larcher committed
592
593
594
595
596
597
598
        indices, seg_1, seg_2, seg_3, seg_4, seg_5, seg_6 = output_queue.get()
        x_server_1.stat1[indices, :] = seg_1
        x_server_2.stat1[indices, :] = seg_2
        x_server_3.stat1[indices, :] = seg_3
        x_server_4.stat1[indices, :] = seg_4
        x_server_5.stat1[indices, :] = seg_5
        x_server_6.stat1[indices, :] = seg_6
Anthony Larcher's avatar
Anthony Larcher committed
599
600
601
602

    for p in processes:
        p.join()

Anthony Larcher's avatar
Anthony Larcher committed
603
    return x_server_1, x_server_2, x_server_3, x_server_4, x_server_5, x_server_6
Anthony Larcher's avatar
Anthony Larcher committed
604
605
606
607
608