xvector.py 25.2 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT.  If not, see <http://www.gnu.org/licenses/>.

"""
Anthony Larcher's avatar
Anthony Larcher committed
25
Copyright 2014-2019 Yevhenii Prokopalo, Anthony Larcher
Anthony Larcher's avatar
Anthony Larcher committed
26
27
28
29
30


The authors would like to thank the BUT Speech@FIT group (http://speech.fit.vutbr.cz) and Lukas BURGET
for sharing the source code that strongly inspired this module. Thank you for your valuable contribution.
"""
Anthony Larcher's avatar
Anthony Larcher committed
31

Anthony Larcher's avatar
Anthony Larcher committed
32
33
34
35
import h5py
import logging
import numpy
import torch
Anthony Larcher's avatar
Anthony Larcher committed
36
37
38
import torch.optim as optim
import torch.multiprocessing as mp
from collections import OrderedDict
Anthony Larcher's avatar
hot    
Anthony Larcher committed
39
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset, XvectorMultiDataset_hot
Anthony Larcher's avatar
Anthony Larcher committed
40
41
42
from sidekit.bosaris import IdMap
from sidekit.statserver import StatServer

Anthony Larcher's avatar
Anthony Larcher committed
43

Anthony Larcher's avatar
Anthony Larcher committed
44
45
46
47
48
49
50
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2019 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'
Anthony Larcher's avatar
Anthony Larcher committed
51
52
53
54
55





Anthony Larcher's avatar
Anthony Larcher committed
56
57
58
59
60
61
62
def split_file_list(batch_files, num_processes):
    # Cut the list of files into args.num_processes lists of files
    batch_sub_lists = [[]] * num_processes
    x = [ii for ii in range(len(batch_files))]
    for ii in range(num_processes):
        batch_sub_lists[ii - 1] = [batch_files[z + ii] for z in x[::num_processes] if (z + ii) < len(batch_files)]
    return batch_sub_lists
Anthony Larcher's avatar
Anthony Larcher committed
63
64
65


class Xtractor(torch.nn.Module):
Anthony Larcher's avatar
Anthony Larcher committed
66
    def __init__(self, spk_number, dropout):
Anthony Larcher's avatar
Anthony Larcher committed
67
        super(Xtractor, self).__init__()
Anthony Larcher's avatar
test    
Anthony Larcher committed
68
        self.frame_conv0 = torch.nn.Conv1d(20, 512, 5, dilation=1)
Anthony Larcher's avatar
Anthony Larcher committed
69
70
71
        self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
        self.frame_conv2 = torch.nn.Conv1d(512, 512, 3, dilation=3)
        self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
Anthony Larcher's avatar
test    
Anthony Larcher committed
72
73
        self.frame_conv4 = torch.nn.Conv1d(512, 3 * 512, 1)
        self.seg_lin0 = torch.nn.Linear(3 * 512 * 2, 512)
Anthony Larcher's avatar
Anthony Larcher committed
74
        self.dropout_lin0 = torch.nn.Dropout(p=dropout)
Anthony Larcher's avatar
Anthony Larcher committed
75
        self.seg_lin1 = torch.nn.Linear(512, 512)
Anthony Larcher's avatar
Anthony Larcher committed
76
        self.dropout_lin1 = torch.nn.Dropout(p=dropout)
Anthony Larcher's avatar
Anthony Larcher committed
77
78
79
        self.seg_lin2 = torch.nn.Linear(512, spk_number)
        #
        self.norm0 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
80
81
82
        self.norm1 = torch.nn.BatchNorm1d(512)
        self.norm2 = torch.nn.BatchNorm1d(512)
        self.norm3 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
test    
Anthony Larcher committed
83
        self.norm4 = torch.nn.BatchNorm1d(3 * 512)
Anthony Larcher's avatar
Anthony Larcher committed
84
        self.norm6 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
85
        self.norm7 = torch.nn.BatchNorm1d(512)
Anthony Larcher's avatar
Anthony Larcher committed
86
        #
Anthony Larcher's avatar
Anthony Larcher committed
87
        self.activation = torch.nn.LeakyReLU(0.2)
Anthony Larcher's avatar
Anthony Larcher committed
88
89

    def forward(self, x):
Anthony Larcher's avatar
Anthony Larcher committed
90
91
92
93
94
95
96
97
98
99
        frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
        frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
        frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
        frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
        frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
        # Pooling Layer that computes mean and standard devition of frame level embeddings
        # The output of the pooling layer is the first segment-level representation
        mean = torch.mean(frame_emb_4, dim=2)
        std = torch.std(frame_emb_4, dim=2)
        seg_emb_0 = torch.cat([mean, std], dim=1)
Anthony Larcher's avatar
Anthony Larcher committed
100
        # batch-normalisation after this layer
Anthony Larcher's avatar
Anthony Larcher committed
101
        seg_emb_1 = self.dropout_lin0(seg_emb_0)
Anthony Larcher's avatar
Anthony Larcher committed
102
        seg_emb_2 = self.norm6(self.activation(self.seg_lin0(seg_emb_1)))
Anthony Larcher's avatar
Anthony Larcher committed
103
        # new layer with batch Normalization
Anthony Larcher's avatar
Anthony Larcher committed
104
        seg_emb_3 = self.dropout_lin1(seg_emb_2)
Anthony Larcher's avatar
Anthony Larcher committed
105
        seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
Anthony Larcher's avatar
Anthony Larcher committed
106
        # No batch-normalisation after this layer
Anthony Larcher's avatar
Anthony Larcher committed
107
        seg_emb_5 = self.seg_lin2(seg_emb_4)
Anthony Larcher's avatar
Anthony Larcher committed
108
109
110
        result = torch.nn.functional.softmax(self.activation(seg_emb_5),dim=1)
        #return seg_emb_5
        return result
Anthony Larcher's avatar
Anthony Larcher committed
111

Anthony Larcher's avatar
Anthony Larcher committed
112
113
    def LossFN(self, x, label):
        loss = - torch.trace(torch.mm(torch.log10(x), torch.t(label)))
Anthony Larcher's avatar
Anthony Larcher committed
114
        return loss
Anthony Larcher's avatar
Anthony Larcher committed
115

Anthony Larcher's avatar
Anthony Larcher committed
116
    def init_weights(self):
Anthony Larcher's avatar
Anthony Larcher committed
117
        """
Anthony Larcher's avatar
Anthony Larcher committed
118
        """
Anthony Larcher's avatar
Anthony Larcher committed
119
120
121
122
123
124
125
126
        torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
        torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
        torch.nn.init.xavier_uniform(self.seg_lin0.weight)
        torch.nn.init.xavier_uniform(self.seg_lin1.weight)
        torch.nn.init.xavier_uniform(self.seg_lin2.weight)
Anthony Larcher's avatar
hot    
Anthony Larcher committed
127

Anthony Larcher's avatar
Anthony Larcher committed
128
129
130
131
132
133
134
135
        torch.nn.init.constant(self.frame_conv0.bias, 0.1)
        torch.nn.init.constant(self.frame_conv1.bias, 0.1)
        torch.nn.init.constant(self.frame_conv2.bias, 0.1)
        torch.nn.init.constant(self.frame_conv3.bias, 0.1)
        torch.nn.init.constant(self.frame_conv4.bias, 0.1)
        torch.nn.init.constant(self.seg_lin0.bias, 0.1)
        torch.nn.init.constant(self.seg_lin1.bias, 0.1)
        torch.nn.init.constant(self.seg_lin2.bias, 0.1)
Anthony Larcher's avatar
Anthony Larcher committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

    def extract(self, x):
        frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
        frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
        frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
        frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
        frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
        # Pooling Layer that computes mean and standard devition of frame level embeddings
        # The output of the pooling layer is the first segment-level representation
        mean = torch.mean(frame_emb_4, dim=2)
        std = torch.std(frame_emb_4, dim=2)
        seg_emb = torch.cat([mean, std], dim=1)
        # No batch-normalisation after this layer
        # seg_emb_1 = self.activation(self.seg_lin0(seg_emb_0))

        seg_emb_A = self.seg_lin0(seg_emb)
        seg_emb_B = self.seg_lin1(self.activation(seg_emb_A))

        # return torch.nn.functional.softmax(seg_emb_3,dim=1)
        return seg_emb_A, seg_emb_B


def xtrain(args):
    # Initialize a first model and save to disk
Anthony Larcher's avatar
Anthony Larcher committed
160
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
161
162
163
164
165
166
167
168
    current_model_file_name = "initial_model"
    torch.save(model.state_dict(), current_model_file_name)

    for epoch in range(1, args.epochs + 1):
        current_model_file_name = train_epoch(epoch, args, current_model_file_name)

        # Add the cross validation here
        accuracy = cross_validation(args, current_model_file_name)
Anthony Larcher's avatar
Anthony Larcher committed
169
        print("*** Cross validation accuracy = {} %".format(accuracy))
Anthony Larcher's avatar
Anthony Larcher committed
170

Anthony Larcher's avatar
Anthony Larcher committed
171
172
173
        # Decrease learning rate after every epoch
        args.lr = args.lr * 0.9

Anthony Larcher's avatar
hot    
Anthony Larcher committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def xtrain_hot(args):
    # Initialize a first model and save to disk
    model = Xtractor(args.class_number, args.dropout)
    current_model_file_name = "initial_model"
    torch.save(model.state_dict(), current_model_file_name)

    for epoch in range(1, args.epochs + 1):
        current_model_file_name = train_epoch_hot(epoch, args, current_model_file_name)

        # Add the cross validation here
        accuracy = cross_validation_hot(args, current_model_file_name)
        print("*** Cross validation accuracy = {} %".format(accuracy))

        # Decrease learning rate after every epoch
        args.lr = args.lr * 0.9
Anthony Larcher's avatar
Anthony Larcher committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

def train_epoch(epoch, args, initial_model_file_name):
    # Compute the megabatch number
    with open(args.batch_training_list, 'r') as fh:
        batch_file_list = [l.rstrip() for l in fh]

    # Shorten the batch_file_list to be a multiple of

    megabatch_number = len(batch_file_list) // (args.averaging_step * args.num_processes)
    megabatch_size = args.averaging_step * args.num_processes
    print("Epoch {}, number of megabatches = {}".format(epoch, megabatch_number))

    current_model = initial_model_file_name

    # For each sublist: run an asynchronous training and averaging of the model
    for ii in range(megabatch_number):
        print('Process megabatch [{} / {}]'.format(ii + 1, megabatch_number))
        current_model = train_asynchronous(epoch,
                                           args,
                                           current_model,
                                           batch_file_list[megabatch_size * ii: megabatch_size * (ii + 1)],
                                           ii,
                                           megabatch_number)  # function that split train, fuse and write the new model to disk
    return current_model

Anthony Larcher's avatar
hot    
Anthony Larcher committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def train_epoch_hot(epoch, args, initial_model_file_name):
    # Compute the megabatch number
    with open(args.batch_training_list, 'r') as fh:
        batch_file_list = [l.rstrip() for l in fh]

    # Shorten the batch_file_list to be a multiple of

    megabatch_number = len(batch_file_list) // (args.averaging_step * args.num_processes)
    megabatch_size = args.averaging_step * args.num_processes
    print("Epoch {}, number of megabatches = {}".format(epoch, megabatch_number))

    current_model = initial_model_file_name

    # For each sublist: run an asynchronous training and averaging of the model
    for ii in range(megabatch_number):
        print('Process megabatch [{} / {}]'.format(ii + 1, megabatch_number))
        current_model = train_asynchronous_hot(epoch,
                                               args,
                                               current_model,
                                               batch_file_list[megabatch_size * ii: megabatch_size * (ii + 1)],
                                               ii,
                                               megabatch_number)  # function that split train, fuse and write the new model to disk
    return current_model

Anthony Larcher's avatar
Anthony Larcher committed
238
239

def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
240
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
241
242
243
244
    model.load_state_dict(torch.load(initial_model_file_name))
    model.train()

    torch.manual_seed(args.seed + rank)
Anthony Larcher's avatar
Anthony Larcher committed
245
    train_loader = XvectorMultiDataset(batch_list, args.batch_path)
Anthony Larcher's avatar
Anthony Larcher committed
246
247
248
249
250
251
252
253
254
255
256
257
258

    device = torch.device("cuda:{}".format(rank))
    model.to(device)

    # optimizer = optim.Adam(model.parameters(), lr = args.lr)
    optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
Anthony Larcher's avatar
Anthony Larcher committed
259
                            ], lr=args.lr)
Anthony Larcher's avatar
Anthony Larcher committed
260
261
262


    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
Anthony Larcher's avatar
Anthony Larcher committed
263
264
    #criterion = torch.nn.NLLLoss()
    #criterion = torch.nn.CrossEntropyLoss()
Anthony Larcher's avatar
Anthony Larcher committed
265
266
267
268
269
270
271
272

    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = criterion(output, target.to(device))
        loss.backward()
        optimizer.step()
Anthony Larcher's avatar
Anthony Larcher committed
273

Anthony Larcher's avatar
Anthony Larcher committed
274
        accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
Anthony Larcher's avatar
Anthony Larcher committed
275

Anthony Larcher's avatar
Anthony Larcher committed
276
277
278
279
280
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                epoch, batch_idx + 1, train_loader.__len__(),
                       100. * batch_idx / train_loader.__len__(), loss.item(),
                       100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
Anthony Larcher's avatar
Anthony Larcher committed
281

Anthony Larcher's avatar
Anthony Larcher committed
282
283
    model_param = OrderedDict()
    params = model.state_dict()
Anthony Larcher's avatar
Anthony Larcher committed
284

Anthony Larcher's avatar
Anthony Larcher committed
285
286
287
    for k in list(params.keys()):
        model_param[k] = params[k].cpu().detach().numpy()
    output_queue.put(model_param)
Anthony Larcher's avatar
Anthony Larcher committed
288
289


Anthony Larcher's avatar
hot    
Anthony Larcher committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def train_worker_hot(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
    model = Xtractor(args.class_number, args.dropout)
    model.load_state_dict(torch.load(initial_model_file_name))
    model.train()

    torch.manual_seed(args.seed + rank)
    train_loader = XvectorMultiDataset_hot(batch_list, args.batch_path)

    device = torch.device("cuda:{}".format(rank))
    model.to(device)

    # optimizer = optim.Adam(model.parameters(), lr = args.lr)
    optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
                            {'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
                            {'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
                            ], lr=args.lr)


    #criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    #criterion = torch.nn.NLLLoss()
    #criterion = torch.nn.CrossEntropyLoss()

    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss = model.LossFN(output, target.to(device))
        loss.backward()
        optimizer.step()

        accuracy += (torch.argmax(output.data, 1) == torch.argmax(target.to(device), 1)).sum()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                epoch, batch_idx + 1, train_loader.__len__(),
                       100. * batch_idx / train_loader.__len__(), loss.item(),
                       100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))

    model_param = OrderedDict()
    params = model.state_dict()

    for k in list(params.keys()):
        model_param[k] = params[k].cpu().detach().numpy()
    output_queue.put(model_param)


Anthony Larcher's avatar
Anthony Larcher committed
341
342
343
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
    # Split the list of files for each process
    sub_lists = split_file_list(batch_file_list, args.num_processes)
Anthony Larcher's avatar
Anthony Larcher committed
344

Anthony Larcher's avatar
Anthony Larcher committed
345
346
347
    #
    output_queue = mp.Queue()
    # output_queue = multiprocessing.Queue()
Anthony Larcher's avatar
Anthony Larcher committed
348

Anthony Larcher's avatar
Anthony Larcher committed
349
350
351
352
353
354
355
356
    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=train_worker,
                       args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first train the model across `num_processes` processes
        p.start()
        processes.append(p)
Anthony Larcher's avatar
Anthony Larcher committed
357

Anthony Larcher's avatar
Anthony Larcher committed
358
359
360
361
    # Average the models and write the new one to disk
    asynchronous_model = []
    for ii in range(args.num_processes):
        asynchronous_model.append(dict(output_queue.get()))
Anthony Larcher's avatar
Anthony Larcher committed
362

Anthony Larcher's avatar
Anthony Larcher committed
363
364
    for p in processes:
        p.join()
Anthony Larcher's avatar
Anthony Larcher committed
365

Anthony Larcher's avatar
Anthony Larcher committed
366
    av_model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
367
    tmp = av_model.state_dict()
Anthony Larcher's avatar
Anthony Larcher committed
368

Anthony Larcher's avatar
Anthony Larcher committed
369
370
371
    average_param = dict()
    for k in list(asynchronous_model[0].keys()):
        average_param[k] = asynchronous_model[0][k]
Anthony Larcher's avatar
Anthony Larcher committed
372

Anthony Larcher's avatar
Anthony Larcher committed
373
374
        for mod in asynchronous_model[1:]:
            average_param[k] += mod[k]
Anthony Larcher's avatar
Anthony Larcher committed
375

Anthony Larcher's avatar
Anthony Larcher committed
376
377
        if 'num_batches_tracked' not in k:
            tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))
Anthony Larcher's avatar
Anthony Larcher committed
378

Anthony Larcher's avatar
Anthony Larcher committed
379
380
381
382
383
384
    # return the file name of the new model
    current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
                                                                     megabatch_idx)
    torch.save(tmp, current_model_file_name)
    if megabatch_idx == megabatch_number:
        torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))
Anthony Larcher's avatar
Anthony Larcher committed
385

Anthony Larcher's avatar
Anthony Larcher committed
386
    return current_model_file_name
Anthony Larcher's avatar
Anthony Larcher committed
387

Anthony Larcher's avatar
hot    
Anthony Larcher committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
def train_asynchronous_hot(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
    # Split the list of files for each process
    sub_lists = split_file_list(batch_file_list, args.num_processes)

    #
    output_queue = mp.Queue()
    # output_queue = multiprocessing.Queue()

    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=train_worker_hot,
                       args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first train the model across `num_processes` processes
        p.start()
        processes.append(p)

    # Average the models and write the new one to disk
    asynchronous_model = []
    for ii in range(args.num_processes):
        asynchronous_model.append(dict(output_queue.get()))

    for p in processes:
        p.join()

    av_model = Xtractor(args.class_number, args.dropout)
    tmp = av_model.state_dict()

    average_param = dict()
    for k in list(asynchronous_model[0].keys()):
        average_param[k] = asynchronous_model[0][k]

        for mod in asynchronous_model[1:]:
            average_param[k] += mod[k]

        if 'num_batches_tracked' not in k:
            tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))

    # return the file name of the new model
    current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
                                                                     megabatch_idx)
    torch.save(tmp, current_model_file_name)
    if megabatch_idx == megabatch_number:
        torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))

    return current_model_file_name

Anthony Larcher's avatar
Anthony Larcher committed
435
def cross_validation(args, current_model_file_name):
Anthony Larcher's avatar
Anthony Larcher committed
436
437
    """

Anthony Larcher's avatar
Anthony Larcher committed
438
439
    :param args:
    :param current_model_file_name:
Anthony Larcher's avatar
Anthony Larcher committed
440
441
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
442
    with open(args.cross_validation_list, 'r') as fh:
Anthony Larcher's avatar
Anthony Larcher committed
443
        cross_validation_list = [l.rstrip() for l in fh]
Anthony Larcher's avatar
Anthony Larcher committed
444
        sub_lists = split_file_list(cross_validation_list, args.num_processes)
Anthony Larcher's avatar
Anthony Larcher committed
445

Anthony Larcher's avatar
Anthony Larcher committed
446
447
    #
    output_queue = mp.Queue()
Anthony Larcher's avatar
Anthony Larcher committed
448

Anthony Larcher's avatar
Anthony Larcher committed
449
450
451
452
453
454
455
456
    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=cv_worker,
                       args=(rank, args, current_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first evaluate the model across `num_processes` processes
        p.start()
        processes.append(p)
Anthony Larcher's avatar
Anthony Larcher committed
457

Anthony Larcher's avatar
Anthony Larcher committed
458
459
460
461
    # Average the models and write the new one to disk
    result = []
    for ii in range(args.num_processes):
        result.append(output_queue.get())
Anthony Larcher's avatar
Anthony Larcher committed
462

Anthony Larcher's avatar
Anthony Larcher committed
463
464
    for p in processes:
        p.join()
Anthony Larcher's avatar
Anthony Larcher committed
465

Anthony Larcher's avatar
Anthony Larcher committed
466
467
468
    # Compute the global accuracy
    accuracy = 0.0
    total_batch_number = 0
Anthony Larcher's avatar
Anthony Larcher committed
469
    for bn, acc in result:
Anthony Larcher's avatar
Anthony Larcher committed
470
        accuracy += acc
Anthony Larcher's avatar
Anthony Larcher committed
471
472
        total_batch_number += bn
    
Anthony Larcher's avatar
Anthony Larcher committed
473
    return 100. * accuracy / (total_batch_number * args.batch_size)
Anthony Larcher's avatar
Anthony Larcher committed
474
475


Anthony Larcher's avatar
hot    
Anthony Larcher committed
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
def cross_validation_hot(args, current_model_file_name):
    """

    :param args:
    :param current_model_file_name:
    :return:
    """
    with open(args.cross_validation_list, 'r') as fh:
        cross_validation_list = [l.rstrip() for l in fh]
        sub_lists = split_file_list(cross_validation_list, args.num_processes)

    #
    output_queue = mp.Queue()

    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=cv_worker_hot,
                       args=(rank, args, current_model_file_name, sub_lists[rank], output_queue)
                       )
        # We first evaluate the model across `num_processes` processes
        p.start()
        processes.append(p)

    # Average the models and write the new one to disk
    result = []
    for ii in range(args.num_processes):
        result.append(output_queue.get())

    for p in processes:
        p.join()

    # Compute the global accuracy
    accuracy = 0.0
    total_batch_number = 0
    for bn, acc in result:
        accuracy += acc
        total_batch_number += bn

    return 100. * accuracy / (total_batch_number * args.batch_size)


Anthony Larcher's avatar
Anthony Larcher committed
517
def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
518
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
519
520
    model.load_state_dict(torch.load(current_model_file_name))
    model.eval()
Anthony Larcher's avatar
Anthony Larcher committed
521

Anthony Larcher's avatar
Anthony Larcher committed
522
    cv_loader = XvectorMultiDataset(batch_list, args.batch_path)
Anthony Larcher's avatar
Anthony Larcher committed
523

Anthony Larcher's avatar
Anthony Larcher committed
524
525
    device = torch.device("cuda:{}".format(rank))
    model.to(device)
Anthony Larcher's avatar
Anthony Larcher committed
526

Anthony Larcher's avatar
Anthony Larcher committed
527
    accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
528
    for batch_idx, (data, target) in enumerate(cv_loader):
Anthony Larcher's avatar
Anthony Larcher committed
529
530
        output = model(data.to(device))
        accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
Anthony Larcher's avatar
Anthony Larcher committed
531
    output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
Anthony Larcher's avatar
Anthony Larcher committed
532

Anthony Larcher's avatar
hot    
Anthony Larcher committed
533
534
535
536
537
538
539
540
541
542
543
544
545
def cv_worker_hot(rank, args, current_model_file_name, batch_list, output_queue):
    model = Xtractor(args.class_number, args.dropout)
    model.load_state_dict(torch.load(current_model_file_name))
    model.eval()

    cv_loader = XvectorMultiDataset_hot(batch_list, args.batch_path)

    device = torch.device("cuda:{}".format(rank))
    model.to(device)

    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(cv_loader):
        output = model(data.to(device))
Anthony Larcher's avatar
Anthony Larcher committed
546
547
548
        #accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
        accuracy += (torch.argmax(output.data, 1) == torch.argmax(target.to(device), 1)).sum()

Anthony Larcher's avatar
hot    
Anthony Larcher committed
549
550
551
    output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))


Anthony Larcher's avatar
Anthony Larcher committed
552
def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, output_queue):
Anthony Larcher's avatar
Anthony Larcher committed
553
    """
Anthony Larcher's avatar
Anthony Larcher committed
554
555
    Function that takes a model and an idmap and extract all x-vectors based on this model
    and return a StatServer containing the x-vectors
Anthony Larcher's avatar
Anthony Larcher committed
556
    """
Anthony Larcher's avatar
Anthony Larcher committed
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    device = torch.device("cuda:{}".format(device_ID))

    # Create the dataset
    tmp_idmap = IdMap(idmap_name)
    idmap = IdMap()
    idmap.leftids = tmp_idmap.leftids[segment_indices]
    idmap.rightids = tmp_idmap.rightids[segment_indices]
    idmap.start = tmp_idmap.start[segment_indices]
    idmap.stop = tmp_idmap.stop[segment_indices]

    segment_loader = StatDataset(idmap, fs_params)

    # Load the model
    model_file_name = '/'.join([args.model_path, args.model_name])
Anthony Larcher's avatar
Anthony Larcher committed
571
    model = Xtractor(args.class_number, args.dropout)
Anthony Larcher's avatar
Anthony Larcher committed
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    model.load_state_dict(torch.load(model_file_name))
    model.eval()

    # Get the size of embeddings
    emb_a_size = model.seg_lin0.weight.data.shape[0]
    emb_b_size = model.seg_lin1.weight.data.shape[0]

    # Create a Tensor to store all x-vectors on the GPU
    emb_A = numpy.zeros((idmap.leftids.shape[0], emb_a_size)).astype(numpy.float32)
    emb_B = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)

    # Send on selected device
    model.to(device)

    # Loop to extract all x-vectors
    for idx, (model_id, segment_id, data) in enumerate(segment_loader):
        print('Extract X-vector for {}\t[{} / {}]'.format(segment_id, idx, segment_loader.__len__()))
        print("shape of data = {}".format(list(data.shape)))
        print("shape[2] = {}".format(list(data.shape)[2]))
        if list(data.shape)[2] < 20:
            pass
        else:
            A, B = model.extract(data.to(device))
            emb_A[idx, :] = A.detach().cpu()
            emb_B[idx, :] = B.detach().cpu()

    output_queue.put((segment_indices, emb_A, emb_B))


def extract_parallel(args, fs_params, dataset):
    emb_a_size = 512
    emb_b_size = 512

    if dataset == 'enroll':
        idmap_name = args.enroll_idmap
    elif dataset == 'test':
        idmap_name = args.test_idmap
    elif dataset == 'back':
        idmap_name = args.back_idmap

    idmap = IdMap(idmap_name)
    x_server_A = StatServer(idmap, 1, emb_a_size)
    x_server_B = StatServer(idmap, 1, emb_b_size)
    x_server_A.stat0 = numpy.ones(x_server_A.stat0.shape)
    x_server_B.stat0 = numpy.ones(x_server_B.stat0.shape)

    # Split the indices
    mega_batch_size = idmap.leftids.shape[0] // args.num_processes
    segment_idx = []
    for ii in range(args.num_processes):
        segment_idx.append(
623
            numpy.arange(ii * mega_batch_size, numpy.max([(ii + 1) * mega_batch_size, idmap.leftids.shape[0]])))
Anthony Larcher's avatar
Anthony Larcher committed
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

    # Extract x-vectors in parallel
    output_queue = mp.Queue()

    processes = []
    for rank in range(args.num_processes):
        p = mp.Process(target=extract_idmap,
                       args=(args, rank, segment_idx[rank], fs_params, idmap_name, output_queue)
                       )
        # We first train the model across `num_processes` processes
        p.start()
        processes.append(p)

    # Get the x-vectors and fill the StatServer
    for ii in range(args.num_processes):
        indices, A, B = output_queue.get()
        x_server_A.stat1[indices, :] = A
        x_server_B.stat1[indices, :] = B

    for p in processes:
        p.join()

    return x_server_A, x_server_B