seqtoseq.py 18.8 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d.  If not, see <http://www.gnu.org/licenses/>.


"""
Copyright 2014-2020 Anthony Larcher
"""

import os
import sys
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
28
29
import logging
import pandas
Anthony Larcher's avatar
Anthony Larcher committed
30
import numpy
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
31
from collections import OrderedDict
Anthony Larcher's avatar
Anthony Larcher committed
32
33
import random
import h5py
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
34
import shutil
Anthony Larcher's avatar
Anthony Larcher committed
35
36
import torch
import torch.nn as nn
37
import yaml
Martin Lebourdais's avatar
Martin Lebourdais committed
38
from sklearn.model_selection import train_test_split
Anthony Larcher's avatar
Anthony Larcher committed
39
40
41
from torch import optim
from torch.utils.data import Dataset

Martin Lebourdais's avatar
Martin Lebourdais committed
42
from .loss import ConcordanceCorCoeff
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
43
from .wavsets import SeqSet
44
from sidekit.nnet.sincnet import SincNet
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
45
from torch.utils.data import DataLoader
Anthony Larcher's avatar
Anthony Larcher committed
46
47
48
49
50
51
52
53
54

__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'

Anthony Larcher's avatar
Anthony Larcher committed
55

Martin Lebourdais's avatar
Martin Lebourdais committed
56
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
Martin Lebourdais's avatar
Martin Lebourdais committed
57

Martin Lebourdais's avatar
Martin Lebourdais committed
58
59
60
61
62
63
64
65
66
67
68
    """

    :param state:
    :param is_best:
    :param filename:
    :param best_filename:
    :return:
    """
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_filename)
Anthony Larcher's avatar
Anthony Larcher committed
69
70
71


class BLSTM(nn.Module):
Anthony Larcher's avatar
Anthony Larcher committed
72
73
    def __init__(self,
                 input_size,
74
                 blstm_sizes):
Anthony Larcher's avatar
Anthony Larcher committed
75
76
77
        """

        :param input_size:
78
        :param blstm_sizes:
Anthony Larcher's avatar
Anthony Larcher committed
79
80
        """
        super(BLSTM, self).__init__()
81
82
        self.input_size = input_size
        self.blstm_sizes = blstm_sizes
Martin Lebourdais's avatar
Martin Lebourdais committed
83
84
85
86
87
        #self.blstm_layers = []
        # for blstm_size in blstm_sizes:
        #     print(f"Input size {input_size},Output_size {self.output_size}")
        #     self.blstm_layers.append(nn.LSTM(input_size, blstm_size, bidirectional=False, batch_first=True))
        #     input_size = blstm_size
Martin Lebourdais's avatar
Martin Lebourdais committed
88

Martin Lebourdais's avatar
Martin Lebourdais committed
89
90
        self.output_size = blstm_sizes[0] * 2
        # self.blstm_layers = torch.nn.ModuleList(self.blstm_layers)
Martin Lebourdais's avatar
Martin Lebourdais committed
91

Martin Lebourdais's avatar
Martin Lebourdais committed
92
        self.blstm_layers = nn.LSTM(input_size,blstm_sizes[0],bidirectional=True,batch_first=True,num_layers=2)
Anthony Larcher's avatar
Anthony Larcher committed
93
        self.hidden = None
94
95
96
    """
    Bi LSTM model used for voice activity detection or speaker turn detection
    """
Anthony Larcher's avatar
Anthony Larcher committed
97
98
99
100
101
102
103

    def forward(self, inputs):
        """

        :param inputs:
        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
104
105
106
        #for idx, _s in enumerate(self.blstm_sizes):
        #    self.blstm_layers[idx].flatten_parameters()

107
        hiddens = []
Anthony Larcher's avatar
Anthony Larcher committed
108
        if self.hidden is None:
109
            #hidden_1, hidden_2 = None, None
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
110
            for _s in self.blstm_sizes:
111
                hiddens.append(None)
Anthony Larcher's avatar
Anthony Larcher committed
112
        else:
Anthony Larcher's avatar
Anthony Larcher committed
113
            hiddens = list(self.hidden)
114
115

        x = inputs
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
116
        outputs = []
Martin Lebourdais's avatar
Martin Lebourdais committed
117
118
119
120
121
122
123
124
125
126
        # for idx, _s in enumerate(self.blstm_sizes):
        #     # self.blstm_layers[idx].flatten_parameters()
        #     print("IN",x.shape)
        #     x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
        #     print("OUT",x.shape)
        #     outputs.append(x)
        # self.hidden = tuple(hiddens)
        # output = torch.cat(outputs, dim=2)
        output,h = self.blstm_layers(x)
        return output
Anthony Larcher's avatar
Anthony Larcher committed
127

128
129
130
    def output_size(self):
        return self.output_size

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
131

Anthony Larcher's avatar
Anthony Larcher committed
132
class SeqToSeq(nn.Module):
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
133
    """
134
135
136
137
    Model used for voice activity detection or speaker turn detection
    This model can include a pre-processor to input raw waveform,
    a BLSTM module to process the sequence-to-sequence
    and other linear of convolutional layers
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
138
139
    """
    def __init__(self,
140
141
                 model_archi):

Anthony Larcher's avatar
Anthony Larcher committed
142
        super(SeqToSeq, self).__init__()
143
144
145
146
147
148

        # Load Yaml configuration
        with open(model_archi, 'r') as fh:
            cfg = yaml.load(fh, Loader=yaml.FullLoader)

        self.loss = cfg["loss"]
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
149
        self.feature_size = None
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

        """
        Prepare Preprocessor
        """
        self.preprocessor = None
        if "preprocessor" in cfg:
            if cfg['preprocessor']["type"] == "sincnet":
                self.preprocessor = SincNet(
                    waveform_normalize=cfg['preprocessor']["waveform_normalize"],
                    sample_rate=cfg['preprocessor']["sample_rate"],
                    min_low_hz=cfg['preprocessor']["min_low_hz"],
                    min_band_hz=cfg['preprocessor']["min_band_hz"],
                    out_channels=cfg['preprocessor']["out_channels"],
                    kernel_size=cfg['preprocessor']["kernel_size"],
                    stride=cfg['preprocessor']["stride"],
                    max_pool=cfg['preprocessor']["max_pool"],
                    instance_normalize=cfg['preprocessor']["instance_normalize"],
                    activation=cfg['preprocessor']["activation"],
                    dropout=cfg['preprocessor']["dropout"]
                )
                self.feature_size = self.preprocessor.dimension

        """
        Prepare sequence to sequence  network
        """
        # Get Feature size
        if self.feature_size is None:
            self.feature_size = cfg["feature_size"]

Martin Lebourdais's avatar
Martin Lebourdais committed
179
        input_size = self.feature_size
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
180
181
        self.sequence_to_sequence = BLSTM(input_size=input_size,
                                          blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"])
182

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
183
        input_size = self.sequence_to_sequence.output_size
184
185
186
187
188

        """
        Prepare post-processing network
        """
        # Create sequential object for the second part of the network
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
189
        self.post_processing_activation = torch.nn.Tanh()
190
191
192
193
194
195
196
197
198
        post_processing_layers = []
        for k in cfg["post_processing"].keys():

            if k.startswith("lin"):
                post_processing_layers.append((k, torch.nn.Linear(input_size,
                                                                  cfg["post_processing"][k]["output"])))
                input_size = cfg["post_processing"][k]["output"]

            elif k.startswith("activation"):
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
199
                post_processing_layers.append((k, self.post_processing_activation))
200
201
202
203
204
205
206

            elif k.startswith('batch_norm'):
                post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size)))

            elif k.startswith('dropout'):
                post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
207
208
        self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
        #self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"]
209

Anthony Larcher's avatar
Anthony Larcher committed
210

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
211
212
    def forward(self, inputs):
        """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
213

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
214
215
216
        :param inputs:
        :return:
        """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
217
218
        if self.preprocessor is not None:
            x = self.preprocessor(inputs)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
219
        else:
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
220
221
222
            x = inputs
        x = self.sequence_to_sequence(x)
        x = self.post_processing(x)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
223
224
225
        return x


Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
226
def seqTrain(dataset_yaml,
Martin Lebourdais's avatar
Martin Lebourdais committed
227
             val_dataset_yaml,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
228
             model_yaml,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
229
230
231
232
             mode,
             epochs=100,
             lr=0.0001,
             patience=10,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
233
             model_name=None,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
234
235
236
237
             tmp_model_name=None,
             best_model_name=None,
             multi_gpu=True,
             opt='sgd',
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
238
239
240
241
242
243
             filter_type="gate",
             collar_duration=0.1,
             framerate=16000,
             output_rate=100,
             batch_size=32,
             log_interval=10,
Martin Lebourdais's avatar
Martin Lebourdais committed
244
245
246
             num_thread=10,
             non_overlap_dataset = None,
             overlap_dataset = None
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
             ):
    """

    :param data_dir:
    :param mode:
    :param duration:
    :param seg_shift:
    :param filter_type:
    :param collar_duration:
    :param framerate:
    :param epochs:
    :param lr:
    :param loss:
    :param patience:
    :param tmp_model_name:
    :param best_model_name:
    :param multi_gpu:
    :param opt:
    :return:
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Martin Lebourdais's avatar
Martin Lebourdais committed
268
    # device = torch.device("CUP")
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
269
270

    # Start from scratch
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
271
272
273
274
275
276
277
278
    if model_name is None:
       model = SeqToSeq(model_yaml)
    # If we start from an existing model
    else:
        # Load the model
        logging.critical(f"*** Load model from = {model_name}")
        checkpoint = torch.load(model_name)
        model = SeqToSeq(model_yaml)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
279
280
281
282
283
284
285
286
287
288
    if torch.cuda.device_count() > 1 and multi_gpu:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
    else:
        print("Train on a single GPU")
    model.to(device)

    """
    Create two dataloaders for training and evaluation
    """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
289
290
    with open(dataset_yaml, "r") as fh:
        dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
Martin Lebourdais's avatar
Martin Lebourdais committed
291
292
293
294
        df = pandas.read_csv(dataset_params["dataset_description"])
    training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
    _wav_dir=dataset_params['wav_dir']
    _mdtm_dir=dataset_params['mdtm_dir']
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
295
    torch.manual_seed(dataset_params['seed'])
Martin Lebourdais's avatar
Martin Lebourdais committed
296
297
298
299
300
301
302
303
304
305
306
    training_set = SeqSet(dataset_yaml,
                              wav_dir=_wav_dir,
                              mdtm_dir=_mdtm_dir,
                              mode=mode,
                              duration=2.,
                              filter_type="gate",
                              collar_duration=0.1,
                              audio_framerate=framerate,
                              output_framerate=output_rate,
                              transform_pipeline="MFCC")
        
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
307
    training_loader = DataLoader(training_set,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
308
                                 batch_size=dataset_params["batch_size"],
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
309
                                 drop_last=True,
Martin Lebourdais's avatar
Martin Lebourdais committed
310
                                 shuffle=True,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
311
                                 pin_memory=True,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
312
313
                                 num_workers=num_thread)

Martin Lebourdais's avatar
Martin Lebourdais committed
314
    validation_set = SeqSet(val_dataset_yaml,
Martin Lebourdais's avatar
Martin Lebourdais committed
315
316
317
318
319
320
321
322
323
324
325
                          wav_dir=_wav_dir,
                          mdtm_dir=_mdtm_dir,
                          mode=mode,
                          duration=2.,
                          filter_type="gate",
                          collar_duration=0.1,
                          audio_framerate=framerate,
                          output_framerate=output_rate,
                          set_type= "validation",
                          transform_pipeline="MFCC")

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
326

Martin Lebourdais's avatar
Martin Lebourdais committed
327
328
329
330
331
332
    validation_loader = DataLoader(validation_set,
                                   batch_size=dataset_params["batch_size"],
                                   drop_last=True,
                                   shuffle=True,
                                   pin_memory=True,
                                   num_workers=num_thread)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    """
    Set the training options
    """
    if opt == 'sgd':
        _optimizer = torch.optim.SGD
        _options = {'lr': lr, 'momentum': 0.9}
    elif opt == 'adam':
        _optimizer = torch.optim.Adam
        _options = {'lr': lr}
    elif opt == 'rmsprop':
        _optimizer = torch.optim.RMSprop
        _options = {'lr': lr}

    params = [
        {
            'params': [
                param for name, param in model.named_parameters() if 'bn' not in name
            ]
        },
        {
            'params': [
                param for name, param in model.named_parameters() if 'bn' in name
            ],
            'weight_decay': 0
        },
    ]

Anthony Larcher's avatar
Anthony Larcher committed
361
362
363
364
365
366
367
368
369
370
371
372
373
    optimizer = _optimizer([{'params': model.parameters()},], **_options)
    #if type(model) is SeqToSeq:
    #    optimizer = _optimizer([
    #        {'params': model.parameters(),
    #         'weight_decay': model.weight_decay},],
    #        **_options
    #    )
    #else:
    #    optimizer = _optimizer([
    #        {'params': model.module.sequence_network.parameters(),
    #         #'weight_decay': model.module.sequence_network_weight_decay},],
    #        **_options
    #    )
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
374
375
376

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)

Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
377
378
    best_fmes = 0.0
    best_fmes_epoch = 1
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
379
380
381
382
383
384
385
386
387
388
389
390
391
    curr_patience = patience
    for epoch in range(1, epochs + 1):
        # Process one epoch and return the current model
        if curr_patience == 0:
            print(f"Stopping at epoch {epoch} for cause of patience")
            break
        model = train_epoch(model,
                            epoch,
                            training_loader,
                            optimizer,
                            log_interval,
                            device=device)

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
392
        # Cross validation here
Martin Lebourdais's avatar
Martin Lebourdais committed
393
394
        fmes, val_loss = cross_validation(model, validation_loader, device=device)
        logging.critical("*** Validation f-Measure = {}".format(fmes))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
395
396

        # Decrease learning rate according to the scheduler policy
Martin Lebourdais's avatar
Martin Lebourdais committed
397
398
399
400
        scheduler.step(val_loss)
        print(f"Learning rate is {optimizer.param_groups[0]['lr']}")

        ## remember best accuracy and save checkpoint
Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
401
402
        is_best = fmes > best_fmes
        best_fmes = max(fmes, best_fmes)
Martin Lebourdais's avatar
Martin Lebourdais committed
403
404
405
406
407
408

        if type(model) is SeqToSeq:
            save_checkpoint({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
Martin Lebourdais's avatar
Martin Lebourdais committed
409
                'accuracy': best_fmes,
Martin Lebourdais's avatar
Martin Lebourdais committed
410
411
412
413
414
415
416
                'scheduler': scheduler
            }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
        else:
            save_checkpoint({
                'epoch': epoch,
                'model_state_dict': model.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
Martin Lebourdais's avatar
Martin Lebourdais committed
417
                'accuracy': best_fmes,
Martin Lebourdais's avatar
Martin Lebourdais committed
418
419
420
421
                'scheduler': scheduler
            }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')

        if is_best:
Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
422
            best_fmes_epoch = epoch
Martin Lebourdais's avatar
Martin Lebourdais committed
423
424
425
426
            curr_patience = patience
        else:
            curr_patience -= 1

Martin Lebourdais's avatar
bug fix    
Martin Lebourdais committed
427
    logging.critical(f"Best F-Mesure {best_fmes * 100.} obtained at epoch {best_fmes_epoch}")
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
428

Martin Lebourdais's avatar
Martin Lebourdais committed
429
430
431
432
433
def calc_recall(output,target,device):
    y_trueb = target.to(device)
    y_predb = output
    rc = 0.0
    pr = 0.0
Martin Lebourdais's avatar
Martin Lebourdais committed
434
435
    batch_size = y_trueb.shape[1]
    for b in range(batch_size):
Martin Lebourdais's avatar
Martin Lebourdais committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        y_true = y_trueb[:,b]
        y_pred = y_predb[:,:,b]
        assert y_true.ndim == 1
        assert y_pred.ndim == 1 or y_pred.ndim == 2

        if y_pred.ndim == 2:
            y_pred = y_pred.argmax(dim=1)


        tp = (y_true * y_pred).sum().to(torch.float32)
        tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
        fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
        fn = (y_true * (1 - y_pred)).sum().to(torch.float32)

        epsilon = 1e-7

        precision = tp / (tp + fp + epsilon)
        recall = tp / (tp + fn + epsilon)
        rc+=recall
        pr+=precision
    return rc,pr
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
457
458
459
460
461
462
463
464
465
466
467
468
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
    """

    :param model:
    :param epoch:
    :param training_loader:
    :param optimizer:
    :param log_interval:
    :param device:
    :param clipping:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
469
    model.to(device)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
470
    model.train()
471
472
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    #criterion = ccc_loss
Martin Lebourdais's avatar
Martin Lebourdais committed
473
    recall = 0.0
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
474
475
    accuracy = 0.0
    for batch_idx, (data, target) in enumerate(training_loader):
Martin Lebourdais's avatar
Martin Lebourdais committed
476
        
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
477
        target = target.squeeze()
Martin Lebourdais's avatar
Martin Lebourdais committed
478
479
480
        # tnumpy = target.numpy()
        # print(tnumpy.shape)
        # print(sum(tnumpy)/(tnumpy.shape[1]))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
481
        optimizer.zero_grad()
Martin Lebourdais's avatar
Martin Lebourdais committed
482
483
484
        data = data.to(device)

        output = model(data)
Anthony Larcher's avatar
Anthony Larcher committed
485
486
        output = output.permute(1, 2, 0)
        target = target.permute(1, 0)
Martin Lebourdais's avatar
Martin Lebourdais committed
487
488
489
        #print(output.shape)
        #print(torch.argmax(output[:,:,0],1))
        #print(target[:,0])
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
490
        loss = criterion(output, target.to(device))
Anthony Larcher's avatar
Anthony Larcher committed
491
        loss.backward(retain_graph=True)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
492
493
        optimizer.step()

Martin Lebourdais's avatar
Martin Lebourdais committed
494
495
496
        rc,pr = calc_recall(output.data,target,device)
        accuracy += pr
        recall += rc
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
497
498
        if batch_idx % log_interval == 0:
            batch_size = target.shape[0]
Martin Lebourdais's avatar
Martin Lebourdais committed
499
500
            # print(100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198))
            logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}\tRecall: {:.3f}'.format(
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
501
502
                epoch, batch_idx + 1, training_loader.__len__(),
                100. * batch_idx / training_loader.__len__(), loss.item(),
Martin Lebourdais's avatar
Martin Lebourdais committed
503
504
505
506
                100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198),
                100.0 * recall.item() / ((batch_idx+1) *batch_size * 198)
            ))

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
507
    return model
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
def pearsonr(x, y):
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    xm = x.sub(mean_x)
    ym = y.sub(mean_y)
    r_num = xm.dot(ym)
    r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
    r_val = r_num / r_den
    return r_val

def ccc_loss(pred,gt):
    batch_size=64
    temp1 = torch.argmax(torch.abs(pred),dim=1)
    acc_ccc = 0
    for n in range(batch_size):
        curr_pred = temp1[:,n].float()
        curr_gt = gt[:,n].float()
        true_m = torch.mean(curr_gt)
        true_var = torch.var(curr_gt)
        pred_m = torch.mean(curr_pred)
        pred_var = torch.var(curr_pred)
        rho = pearsonr(curr_pred,curr_gt)
        std_pred = torch.std(curr_pred)
        std_true = torch.std(curr_gt)
        ccc = (
                2
                * rho
                * std_true
                * std_pred
                / (std_pred ** 2 + std_true ** 2 + (pred_m - true_m) **2)
                )
        acc_ccc+=(1-ccc)
    return (acc_ccc/batch_size)



'''
def llincc(x, y):
    true_m = np.mean(y)
    true_var = np.var(y)
    pred_m = np.mean(x)
    pred_var = np.var(x)
    rho, _ = pearsonr(x, y)
    std_pred = np.std(x)
    std_true = np.std(y)
    ccc = (
        2
        * rho
        * std_true
        * std_pred
        / (std_pred ** 2 + std_true ** 2 + (pred_m - true_m) ** 2)
    )
    return ccc:
'''

    
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
564
565
566
567
568
569
570
571
572
def cross_validation(model, validation_loader, device):
    """

    :param model:
    :param validation_loader:
    :param device:
    :return:
    """
    model.eval()
Martin Lebourdais's avatar
Martin Lebourdais committed
573
    recall = 0.0
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
574
575
576
577
578
579
580
    accuracy = 0.0
    loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(validation_loader):
            batch_size = target.shape[0]
            target = target.squeeze()
Martin Lebourdais's avatar
Martin Lebourdais committed
581
582
583
584
            output = model(data.to(device))
            output = output.permute(1, 2, 0)
            target = target.permute(1, 0)
            nbpoint = output.shape[0]
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
585

Martin Lebourdais's avatar
Martin Lebourdais committed
586
587
588
589
            rc,pr = calc_recall(output.data,target,device)
            accuracy+= pr
            recall += rc
            
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
590
            loss += criterion(output, target.to(device))
Martin Lebourdais's avatar
Martin Lebourdais committed
591
592
        fmes = 2*(accuracy*recall)/(recall+accuracy)
    return fmes / ((batch_idx + 1) * batch_size), \
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
593
           loss.cpu().numpy() / ((batch_idx + 1) * batch_size)