seqtoseq.py 17.8 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d.  If not, see <http://www.gnu.org/licenses/>.


"""
Copyright 2014-2020 Anthony Larcher
"""

import os
import sys
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
28
29
import logging
import pandas
Anthony Larcher's avatar
Anthony Larcher committed
30
import numpy
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
31
from collections import OrderedDict
Anthony Larcher's avatar
Anthony Larcher committed
32
33
import random
import h5py
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
34
import shutil
Anthony Larcher's avatar
Anthony Larcher committed
35
36
import torch
import torch.nn as nn
37
import yaml
Martin Lebourdais's avatar
Martin Lebourdais committed
38
from sklearn.model_selection import train_test_split
Anthony Larcher's avatar
Anthony Larcher committed
39
40
41
from torch import optim
from torch.utils.data import Dataset

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
42
from .wavsets import SeqSet
43
from sidekit.nnet.sincnet import SincNet
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
44
from torch.utils.data import DataLoader
Anthony Larcher's avatar
Anthony Larcher committed
45
46
47
48
49
50
51
52
53

__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'

Anthony Larcher's avatar
Anthony Larcher committed
54

Martin Lebourdais's avatar
Martin Lebourdais committed
55
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
Martin Lebourdais's avatar
Martin Lebourdais committed
56

Martin Lebourdais's avatar
Martin Lebourdais committed
57
58
59
60
61
62
63
64
65
66
67
    """

    :param state:
    :param is_best:
    :param filename:
    :param best_filename:
    :return:
    """
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_filename)
Anthony Larcher's avatar
Anthony Larcher committed
68
69
70


class BLSTM(nn.Module):
Anthony Larcher's avatar
Anthony Larcher committed
71
72
    def __init__(self,
                 input_size,
73
                 blstm_sizes):
Anthony Larcher's avatar
Anthony Larcher committed
74
75
76
        """

        :param input_size:
77
        :param blstm_sizes:
Anthony Larcher's avatar
Anthony Larcher committed
78
79
        """
        super(BLSTM, self).__init__()
80
81
        self.input_size = input_size
        self.blstm_sizes = blstm_sizes
Martin Lebourdais's avatar
Martin Lebourdais committed
82
83
84
85
86
        #self.blstm_layers = []
        # for blstm_size in blstm_sizes:
        #     print(f"Input size {input_size},Output_size {self.output_size}")
        #     self.blstm_layers.append(nn.LSTM(input_size, blstm_size, bidirectional=False, batch_first=True))
        #     input_size = blstm_size
Martin Lebourdais's avatar
Martin Lebourdais committed
87

Martin Lebourdais's avatar
Martin Lebourdais committed
88
89
        self.output_size = blstm_sizes[0] * 2
        # self.blstm_layers = torch.nn.ModuleList(self.blstm_layers)
Martin Lebourdais's avatar
Martin Lebourdais committed
90

Martin Lebourdais's avatar
Martin Lebourdais committed
91
        self.blstm_layers = nn.LSTM(input_size,blstm_sizes[0],bidirectional=True,batch_first=True,num_layers=2)
Anthony Larcher's avatar
Anthony Larcher committed
92
        self.hidden = None
93
94
95
    """
    Bi LSTM model used for voice activity detection or speaker turn detection
    """
Anthony Larcher's avatar
Anthony Larcher committed
96
97
98
99
100
101
102

    def forward(self, inputs):
        """

        :param inputs:
        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
103
104
105
        #for idx, _s in enumerate(self.blstm_sizes):
        #    self.blstm_layers[idx].flatten_parameters()

106
        hiddens = []
Anthony Larcher's avatar
Anthony Larcher committed
107
        if self.hidden is None:
108
            #hidden_1, hidden_2 = None, None
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
109
            for _s in self.blstm_sizes:
110
                hiddens.append(None)
Anthony Larcher's avatar
Anthony Larcher committed
111
        else:
Anthony Larcher's avatar
Anthony Larcher committed
112
            hiddens = list(self.hidden)
113
114

        x = inputs
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
115
        outputs = []
Martin Lebourdais's avatar
Martin Lebourdais committed
116
117
118
119
120
121
122
123
124
125
        # for idx, _s in enumerate(self.blstm_sizes):
        #     # self.blstm_layers[idx].flatten_parameters()
        #     print("IN",x.shape)
        #     x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
        #     print("OUT",x.shape)
        #     outputs.append(x)
        # self.hidden = tuple(hiddens)
        # output = torch.cat(outputs, dim=2)
        output,h = self.blstm_layers(x)
        return output
Anthony Larcher's avatar
Anthony Larcher committed
126

127
128
129
    def output_size(self):
        return self.output_size

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
130

Anthony Larcher's avatar
Anthony Larcher committed
131
class SeqToSeq(nn.Module):
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
132
    """
133
134
135
136
    Model used for voice activity detection or speaker turn detection
    This model can include a pre-processor to input raw waveform,
    a BLSTM module to process the sequence-to-sequence
    and other linear of convolutional layers
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
137
138
    """
    def __init__(self,
139
140
                 model_archi):

Anthony Larcher's avatar
Anthony Larcher committed
141
        super(SeqToSeq, self).__init__()
142
143
144
145
146
147

        # Load Yaml configuration
        with open(model_archi, 'r') as fh:
            cfg = yaml.load(fh, Loader=yaml.FullLoader)

        self.loss = cfg["loss"]
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
148
        self.feature_size = None
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        """
        Prepare Preprocessor
        """
        self.preprocessor = None
        if "preprocessor" in cfg:
            if cfg['preprocessor']["type"] == "sincnet":
                self.preprocessor = SincNet(
                    waveform_normalize=cfg['preprocessor']["waveform_normalize"],
                    sample_rate=cfg['preprocessor']["sample_rate"],
                    min_low_hz=cfg['preprocessor']["min_low_hz"],
                    min_band_hz=cfg['preprocessor']["min_band_hz"],
                    out_channels=cfg['preprocessor']["out_channels"],
                    kernel_size=cfg['preprocessor']["kernel_size"],
                    stride=cfg['preprocessor']["stride"],
                    max_pool=cfg['preprocessor']["max_pool"],
                    instance_normalize=cfg['preprocessor']["instance_normalize"],
                    activation=cfg['preprocessor']["activation"],
                    dropout=cfg['preprocessor']["dropout"]
                )
                self.feature_size = self.preprocessor.dimension

        """
        Prepare sequence to sequence  network
        """
        # Get Feature size
        if self.feature_size is None:
            self.feature_size = cfg["feature_size"]

Martin Lebourdais's avatar
Martin Lebourdais committed
178
        input_size = self.feature_size
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
179
180
        self.sequence_to_sequence = BLSTM(input_size=input_size,
                                          blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"])
181

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
182
        input_size = self.sequence_to_sequence.output_size
183
184
185
186
187

        """
        Prepare post-processing network
        """
        # Create sequential object for the second part of the network
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
188
        self.post_processing_activation = torch.nn.Tanh()
189
190
191
192
193
194
195
196
197
        post_processing_layers = []
        for k in cfg["post_processing"].keys():

            if k.startswith("lin"):
                post_processing_layers.append((k, torch.nn.Linear(input_size,
                                                                  cfg["post_processing"][k]["output"])))
                input_size = cfg["post_processing"][k]["output"]

            elif k.startswith("activation"):
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
198
                post_processing_layers.append((k, self.post_processing_activation))
199
200
201
202
203
204
205

            elif k.startswith('batch_norm'):
                post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size)))

            elif k.startswith('dropout'):
                post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
206
207
        self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
        #self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"]
208

Anthony Larcher's avatar
Anthony Larcher committed
209

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
210
211
    def forward(self, inputs):
        """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
212

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
213
214
215
        :param inputs:
        :return:
        """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
216
217
        if self.preprocessor is not None:
            x = self.preprocessor(inputs)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
218
        else:
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
219
220
221
            x = inputs
        x = self.sequence_to_sequence(x)
        x = self.post_processing(x)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
222
223
224
        return x


Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
225
def seqTrain(dataset_yaml,
Martin Lebourdais's avatar
Martin Lebourdais committed
226
227
228
             val_dataset_yaml,
             norm_dataset_yaml,
             over_dataset_yaml,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
229
             model_yaml,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
230
231
232
233
             mode,
             epochs=100,
             lr=0.0001,
             patience=10,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
234
             model_name=None,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
235
236
237
238
             tmp_model_name=None,
             best_model_name=None,
             multi_gpu=True,
             opt='sgd',
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
239
240
241
242
243
244
             filter_type="gate",
             collar_duration=0.1,
             framerate=16000,
             output_rate=100,
             batch_size=32,
             log_interval=10,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
             num_thread=10
             ):
    """

    :param data_dir:
    :param mode:
    :param duration:
    :param seg_shift:
    :param filter_type:
    :param collar_duration:
    :param framerate:
    :param epochs:
    :param lr:
    :param loss:
    :param patience:
    :param tmp_model_name:
    :param best_model_name:
    :param multi_gpu:
    :param opt:
    :return:
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Martin Lebourdais's avatar
Martin Lebourdais committed
267
    # device = torch.device("CUP")
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
268
269

    # Start from scratch
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
270
271
272
273
274
275
276
277
    if model_name is None:
       model = SeqToSeq(model_yaml)
    # If we start from an existing model
    else:
        # Load the model
        logging.critical(f"*** Load model from = {model_name}")
        checkpoint = torch.load(model_name)
        model = SeqToSeq(model_yaml)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
278
279
280
281
282
283
284
285
286
287
    if torch.cuda.device_count() > 1 and multi_gpu:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
    else:
        print("Train on a single GPU")
    model.to(device)

    """
    Create two dataloaders for training and evaluation
    """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
288
289
    with open(dataset_yaml, "r") as fh:
        dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
Martin Lebourdais's avatar
Martin Lebourdais committed
290
291
292
293
        df = pandas.read_csv(dataset_params["dataset_description"])
    training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
    _wav_dir=dataset_params['wav_dir']
    _mdtm_dir=dataset_params['mdtm_dir']
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
294
295

    torch.manual_seed(dataset_params['seed'])
Martin Lebourdais's avatar
Martin Lebourdais committed
296
    training_set_norm = SeqSet(norm_dataset_yaml,
Martin Lebourdais's avatar
Martin Lebourdais committed
297
298
299
                          wav_dir=_wav_dir,
                          mdtm_dir=_mdtm_dir,
                          mode=mode,
Anthony Larcher's avatar
Anthony Larcher committed
300
301
302
                          duration=2.,
                          filter_type="gate",
                          collar_duration=0.1,
Martin Lebourdais's avatar
Martin Lebourdais committed
303
304
                          audio_framerate=framerate,
                          output_framerate=output_rate,
Anthony Larcher's avatar
Anthony Larcher committed
305
                          transform_pipeline="MFCC")
Martin Lebourdais's avatar
Martin Lebourdais committed
306
307
308
309
310
311
312
313
314
315
316
    training_set_overlap = SeqSet(over_dataset_yaml,
                          wav_dir=_wav_dir,
                          mdtm_dir=_mdtm_dir,
                          mode=mode,
                          duration=2.,
                          filter_type="gate",
                          collar_duration=0.1,
                          audio_framerate=framerate,
                          output_framerate=output_rate,
                          transform_pipeline="add_noise,MFCC")
    training_set = torch.utils.data.ConcatDataset([training_set_norm,training_set_overlap])
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
317
    training_loader = DataLoader(training_set,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
318
                                 batch_size=dataset_params["batch_size"],
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
319
                                 drop_last=True,
Martin Lebourdais's avatar
Martin Lebourdais committed
320
                                 shuffle=True,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
321
                                 pin_memory=True,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
322
323
                                 num_workers=num_thread)

Martin Lebourdais's avatar
Martin Lebourdais committed
324
    validation_set = SeqSet(val_dataset_yaml,
Martin Lebourdais's avatar
Martin Lebourdais committed
325
326
327
328
329
330
331
332
333
334
335
                          wav_dir=_wav_dir,
                          mdtm_dir=_mdtm_dir,
                          mode=mode,
                          duration=2.,
                          filter_type="gate",
                          collar_duration=0.1,
                          audio_framerate=framerate,
                          output_framerate=output_rate,
                          set_type= "validation",
                          transform_pipeline="MFCC")

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
336

Martin Lebourdais's avatar
Martin Lebourdais committed
337
338
339
340
341
342
    validation_loader = DataLoader(validation_set,
                                   batch_size=dataset_params["batch_size"],
                                   drop_last=True,
                                   shuffle=True,
                                   pin_memory=True,
                                   num_workers=num_thread)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

    """
    Set the training options
    """
    if opt == 'sgd':
        _optimizer = torch.optim.SGD
        _options = {'lr': lr, 'momentum': 0.9}
    elif opt == 'adam':
        _optimizer = torch.optim.Adam
        _options = {'lr': lr}
    elif opt == 'rmsprop':
        _optimizer = torch.optim.RMSprop
        _options = {'lr': lr}

    params = [
        {
            'params': [
                param for name, param in model.named_parameters() if 'bn' not in name
            ]
        },
        {
            'params': [
                param for name, param in model.named_parameters() if 'bn' in name
            ],
            'weight_decay': 0
        },
    ]

Anthony Larcher's avatar
Anthony Larcher committed
371
372
373
374
375
376
377
378
379
380
381
382
383
    optimizer = _optimizer([{'params': model.parameters()},], **_options)
    #if type(model) is SeqToSeq:
    #    optimizer = _optimizer([
    #        {'params': model.parameters(),
    #         'weight_decay': model.weight_decay},],
    #        **_options
    #    )
    #else:
    #    optimizer = _optimizer([
    #        {'params': model.module.sequence_network.parameters(),
    #         #'weight_decay': model.module.sequence_network_weight_decay},],
    #        **_options
    #    )
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
384
385
386

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)

Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
387
388
    best_fmes = 0.0
    best_fmes_epoch = 1
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
389
390
391
392
393
394
395
396
397
398
399
400
401
    curr_patience = patience
    for epoch in range(1, epochs + 1):
        # Process one epoch and return the current model
        if curr_patience == 0:
            print(f"Stopping at epoch {epoch} for cause of patience")
            break
        model = train_epoch(model,
                            epoch,
                            training_loader,
                            optimizer,
                            log_interval,
                            device=device)

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
402
        # Cross validation here
Martin Lebourdais's avatar
Martin Lebourdais committed
403
404
        fmes, val_loss = cross_validation(model, validation_loader, device=device)
        logging.critical("*** Validation f-Measure = {}".format(fmes))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
405
406

        # Decrease learning rate according to the scheduler policy
Martin Lebourdais's avatar
Martin Lebourdais committed
407
408
409
410
        scheduler.step(val_loss)
        print(f"Learning rate is {optimizer.param_groups[0]['lr']}")

        ## remember best accuracy and save checkpoint
Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
411
412
        is_best = fmes > best_fmes
        best_fmes = max(fmes, best_fmes)
Martin Lebourdais's avatar
Martin Lebourdais committed
413
414
415
416
417
418

        if type(model) is SeqToSeq:
            save_checkpoint({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
Martin Lebourdais's avatar
Martin Lebourdais committed
419
                'accuracy': best_fmes,
Martin Lebourdais's avatar
Martin Lebourdais committed
420
421
422
423
424
425
426
                'scheduler': scheduler
            }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
        else:
            save_checkpoint({
                'epoch': epoch,
                'model_state_dict': model.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
Martin Lebourdais's avatar
Martin Lebourdais committed
427
                'accuracy': best_fmes,
Martin Lebourdais's avatar
Martin Lebourdais committed
428
429
430
431
                'scheduler': scheduler
            }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')

        if is_best:
Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
432
            best_fmes_epoch = epoch
Martin Lebourdais's avatar
Martin Lebourdais committed
433
434
435
436
            curr_patience = patience
        else:
            curr_patience -= 1

Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
437
    logging.critical(f"Best F-Mesure {best_fmes * 100.} obtained at epoch {best_fmes_epoch}")
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
438

Martin Lebourdais's avatar
Martin Lebourdais committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
def calc_recall(output,target,device):
    y_trueb = target.to(device)
    y_predb = output
    rc = 0.0
    pr = 0.0
    for b in range(64):
        y_true = y_trueb[:,b]
        y_pred = y_predb[:,:,b]
        assert y_true.ndim == 1
        assert y_pred.ndim == 1 or y_pred.ndim == 2

        if y_pred.ndim == 2:
            y_pred = y_pred.argmax(dim=1)


        tp = (y_true * y_pred).sum().to(torch.float32)
        tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
        fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
        fn = (y_true * (1 - y_pred)).sum().to(torch.float32)

        epsilon = 1e-7

        precision = tp / (tp + fp + epsilon)
        recall = tp / (tp + fn + epsilon)
        rc+=recall
        pr+=precision
    return rc,pr
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
466
467
468
469
470
471
472
473
474
475
476
477
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
    """

    :param model:
    :param epoch:
    :param training_loader:
    :param optimizer:
    :param log_interval:
    :param device:
    :param clipping:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
478
    model.to(device)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
479
    model.train()
Martin Lebourdais's avatar
Martin Lebourdais committed
480
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
Martin Lebourdais's avatar
Martin Lebourdais committed
481
    recall = 0.0
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
482
483
    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(training_loader):
Martin Lebourdais's avatar
Martin Lebourdais committed
484
        
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
485
486
        target = target.squeeze()
        optimizer.zero_grad()
Martin Lebourdais's avatar
Martin Lebourdais committed
487
488
489
        data = data.to(device)

        output = model(data)
Anthony Larcher's avatar
Anthony Larcher committed
490
491
        output = output.permute(1, 2, 0)
        target = target.permute(1, 0)
Martin Lebourdais's avatar
Martin Lebourdais committed
492
493
494
        #print(output.shape)
        #print(torch.argmax(output[:,:,0],1))
        #print(target[:,0])
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
495
        loss = criterion(output, target.to(device))
Anthony Larcher's avatar
Anthony Larcher committed
496
        loss.backward(retain_graph=True)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
497
498
        optimizer.step()

Martin Lebourdais's avatar
Martin Lebourdais committed
499
500
501
        rc,pr = calc_recall(output.data,target,device)
        accuracy += pr
        recall += rc
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
502
503
        if batch_idx % log_interval == 0:
            batch_size = target.shape[0]
Martin Lebourdais's avatar
Martin Lebourdais committed
504
505
            # print(100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198))
            logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}\tRecall: {:.3f}'.format(
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
506
507
                epoch, batch_idx + 1, training_loader.__len__(),
                100. * batch_idx / training_loader.__len__(), loss.item(),
Martin Lebourdais's avatar
Martin Lebourdais committed
508
509
510
511
                100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198),
                100.0 * recall.item() / ((batch_idx+1) *batch_size * 198)
            ))

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
512
513
514
515
516
517
518
519
520
521
522
523
    return model


def cross_validation(model, validation_loader, device):
    """

    :param model:
    :param validation_loader:
    :param device:
    :return:
    """
    model.eval()
Martin Lebourdais's avatar
Martin Lebourdais committed
524
    recall = 0.0
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
525
526
527
528
529
530
531
    accuracy = 0.0
    loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(validation_loader):
            batch_size = target.shape[0]
            target = target.squeeze()
Martin Lebourdais's avatar
Martin Lebourdais committed
532
533
534
535
            output = model(data.to(device))
            output = output.permute(1, 2, 0)
            target = target.permute(1, 0)
            nbpoint = output.shape[0]
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
536

Martin Lebourdais's avatar
Martin Lebourdais committed
537
538
539
540
            rc,pr = calc_recall(output.data,target,device)
            accuracy+= pr
            recall += rc
            
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
541
            loss += criterion(output, target.to(device))
Martin Lebourdais's avatar
Martin Lebourdais committed
542
543
        fmes = 2*(accuracy*recall)/(recall+accuracy)
    return fmes / ((batch_idx + 1) * batch_size), \
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
544
           loss.cpu().numpy() / ((batch_idx + 1) * batch_size)