xvector.py 72.8 KB
Newer Older
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1001
1002
1003
1004
1005
1006
1007
                if p.requires_grad) + \
            sum(p.numel()
                for p in model.before_speaker_embedding.parameters()
                if p.requires_grad) + \
            sum(p.numel()
                for p in model.stat_pooling.parameters()
                if p.requires_grad)))
Anthony Larcher's avatar
Anthony Larcher committed
1008
1009
1010
1011

    return model


Anthony Larcher's avatar
Anthony Larcher committed
1012
def get_loaders(dataset_opts, training_opts, model_opts):
Anthony Larcher's avatar
Anthony Larcher committed
1013
1014
    """

Anthony Larcher's avatar
Anthony Larcher committed
1015
1016
1017
    :param dataset_opts:
    :param training_opts:
    :param model_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1018
1019
1020
1021
1022
1023
1024
    :return:
    """

    """
    Set the dataloaders according to the dataset_yaml
    
    First we load the dataframe from CSV file in order to split it for training and validation purpose
Anthony Larcher's avatar
Anthony Larcher committed
1025
    Then we provide those two
Anthony Larcher's avatar
Anthony Larcher committed
1026
    """
Anthony Larcher's avatar
Anthony Larcher committed
1027
    df = pandas.read_csv(dataset_opts["dataset_csv"])
Anthony Larcher's avatar
Anthony Larcher committed
1028

Anthony Larcher's avatar
Anthony Larcher committed
1029
1030
1031
    training_df, validation_df = train_test_split(df,
                                                  test_size=dataset_opts["validation_ratio"],
                                                  stratify=df["speaker_idx"])
Anthony Larcher's avatar
Anthony Larcher committed
1032
1033

    training_set = SideSet(dataset_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1034
1035
                           set_type="train",
                           chunk_per_segment=-1,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1036
                           transform_number=dataset_opts['train']['transform_number'],
Anthony Larcher's avatar
Anthony Larcher committed
1037
                           overlap=dataset_opts['train']['overlap'],
Anthony Larcher's avatar
Anthony Larcher committed
1038
1039
1040
1041
                           dataset_df=training_df,
                           output_format="pytorch",
                           )

Anthony Larcher's avatar
Anthony Larcher committed
1042
    validation_set = SideSet(dataset_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1043
1044
1045
1046
                             set_type="validation",
                             dataset_df=validation_df,
                             output_format="pytorch")

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    if model_opts["loss"]["type"] == 'aps':
        samples_per_speaker = 2
    else:
        samples_per_speaker = 1

    if training_opts["multi_gpu"]:
        assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
        assert dataset_opts["batch_size"] % samples_per_speaker == 0
        batch_size = dataset_opts["batch_size"]//torch.cuda.device_count()

        side_sampler = SideSampler(training_set.sessions['speaker_idx'],
Anthony Larcher's avatar
Anthony Larcher committed
1058
                                   model_opts["speaker_number"],
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1059
1060
1061
1062
1063
1064
                                   dataset_opts["train"]["sampler"]["examples_per_speaker"],
                                   dataset_opts["train"]["sampler"]["samples_per_speaker"],
                                   dataset_opts["batch_size"])
    else:
        batch_size = dataset_opts["batch_size"]
        side_sampler = SideSampler(training_set.sessions['speaker_idx'],
Anthony Larcher's avatar
Anthony Larcher committed
1065
                                   model_opts["speaker_number"],
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1066
1067
1068
1069
                                   samples_per_speaker,
                                   batch_size,
                                   batch_size,
                                   seed=dataset_opts['seed'])
Anthony Larcher's avatar
Anthony Larcher committed
1070
1071

    training_loader = DataLoader(training_set,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1072
                                 batch_size=batch_size,
Anthony Larcher's avatar
Anthony Larcher committed
1073
1074
1075
1076
                                 shuffle=False,
                                 drop_last=True,
                                 pin_memory=True,
                                 sampler=side_sampler,
Anthony Larcher's avatar
Anthony Larcher committed
1077
                                 num_workers=training_opts["num_cpu"],
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1078
                                 persistent_workers=False)
Anthony Larcher's avatar
Anthony Larcher committed
1079
1080

    validation_loader = DataLoader(validation_set,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1081
                                   batch_size=batch_size,
Anthony Larcher's avatar
Anthony Larcher committed
1082
1083
                                   drop_last=False,
                                   pin_memory=True,
Anthony Larcher's avatar
Anthony Larcher committed
1084
                                   num_workers=training_opts["num_cpu"],
Anthony Larcher's avatar
Anthony Larcher committed
1085
1086
                                   persistent_workers=False)

Anthony Larcher's avatar
Anthony Larcher committed
1087
1088
1089
1090
1091
    # Compute indices for target and non-target trials once only to avoid recomputing for each epoch
    classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
    mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
    tar_indices = torch.tril(mask, -1).numpy()
    non_indices = torch.tril(~mask, -1).numpy()
Anthony Larcher's avatar
Anthony Larcher committed
1092

Anthony Larcher's avatar
Anthony Larcher committed
1093
1094
    # Select a subset of non-target trials to reduce the number of tests
    tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
Anthony Larcher's avatar
Anthony Larcher committed
1095
    non_indices *= (numpy.random.rand(*non_indices.shape) < tar_non_ratio)
Anthony Larcher's avatar
Anthony Larcher committed
1096

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1097
    return training_loader, validation_loader, side_sampler, tar_indices, non_indices
Anthony Larcher's avatar
Anthony Larcher committed
1098
1099


Anthony Larcher's avatar
Anthony Larcher committed
1100
def get_optimizer(model, model_opts, train_opts):
Anthony Larcher's avatar
Anthony Larcher committed
1101
1102
1103
1104
1105
1106
1107
1108
1109
    """

    :param model:
    :param model_yaml:
    :return:
    """
    """
    Set the training options
    """
Anthony Larcher's avatar
Anthony Larcher committed
1110
    if train_opts["optimizer"]["type"] == 'adam':
Anthony Larcher's avatar
Anthony Larcher committed
1111
        _optimizer = torch.optim.Adam
Anthony Larcher's avatar
Anthony Larcher committed
1112
        _options = {'lr': train_opts["lr"]}
Anthony Larcher's avatar
Anthony Larcher committed
1113
    elif train_opts["optimizer"]["type"] == 'rmsprop':
Anthony Larcher's avatar
Anthony Larcher committed
1114
        _optimizer = torch.optim.RMSprop
Anthony Larcher's avatar
Anthony Larcher committed
1115
1116
        _options = {'lr': train_opts["lr"]}
    else:  # train_opts["optimizer"]["type"] == 'sgd'
Anthony Larcher's avatar
Anthony Larcher committed
1117
        _optimizer = torch.optim.SGD
Anthony Larcher's avatar
Anthony Larcher committed
1118
        _options = {'lr': train_opts["lr"], 'momentum': 0.9}
Anthony Larcher's avatar
Anthony Larcher committed
1119
1120
1121
1122

    param_list = []
    if type(model) is Xtractor:
        if model.preprocessor is not None:
Anthony Larcher's avatar
Anthony Larcher committed
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
            param_list.append({'params': model.preprocessor.parameters(),
                               'weight_decay': model.preprocessor_weight_decay})
        param_list.append({'params': model.sequence_network.parameters(),
                           'weight_decay': model.sequence_network_weight_decay})
        param_list.append({'params': model.stat_pooling.parameters(),
                           'weight_decay': model.stat_pooling_weight_decay})
        param_list.append({'params': model.before_speaker_embedding.parameters(),
                           'weight_decay': model.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.after_speaker_embedding.parameters(),
                           'weight_decay': model.after_speaker_embedding_weight_decay})
Anthony Larcher's avatar
Anthony Larcher committed
1133
1134
1135

    else:
        if model.module.preprocessor is not None:
Anthony Larcher's avatar
Anthony Larcher committed
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
            param_list.append({'params': model.module.preprocessor.parameters(),
                               'weight_decay': model.module.preprocessor_weight_decay})
        param_list.append({'params': model.module.sequence_network.parameters(),
                           'weight_decay': model.module.sequence_network_weight_decay})
        param_list.append({'params': model.module.stat_pooling.parameters(),
                           'weight_decay': model.module.stat_pooling_weight_decay})
        param_list.append({'params': model.module.before_speaker_embedding.parameters(),
                           'weight_decay': model.module.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.module.after_speaker_embedding.parameters(),
                           'weight_decay': model.module.after_speaker_embedding_weight_decay})
Anthony Larcher's avatar
Anthony Larcher committed
1146
1147
1148

    optimizer = _optimizer(param_list, **_options)

Anthony Larcher's avatar
Anthony Larcher committed
1149
    if train_opts["scheduler"]["type"] == 'CyclicLR':
Anthony Larcher's avatar
Anthony Larcher committed
1150
1151
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
                                                      base_lr=1e-8,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1152
1153
1154
1155
                                                      max_lr=train_opts["lr"],
                                                      step_size_up=model_opts["speaker_number"] * 2,
                                                      step_size_down=None,
                                                      mode="triangular2")
Anthony Larcher's avatar
Anthony Larcher committed
1156
    elif train_opts["scheduler"]["type"] == "MultiStepLR":
Anthony Larcher's avatar
Anthony Larcher committed
1157
1158
1159
1160
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                         milestones=[10000,50000,100000],
                                                         gamma=0.5)

Anthony Larcher's avatar
Anthony Larcher committed
1161
    elif train_opts["scheduler"]["type"] == "StepLR":
Anthony Larcher's avatar
Anthony Larcher committed
1162
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
Anthony Larcher's avatar
Anthony Larcher committed
1163
1164
1165
                                                           step_size=2e3,
                                                           gamma=0.95)

Anthony Larcher's avatar
Anthony Larcher committed
1166
    elif train_opts["scheduler"]["type"] == "StepLR2":
Anthony Larcher's avatar
Anthony Larcher committed
1167
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
Anthony Larcher's avatar
Anthony Larcher committed
1168
                                                           step_size=2000,
Anthony Larcher's avatar
Anthony Larcher committed
1169
1170
                                                           gamma=0.5)
    else:
Anthony Larcher's avatar
Anthony Larcher committed
1171
1172
1173
1174
1175
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                               mode='min',
                                                               factor=0.5,
                                                               patience=3000,
                                                               verbose=True)
Anthony Larcher's avatar
Anthony Larcher committed
1176
1177
1178
1179

    return optimizer, scheduler


Anthony Larcher's avatar
Anthony Larcher committed
1180
def save_model(model, training_monitor, model_opts, training_opts, optimizer, scheduler):
Anthony Larcher's avatar
Anthony Larcher committed
1181
1182
1183
1184
1185
    """

    :param model:
    :param training_monitor:
    :param model_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1186
1187
    :param training_opts:
    :param optimizer:
Anthony Larcher's avatar
Anthony Larcher committed
1188
1189
1190
1191
1192
1193
    :param scheduler:
    :return:
    """
    # TODO à reprendre
    if type(model) is Xtractor:
        save_checkpoint({
Anthony Larcher's avatar
Anthony Larcher committed
1194
            'epoch': training_monitor.current_epoch,
Anthony Larcher's avatar
Anthony Larcher committed
1195
1196
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
Anthony Larcher's avatar
Anthony Larcher committed
1197
            'accuracy': training_monitor.best_accuracy,
Anthony Larcher's avatar
Anthony Larcher committed
1198
            'scheduler': scheduler,
Anthony Larcher's avatar
Anthony Larcher committed
1199
1200
            'speaker_number' : model.speaker_number,
            'model_archi': model_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1201
1202
            'loss': model_opts["loss"]["type"]
        }, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1203
1204
    else:
        save_checkpoint({
Anthony Larcher's avatar
Anthony Larcher committed
1205
            'epoch': training_monitor.current_epoch,
Anthony Larcher's avatar
Anthony Larcher committed
1206
1207
            'model_state_dict': model.module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
Anthony Larcher's avatar
Anthony Larcher committed
1208
            'accuracy': training_monitor.best_accuracy,
Anthony Larcher's avatar
Anthony Larcher committed
1209
            'scheduler': scheduler,
Anthony Larcher's avatar
Anthony Larcher committed
1210
1211
            'speaker_number': model.module.speaker_number,
            'model_archi': model_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1212
1213
            'loss': model_opts["loss"]["type"]
        }, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1214
1215


Anthony Larcher's avatar
Anthony Larcher committed
1216
1217
1218
1219
def xtrain(dataset_description,
           model_description,
           training_description,
           **kwargs):
Anthony Larcher's avatar
Anthony Larcher committed
1220
1221
    """
    REFACTORING
Anthony Larcher's avatar
Anthony Larcher committed
1222
    - en cas de redemarrage à partir d'un modele existant, recharger l'optimize et le scheduler
Anthony Larcher's avatar
Anthony Larcher committed
1223
    """
Anthony Larcher's avatar
Anthony Larcher committed
1224
1225
1226

    local_rank = -1
    if "RANK" in os.environ:
Anthony Larcher's avatar
Anthony Larcher committed
1227
        local_rank = int(os.environ['RANK'])
Anthony Larcher's avatar
Anthony Larcher committed
1228

Anthony Larcher's avatar
Anthony Larcher committed
1229
1230
1231
    # Test to optimize
    torch.autograd.profiler.emit_nvtx(enabled=False)

Anthony Larcher's avatar
Anthony Larcher committed
1232
1233
    dataset_opts, model_opts, training_opts = update_training_dictionary(dataset_description,
                                                                         model_description,
Anthony Larcher's avatar
Anthony Larcher committed
1234
                                                                         training_description,
Anthony Larcher's avatar
Anthony Larcher committed
1235
1236
                                                                         kwargs)

Anthony Larcher's avatar
Anthony Larcher committed
1237
1238
1239
1240
1241
1242
1243
1244
    # Initialize the training monitor
    monitor = TrainingMonitor(output_file=training_opts["log_file"],
                              patience=training_opts["patience"],
                              best_accuracy=0.0,
                              best_eer_epoch=1,
                              best_eer=100,
                              compute_test_eer=training_opts["compute_test_eer"])

Anthony Larcher's avatar
Anthony Larcher committed
1245
1246
1247
1248
    # Make PyTorch Deterministic
    torch.backends.cudnn.deterministic = False
    if training_opts["deterministic"]:
        torch.backends.cudnn.deterministic = True
Anthony Larcher's avatar
Anthony Larcher committed
1249
1250
1251
1252
1253
1254

    # Set all the seeds
    numpy.random.seed(training_opts["seed"]) # Set the random seed of numpy for the data split.
    torch.manual_seed(training_opts["seed"])
    torch.cuda.manual_seed(training_opts["seed"])

Anthony Larcher's avatar
Anthony Larcher committed
1255
    # Display the entire configurations as YAML dictionaries
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1256
    if local_rank < 1:
Anthony Larcher's avatar
Anthony Larcher committed
1257
1258
1259
1260
1261
1262
        monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
        monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
        monitor.logger.info("\n*********************************\nModel options\n*********************************\n")
        monitor.logger.info(yaml.dump(model_opts, default_flow_style=False))
        monitor.logger.info("\n*********************************\nTraining options\n*********************************\n")
        monitor.logger.info(yaml.dump(training_opts, default_flow_style=False))
Anthony Larcher's avatar
Anthony Larcher committed
1263
1264

    # Initialize the model
Anthony Larcher's avatar
Anthony Larcher committed
1265
    model = get_network(model_opts, local_rank)
Anthony Larcher's avatar
Anthony Larcher committed
1266
    embedding_size = model.embedding_size
Anthony Larcher's avatar
Anthony Larcher committed
1267

Anthony Larcher's avatar
Anthony Larcher committed
1268
    # Set the device and manage parallel processing
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1269
    torch.cuda.set_device(local_rank)
Anthony Larcher's avatar
Anthony Larcher committed
1270
    device = torch.device(local_rank)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1271
1272
1273
1274

    if training_opts["multi_gpu"]:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

Anthony Larcher's avatar
Anthony Larcher committed
1275
1276
    model.to(device)

Anthony Larcher's avatar
Anthony Larcher committed
1277

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
    """ [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
        - Add the following line right after "if __name__ == '__main__':" in your main script :
        parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
        - Then, in your shell :
        export NUM_NODES=1
        export NUM_GPUS_PER_NODE=2
        export NODE_RANK=0
        export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
        python -m torch.distributed.launch \
            --nproc_per_node=$NUM_GPUS_PER_NODE \
            --nnodes=$NUM_NODES \
            --node_rank $NODE_RANK \
            train_xvector.py ...
    """
    if training_opts["multi_gpu"]:
        if local_rank < 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
Anthony Larcher's avatar
Anthony Larcher committed
1296
1297
1298
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[local_rank],
                                                          output_device=local_rank)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1299
1300
1301
    else:
        print("Train on a single GPU")

Anthony Larcher's avatar
Anthony Larcher committed
1302
    # Initialise data loaders
Anthony Larcher's avatar
Anthony Larcher committed
1303
1304
1305
    training_loader, validation_loader, \
    sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
                                                                          training_opts,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1306
                                                                          model_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1307

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1308
1309
1310
1311
    if local_rank < 1:
        monitor.logger.info(f"Start training process")
        monitor.logger.info(f"Use \t{torch.cuda.device_count()} \tgpus")
        monitor.logger.info(f"Use \t{training_opts['num_cpu']} \tcpus")
Anthony Larcher's avatar
Anthony Larcher committed
1312

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1313
1314
1315
        monitor.logger.info(f"Validation EER will be measured using")
        monitor.logger.info(f"\t {numpy.sum(validation_tar_indices)} target trials and")
        monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials")
Anthony Larcher's avatar
Anthony Larcher committed
1316
1317

    # Create optimizer and scheduler
Anthony Larcher's avatar
Anthony Larcher committed
1318
    optimizer, scheduler = get_optimizer(model, model_opts, training_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1319
1320
1321
1322
1323

    scaler = None
    if training_opts["mixed_precision"]:
        scaler = torch.cuda.amp.GradScaler()

Anthony Larcher's avatar
Anthony Larcher committed
1324
    for epoch in range(1, training_opts["epochs"] + 1):
Anthony Larcher's avatar
Anthony Larcher committed
1325

Anthony Larcher's avatar
Anthony Larcher committed
1326
        monitor.update(epoch=epoch)
Anthony Larcher's avatar
Anthony Larcher committed
1327

Anthony Larcher's avatar
Anthony Larcher committed
1328
        # Process one epoch and return the current model
Anthony Larcher's avatar
Anthony Larcher committed
1329
        if monitor.current_patience == 0:
Anthony Larcher's avatar
Anthony Larcher committed
1330
1331
1332
            print(f"Stopping at epoch {epoch} for cause of patience")
            break

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1333
        sampler.set_epoch(epoch)
Anthony Larcher's avatar
Anthony Larcher committed
1334
1335
        if training_opts["multi_gpu"]:
            torch.distributed.barrier()
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1336

Anthony Larcher's avatar
Anthony Larcher committed
1337
1338
1339
1340
1341
1342
1343
1344
        model = train_epoch(model,
                            training_opts,
                            monitor,
                            training_loader,
                            optimizer,
                            scheduler,
                            device,
                            scaler=scaler)
Anthony Larcher's avatar
Anthony Larcher committed
1345
1346
1347
1348
1349
1350

        # Cross validation
        if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
            val_acc, val_loss, val_eer = cross_validation(model,
                                                          validation_loader,
                                                          device,
Anthony Larcher's avatar
Anthony Larcher committed
1351
                                                          [validation_loader.dataset.__len__(), embedding_size],
Anthony Larcher's avatar
Anthony Larcher committed
1352
1353
                                                          validation_tar_indices,
                                                          validation_non_indices,
Anthony Larcher's avatar
Anthony Larcher committed
1354
1355
1356
                                                          training_opts["mixed_precision"])

            test_eer = None
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1357
            if training_opts["compute_test_eer"] and local_rank < 1:
Anthony Larcher's avatar
Anthony Larcher committed
1358
                test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1359

Anthony Larcher's avatar
Anthony Larcher committed
1360
            monitor.update(test_eer=test_eer,
Anthony Larcher's avatar
Anthony Larcher committed
1361
1362
1363
1364
                           val_eer=val_eer,
                           val_loss=val_loss,
                           val_acc=val_acc)

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1365
1366
            if local_rank < 1:
                monitor.display()
Anthony Larcher's avatar
Anthony Larcher committed
1367

Anthony Larcher's avatar
Anthony Larcher committed
1368
1369
            # Save the current model and if needed update the best one
            # TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1370
1371
            if local_rank < 1:
                save_model(model, monitor, model_opts, training_opts, optimizer, scheduler)
Anthony Larcher's avatar
Anthony Larcher committed
1372
1373

    for ii in range(torch.cuda.device_count()):
Anthony Larcher's avatar
Anthony Larcher committed
1374
        monitor.logger.info(torch.cuda.memory_summary(ii))
Anthony Larcher's avatar
Anthony Larcher committed
1375
1376

    # TODO gérer l'affichage en utilisant le training_monitor
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1377
1378
    if local_rank < 1:
        monitor.display_final()
Anthony Larcher's avatar
Anthony Larcher committed
1379

Anthony Larcher's avatar
Anthony Larcher committed
1380
1381
    return monitor.best_eer

Anthony Larcher's avatar
Anthony Larcher committed
1382

Anthony Larcher's avatar
Anthony Larcher committed
1383
1384
1385
1386
1387
1388
1389
1390
1391
def train_epoch(model,
                training_opts,
                training_monitor,
                training_loader,
                optimizer,
                scheduler,
                device,
                scaler=None,
                clipping=False):
Anthony Larcher's avatar
Anthony Larcher committed
1392
1393
1394
    """

    :param model:
Anthony Larcher's avatar
Anthony Larcher committed
1395
    :param training_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    :param training_monitor:
    :param training_loader:
    :param optimizer:
    :param scheduler:
    :param device:
    :param scaler:
    :param clipping:
    :return:
    """
    model.train()
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    if isinstance(model, Xtractor):
        loss_criteria = model.loss
    else:
        loss_criteria = model.module.loss

    accuracy = 0.0
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(training_loader):
        data = data.squeeze().to(device)
        target = target.squeeze()
        target = target.to(device)
        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                if loss_criteria == 'aam':
                    output, _ = model(data, target=target)
                    loss = criterion(output, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1426
1427
1428
1429
                elif loss_criteria == 'smn':
                    output_tuple, _ = model(data, target=target)
                    loss, output = output_tuple
                    loss += criterion(output, target)
Anthony Larcher's avatar
Anthony Larcher committed
1430
1431
                elif loss_criteria == 'aps':
                    output_tuple, _ = model(data, target=target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1432
                    loss, output = output_tuple
Anthony Larcher's avatar
Anthony Larcher committed
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
                else:
                    output, _ = model(data, target=None)
                    loss = criterion(output, target)
        else:
            if loss_criteria == 'aam':
                output, _ = model(data, target=target)
                loss = criterion(output, target)
            elif loss_criteria == 'aps':
                output_tuple, _ = model(data, target=target)
                cos_sim_matx, output = output_tuple
                loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
            else:
                output, _ = model(data, target=None)
                loss = criterion(output, target)

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1448
1449
        #if not torch.isnan(loss):
        if True:
Anthony Larcher's avatar
Anthony Larcher committed
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
            if scaler is not None:
                scaler.scale(loss).backward()
                if clipping:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            running_loss += loss.item()
            accuracy += (torch.argmax(output.data, 1) == target).sum()

Anthony Larcher's avatar
Anthony Larcher committed
1463
            if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
Anthony Larcher's avatar
Anthony Larcher committed
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
                batch_size = target.shape[0]
                training_monitor.update(training_loss=loss.item(),
                                        training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))

                training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                    training_monitor.current_epoch,
                    batch_idx + 1,
                    training_loader.__len__(),
                    100. * batch_idx / training_loader.__len__(),
                    loss.item(),
                    100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
        #else:
        #    save_checkpoint({
        #        'epoch': training_monitor.current_epoch,
        #        'model_state_dict': model.state_dict(),
        #        'optimizer_state_dict': optimizer.state_dict(),
        #        'accuracy': 0.0,
        #        'scheduler': 0.0
        #    }, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
        #    with open("batch_loss_NAN.pkl", "wb") as fh:
        #        pickle.dump(data.cpu(), fh)
        #    import sys
        #    sys.exit()
Anthony Larcher's avatar
Anthony Larcher committed
1488

Anthony Larcher's avatar
Anthony Larcher committed
1489
        running_loss = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1490
1491
1492
1493
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(training_monitor.best_eer)
        else:
            scheduler.step()
Anthony Larcher's avatar
Anthony Larcher committed
1494
1495
    return model

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1496
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
1497
1498
1499
    """

    :param model:
Anthony Larcher's avatar
Anthony Larcher committed
1500
1501
    :param validation_loader:
    :param device:
1502
    :param validation_shape:
1503
1504
1505
    :return:
    """
    model.eval()
Anthony Larcher's avatar
Anthony Larcher committed
1506
1507
1508
1509
1510
    if isinstance(model, Xtractor):
        loss_criteria = model.loss
    else:
        loss_criteria = model.module.loss

1511
    accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1512
    loss = 0.0
1513
    criterion = torch.nn.CrossEntropyLoss()
Anthony Larcher's avatar
Anthony Larcher committed
1514
    embeddings = torch.zeros(validation_shape)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1515
    cursor = 0
Anthony Larcher's avatar
Anthony Larcher committed
1516
    with torch.no_grad():
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1517
        for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
1518
1519
            target = target.squeeze()
            target = target.to(device)
Anthony Larcher's avatar
Anthony Larcher committed
1520
            batch_size = target.shape[0]
1521
            data = data.squeeze().to(device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1522
            with torch.cuda.amp.autocast(enabled=mixed_precision):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1523
1524
1525
1526
1527
1528
1529
                output, batch_embeddings = model(data, target=None, is_eval=True)
                if loss_criteria == 'cce':
                    batch_embeddings = l2_norm(batch_embeddings)
                if loss_criteria == 'smn':
                    batch_embeddings, batch_predictions = output
                else:
                    batch_predictions = output
1530
1531
                accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
                loss += criterion(batch_predictions, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1532
1533
            embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
            cursor += batch_size
1534

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1535
    local_device = "cpu" if embeddings.shape[0] > 3e4 else device
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1536
1537
1538
1539
    embeddings = embeddings.to(local_device)
    scores = torch.einsum('ij,kj', embeddings, embeddings).cpu().numpy()
    negatives = scores[non_indices]
    positives = scores[tar_indices]
1540

Anthony Larcher's avatar
Anthony Larcher committed
1541
    # Faster EER computation available here : https://github.com/gl3lan/fast_eer
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1542
1543
1544
1545
    #equal_error_rate = eer(negatives, positives)

    pmiss, pfa = rocch(positives, negatives)
    equal_error_rate = rocch2eer(pmiss, pfa)
Anthony Larcher's avatar
Anthony Larcher committed
1546

Anthony Larcher's avatar
Anthony Larcher committed
1547
    return (100. * accuracy.cpu().numpy() / validation_shape[0],
Anthony Larcher's avatar
Anthony Larcher committed
1548
            loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
Anthony Larcher's avatar
Anthony Larcher committed
1549
            equal_error_rate)
1550
1551


Anthony Larcher's avatar
Anthony Larcher committed
1552
1553
1554
1555
def extract_embeddings(idmap_name,
                       model_filename,
                       data_root_name,
                       device,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1556
                       loss,
Anthony Larcher's avatar
Anthony Larcher committed
1557
                       file_extension="wav",
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1558
                       transform_pipeline="",
Anthony Larcher's avatar
Anthony Larcher committed
1559
1560
1561
1562
1563
                       frame_shift=0.01,
                       frame_duration=0.025,
                       extract_after_pooling=False,
                       num_thread=1,
                       mixed_precision=False):
1564
1565
    """

Anthony Larcher's avatar
Anthony Larcher committed
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
    :param idmap_name:
    :param model_filename:
    :param data_root_name:
    :param device:
    :param model_yaml:
    :param speaker_number:
    :param file_extension:
    :param transform_pipeline:
    :param frame_shift:
    :param frame_duration:
    :param extract_after_pooling:
1577
    :param num_thread:
Anthony Larcher's avatar
Anthony Larcher committed
1578
    :param mixed_precision:
1579
1580
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
1581
1582
    # Load the model
    if isinstance(model_filename, str):
Anthony Larcher's avatar
Anthony Larcher committed
1583
        checkpoint = torch.load(model_filename, map_location=device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1584
1585
        speaker_number = checkpoint["speaker_number"]
        model_archi = checkpoint["model_archi"]
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1586
        model = Xtractor(speaker_number, model_archi=model_archi, loss=checkpoint["loss"])
1587
1588
1589
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename
Anthony Larcher's avatar
Anthony Larcher committed
1590

Anthony Larcher's avatar
Anthony Larcher committed
1591
    if isinstance(idmap_name, IdMap):
1592
1593
1594
1595
        idmap = idmap_name
    else:
        idmap = IdMap(idmap_name)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1596
1597
1598
1599
1600
1601
    if type(model) is Xtractor:
        min_duration = (model.context_size() - 1) * frame_shift + frame_duration
        model_cs = model.context_size()
    else:
        min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
        model_cs = model.module.context_size()
1602

Anthony Larcher's avatar
Anthony Larcher committed
1603
    # Create dataset to load the data
Anthony Larcher's avatar
Anthony Larcher committed
1604
    dataset = IdMapSet(idmap_name=idmap_name,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1605
                       data_path=data_root_name,
Anthony Larcher's avatar
Anthony Larcher committed
1606
                       file_extension=file_extension,
1607
                       transform_pipeline=transform_pipeline,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1608
                       min_duration=(model_cs + 2) * frame_shift * 2
1609
                       )
Anthony Larcher's avatar
Anthony Larcher committed
1610

1611
1612
1613
1614
1615
1616
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)
1617

Anthony Larcher's avatar
Anthony Larcher committed
1618
    with torch.no_grad():
1619
1620
1621
1622
1623

        model.eval()
        model.to(device)

        # Get the size of embeddings to extract
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
        if type(model) is Xtractor:
            name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
        else:
            name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
Anthony Larcher's avatar
Anthony Larcher committed
1634

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1635

1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
        # Create the StatServer
        embeddings = StatServer()
        embeddings.modelset = idmap.leftids
        embeddings.segset = idmap.rightids
        embeddings.start = idmap.start
        embeddings.stop = idmap.stop
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
        embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))

        # Process the data
        with torch.no_grad():
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1647
1648
            for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
                                                                          desc='xvector extraction',
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1649
1650
                                                                          mininterval=1,
                                                                          disable=None)):
Anthony Larcher's avatar
Anthony Larcher committed
1651
1652
                if data.shape[1] > 20000000:
                    data = data[...,:20000000]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1653
                with torch.cuda.amp.autocast(enabled=mixed_precision):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1654
                    _, vec = model(x=data.to(device), is_eval=True)
1655
                embeddings.stat1[idx, :] = vec.detach().cpu()
Anthony Larcher's avatar
Anthony Larcher committed
1656
1657
1658
1659

    return embeddings


Anthony Larcher's avatar
Anthony Larcher committed
1660
1661
1662
1663
1664
1665
1666
1667
def extract_embeddings_per_speaker(idmap_name,
                                   model_filename,
                                   data_root_name,
                                   device,
                                   file_extension="wav",
                                   transform_pipeline=None,
                                   frame_shift=0.01,
                                   frame_duration=0.025,
1668
                                   extract_after_pooling=False,
Anthony Larcher's avatar
Anthony Larcher committed
1669
1670
1671
1672
                                   num_thread=1):
    # Load the model
    if isinstance(model_filename, str):
        checkpoint = torch.load(model_filename)
Anthony Larcher's avatar
Anthony Larcher committed
1673

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1674
        model_archi = checkpoint["model_archi"]
Anthony Larcher's avatar
Anthony Larcher committed
1675

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1676
        model = Xtractor(checkpoint["speaker_number"], model_archi=model_archi, loss="aam")
Anthony Larcher's avatar
Anthony Larcher committed
1677
1678
1679
1680
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1681
    model = model.to(memory_format=torch.channels_last)
Anthony Larcher's avatar
Anthony Larcher committed
1682

Anthony Larcher's avatar
Anthony Larcher committed
1683
1684
1685
    min_duration = (model.context_size() - 1) * frame_shift + frame_duration

    # Create dataset to load the data
Anthony Larcher's avatar
Anthony Larcher committed
1686
1687
1688
1689
1690
1691
    dataset = IdMapSetPerSpeaker(idmap_name=idmap_name,
                                 data_root_path=data_root_name,
                                 file_extension=file_extension,
                                 transform_pipeline=transform_pipeline,
                                 frame_rate=int(1. / frame_shift),
                                 min_duration=(model.context_size() + 2) * frame_shift * 2)
Anthony Larcher's avatar
Anthony Larcher committed
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)

    with torch.no_grad():
        model.eval()
        model.to(device)

        # Get the size of embeddings to extract
        name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
Anthony Larcher's avatar
Anthony Larcher committed
1706
1707
1708
        if extract_after_pooling:
            emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
        else:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1709
            emb_size = model.embedding_size
Anthony Larcher's avatar
Anthony Larcher committed
1710
1711
1712

        # Create the StatServer
        embeddings = StatServer()
Anthony Larcher's avatar
Anthony Larcher committed
1713
1714
1715
1716
        embeddings.modelset = dataset.output_im.leftids
        embeddings.segset = dataset.output_im.rightids
        embeddings.start = dataset.output_im.start
        embeddings.stop = dataset.output_im.stop
Anthony Larcher's avatar
Anthony Larcher committed
1717
1718
1719
1720
1721
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
        embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))

        # Process the data
        with torch.no_grad():
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1722
            for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
Anthony Larcher's avatar
Anthony Larcher committed
1723
1724
                if data.shape[1] > 20000000:
                    data = data[..., :20000000]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1725
                vec = model(data.to(device), is_eval=True)
Anthony Larcher's avatar
Anthony Larcher committed
1726
1727
1728
1729
                embeddings.stat1[idx, :] = vec.detach().cpu()

    return embeddings

Anthony Larcher's avatar
Anthony Larcher committed
1730

Anthony Larcher's avatar
Anthony Larcher committed
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
def extract_sliding_embedding(idmap_name,
                              window_len,
                              window_shift,
                              model_filename,
                              data_root_name ,
                              device,
                              sample_rate=16000,
                              file_extension="wav",
                              transform_pipeline=None,
                              num_thread=1,
                              mixed_precision=False):
    """

    :param idmap_name:
    :param window_length:
    :param sample_rate:
    :param overlap:
    :param speaker_number:
    :param model_filename:
    :param model_yaml:
    :param data_root_name:
    :param device:
    :param file_extension:
    :param transform_pipeline:
    :return:
    """
    # From the original IdMap, create the new one to extract x-vectors
    if not isinstance(idmap_name, IdMap):
        input_idmap = IdMap(idmap_name)
    else:
        input_idmap = idmap_name

    # Load the model
    if isinstance(model_filename, str):
        checkpoint = torch.load(model_filename, map_location=device)
        speaker_number = checkpoint["speaker_number"]
        model_archi = checkpoint["model_archi"]
        model = Xtractor(speaker_number, model_archi=model_archi)
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename

    if isinstance(idmap_name, IdMap):
        idmap = idmap_name
    else:
        idmap = IdMap(idmap_name)

    # Create dataset to load the data
    dataset = IdMapSet(idmap_name=idmap_name,
                       data_path=data_root_name,
                       file_extension=file_extension,
                       transform_pipeline=transform_pipeline,
                       sliding_window=True,
                       window_len=window_len,
                       window_shift=window_shift,
                       sample_rate=sample_rate,
                       min_duration=0.1
                       )

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)

    with torch.no_grad():

        model.eval()
        model.to(device)

        # Get the size of embeddings to extract
        if type(model) is Xtractor:
            name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
        else:
            name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
            if name != 'bias':
                name = name + '.weight'
            emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]

Anthony Larcher's avatar
Anthony Larcher committed
1814
1815
1816
1817
        embeddings = []
        modelset= []
        segset = []
        starts = []
Anthony Larcher's avatar
Anthony Larcher committed
1818
1819
1820
1821
1822
1823
1824

        # Process the data
        with torch.no_grad():
            for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
                                                                          desc='xvector extraction',
                                                                          mininterval=1)):
                with torch.cuda.amp.autocast(enabled=mixed_precision):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1825
1826
1827
1828
1829
1830
1831
1832
1833
                    data = data.squeeze()
                    tmp_data = torch.split(data,data.shape[0]//(data.shape[0]//100))
                    for td in tmp_data:
                        vec = model(x=td.to(device), is_eval=True)
                        embeddings.append(vec.detach().cpu())
                    modelset += [mod,] *  data.shape[0]
                    segset += [seg,] *  data.shape[0]
                    starts += [numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift),]

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1834
        #REPRENDRE ICI
Anthony Larcher's avatar
Anthony Larcher committed
1835

Anthony Larcher's avatar
Anthony Larcher committed
1836
1837
1838
1839
1840
1841
1842
1843
        # Create the StatServer
        embeddings = StatServer()
        embeddings.modelset = numpy.array(modelset).astype('>U')
        embeddings.segset = numpy.array(segset).astype('>U')
        embeddings.start = numpy.array(starts)
        embeddings.stop = numpy.array(starts) + window_len
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
        embeddings.stat1 = numpy.concatenate(embeddings)
Anthony Larcher's avatar
Anthony Larcher committed
1844

Anthony Larcher's avatar
Anthony Larcher committed
1845
    return embeddings