xvector.py 92.1 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1001
1002
1003
        model_opts["loss"]["aam_margin"] = kwargs["margin"]
    if "aam_s" in kwargs:
        model_opts["loss"]["aam_s"] = kwargs["aam_s"]
Anthony Larcher's avatar
Anthony Larcher committed
1004
1005
1006
1007

    return dataset_opts, model_opts, training_opts


Anthony Larcher's avatar
Anthony Larcher committed
1008
def get_network(model_opts):
Anthony Larcher's avatar
Anthony Larcher committed
1009
    """
Anthony Larcher's avatar
Anthony Larcher committed
1010
1011
1012

    :param model_opts:
    :return:
Anthony Larcher's avatar
Anthony Larcher committed
1013
1014
    """

Anthony Larcher's avatar
Anthony Larcher committed
1015
1016
    if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
        model = Xtractor(model_opts["speaker_number"], model_opts["model_type"], loss=model_opts["loss"]["type"])
Anthony Larcher's avatar
Anthony Larcher committed
1017
    else:
Anthony Larcher's avatar
Anthony Larcher committed
1018
        # Custom type of model
Anthony Larcher's avatar
Anthony Larcher committed
1019
        model = Xtractor(model_opts["speaker_number"], model_opts, loss=model_opts["loss"]["type"])
Anthony Larcher's avatar
Anthony Larcher committed
1020

Anthony Larcher's avatar
Anthony Larcher committed
1021
1022
1023
1024
    # Load the model if it exists
    if model_opts["initial_model_name"] is not None and os.path.isfile(model_opts["initial_model_name"]):
        logging.critical(f"*** Load model from = {model_opts['initial_model_name']}")
        checkpoint = torch.load(model_opts["initial_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1025

Anthony Larcher's avatar
Anthony Larcher committed
1026
1027
        """
        Here we remove all layers that we don't want to reload
Anthony Larcher's avatar
Anthony Larcher committed
1028

Anthony Larcher's avatar
Anthony Larcher committed
1029
1030
1031
1032
        """
        pretrained_dict = checkpoint["model_state_dict"]
        for part in model_opts["reset_parts"]:
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
Anthony Larcher's avatar
Anthony Larcher committed
1033

Anthony Larcher's avatar
Anthony Larcher committed
1034
1035
1036
        new_model_dict = model.state_dict()
        new_model_dict.update(pretrained_dict)
        model.load_state_dict(new_model_dict)
Anthony Larcher's avatar
Anthony Larcher committed
1037
1038
1039

        # Freeze required layers
        for name, param in model.named_parameters():
Anthony Larcher's avatar
Anthony Larcher committed
1040
            if name.split(".")[0] in model_opts["reset_parts"]:
Anthony Larcher's avatar
Anthony Larcher committed
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
                param.requires_grad = False

    logging.critical(model)

    logging.critical("model_parameters_count: {:d}".format(
        sum(p.numel()
            for p in model.sequence_network.parameters()
            if p.requires_grad) + \
        sum(p.numel()
            for p in model.before_speaker_embedding.parameters()
            if p.requires_grad) + \
        sum(p.numel()
            for p in model.stat_pooling.parameters()
            if p.requires_grad)))

    return model


Anthony Larcher's avatar
merge    
Anthony Larcher committed
1059
def get_loaders(dataset_opts, training_opts, model_opts, speaker_number):
Anthony Larcher's avatar
Anthony Larcher committed
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
    """

    :param dataset_yaml:
    :return:
    """

    """
    Set the dataloaders according to the dataset_yaml
    
    First we load the dataframe from CSV file in order to split it for training and validation purpose
    Then we provide those two 
    """
Anthony Larcher's avatar
Anthony Larcher committed
1072
    df = pandas.read_csv(dataset_opts["dataset_csv"])
Anthony Larcher's avatar
Anthony Larcher committed
1073

Anthony Larcher's avatar
Anthony Larcher committed
1074
1075
1076
    training_df, validation_df = train_test_split(df,
                                                  test_size=dataset_opts["validation_ratio"],
                                                  stratify=df["speaker_idx"])
Anthony Larcher's avatar
Anthony Larcher committed
1077
1078

    training_set = SideSet(dataset_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1079
1080
                           set_type="train",
                           chunk_per_segment=-1,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1081
                           transform_number=dataset_opts['train']['transform_number'],
Anthony Larcher's avatar
Anthony Larcher committed
1082
                           overlap=dataset_opts['train']['overlap'],
Anthony Larcher's avatar
Anthony Larcher committed
1083
1084
1085
1086
                           dataset_df=training_df,
                           output_format="pytorch",
                           )

Anthony Larcher's avatar
Anthony Larcher committed
1087
    validation_set = SideSet(dataset_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1088
1089
1090
1091
                             set_type="validation",
                             dataset_df=validation_df,
                             output_format="pytorch")

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
    if model_opts["loss"]["type"] == 'aps':
        samples_per_speaker = 2
    else:
        samples_per_speaker = 1

    if training_opts["multi_gpu"]:
        assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
        assert dataset_opts["batch_size"] % samples_per_speaker == 0
        batch_size = dataset_opts["batch_size"]//torch.cuda.device_count()

        side_sampler = SideSampler(training_set.sessions['speaker_idx'],
                                   speaker_number,
                                   dataset_opts["train"]["sampler"]["examples_per_speaker"],
                                   dataset_opts["train"]["sampler"]["samples_per_speaker"],
                                   dataset_opts["batch_size"])
    else:
        batch_size = dataset_opts["batch_size"]
        side_sampler = SideSampler(training_set.sessions['speaker_idx'],
                                   speaker_number,
                                   samples_per_speaker,
                                   batch_size,
                                   batch_size,
                                   seed=dataset_opts['seed'])
Anthony Larcher's avatar
Anthony Larcher committed
1115
1116

    training_loader = DataLoader(training_set,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1117
                                 batch_size=batch_size,
Anthony Larcher's avatar
Anthony Larcher committed
1118
1119
1120
1121
                                 shuffle=False,
                                 drop_last=True,
                                 pin_memory=True,
                                 sampler=side_sampler,
Anthony Larcher's avatar
Anthony Larcher committed
1122
                                 num_workers=training_opts["num_cpu"],
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1123
                                 persistent_workers=False)
Anthony Larcher's avatar
Anthony Larcher committed
1124
1125

    validation_loader = DataLoader(validation_set,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1126
                                   batch_size=batch_size,
Anthony Larcher's avatar
Anthony Larcher committed
1127
1128
                                   drop_last=False,
                                   pin_memory=True,
Anthony Larcher's avatar
Anthony Larcher committed
1129
                                   num_workers=training_opts["num_cpu"],
Anthony Larcher's avatar
Anthony Larcher committed
1130
1131
                                   persistent_workers=False)

Anthony Larcher's avatar
Anthony Larcher committed
1132
1133
1134
1135
1136
    # Compute indices for target and non-target trials once only to avoid recomputing for each epoch
    classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
    mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
    tar_indices = torch.tril(mask, -1).numpy()
    non_indices = torch.tril(~mask, -1).numpy()
Anthony Larcher's avatar
Anthony Larcher committed
1137

Anthony Larcher's avatar
Anthony Larcher committed
1138
1139
    # Select a subset of non-target trials to reduce the number of tests
    tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
Anthony Larcher's avatar
Anthony Larcher committed
1140
    non_indices *= (numpy.random.rand(*non_indices.shape) < tar_non_ratio)
Anthony Larcher's avatar
Anthony Larcher committed
1141

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1142
    return training_loader, validation_loader, side_sampler, tar_indices, non_indices
Anthony Larcher's avatar
Anthony Larcher committed
1143
1144


Anthony Larcher's avatar
Anthony Larcher committed
1145
def get_optimizer(model, model_opts, train_opts):
Anthony Larcher's avatar
Anthony Larcher committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
    """

    :param model:
    :param model_yaml:
    :return:
    """
    """
    Set the training options
    """
Anthony Larcher's avatar
Anthony Larcher committed
1155
    if train_opts["optimizer"]["type"] == 'adam':
Anthony Larcher's avatar
Anthony Larcher committed
1156
        _optimizer = torch.optim.Adam
Anthony Larcher's avatar
Anthony Larcher committed
1157
        _options = {'lr': train_opts["lr"]}
Anthony Larcher's avatar
Anthony Larcher committed
1158
    elif train_opts["optimizer"]["type"] == 'rmsprop':
Anthony Larcher's avatar
Anthony Larcher committed
1159
        _optimizer = torch.optim.RMSprop
Anthony Larcher's avatar
Anthony Larcher committed
1160
1161
        _options = {'lr': train_opts["lr"]}
    else:  # train_opts["optimizer"]["type"] == 'sgd'
Anthony Larcher's avatar
Anthony Larcher committed
1162
        _optimizer = torch.optim.SGD
Anthony Larcher's avatar
Anthony Larcher committed
1163
        _options = {'lr': train_opts["lr"], 'momentum': 0.9}
Anthony Larcher's avatar
Anthony Larcher committed
1164
1165
1166
1167

    param_list = []
    if type(model) is Xtractor:
        if model.preprocessor is not None:
Anthony Larcher's avatar
Anthony Larcher committed
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
            param_list.append({'params': model.preprocessor.parameters(),
                               'weight_decay': model.preprocessor_weight_decay})
        param_list.append({'params': model.sequence_network.parameters(),
                           'weight_decay': model.sequence_network_weight_decay})
        param_list.append({'params': model.stat_pooling.parameters(),
                           'weight_decay': model.stat_pooling_weight_decay})
        param_list.append({'params': model.before_speaker_embedding.parameters(),
                           'weight_decay': model.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.after_speaker_embedding.parameters(),
                           'weight_decay': model.after_speaker_embedding_weight_decay})
Anthony Larcher's avatar
Anthony Larcher committed
1178
1179
1180

    else:
        if model.module.preprocessor is not None:
Anthony Larcher's avatar
Anthony Larcher committed
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
            param_list.append({'params': model.module.preprocessor.parameters(),
                               'weight_decay': model.module.preprocessor_weight_decay})
        param_list.append({'params': model.module.sequence_network.parameters(),
                           'weight_decay': model.module.sequence_network_weight_decay})
        param_list.append({'params': model.module.stat_pooling.parameters(),
                           'weight_decay': model.module.stat_pooling_weight_decay})
        param_list.append({'params': model.module.before_speaker_embedding.parameters(),
                           'weight_decay': model.module.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.module.after_speaker_embedding.parameters(),
                           'weight_decay': model.module.after_speaker_embedding_weight_decay})
Anthony Larcher's avatar
Anthony Larcher committed
1191
1192
1193

    optimizer = _optimizer(param_list, **_options)

Anthony Larcher's avatar
Anthony Larcher committed
1194
    if train_opts["scheduler"]["type"] == 'CyclicLR':
Anthony Larcher's avatar
Anthony Larcher committed
1195
1196
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
                                                      base_lr=1e-8,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1197
1198
1199
1200
                                                      max_lr=train_opts["lr"],
                                                      step_size_up=model_opts["speaker_number"] * 2,
                                                      step_size_down=None,
                                                      mode="triangular2")
Anthony Larcher's avatar
Anthony Larcher committed
1201
    elif train_opts["scheduler"]["type"] == "MultiStepLR":
Anthony Larcher's avatar
Anthony Larcher committed
1202
1203
1204
1205
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                         milestones=[10000,50000,100000],
                                                         gamma=0.5)

Anthony Larcher's avatar
Anthony Larcher committed
1206
    elif train_opts["scheduler"]["type"] == "StepLR":
Anthony Larcher's avatar
Anthony Larcher committed
1207
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
Anthony Larcher's avatar
Anthony Larcher committed
1208
1209
1210
                                                           step_size=2e3,
                                                           gamma=0.95)

Anthony Larcher's avatar
Anthony Larcher committed
1211
    elif train_opts["scheduler"]["type"] == "StepLR2":
Anthony Larcher's avatar
Anthony Larcher committed
1212
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
Anthony Larcher's avatar
Anthony Larcher committed
1213
                                                           step_size=2000,
Anthony Larcher's avatar
Anthony Larcher committed
1214
1215
                                                           gamma=0.5)
    else:
Anthony Larcher's avatar
Anthony Larcher committed
1216
1217
1218
1219
1220
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                               mode='min',
                                                               factor=0.5,
                                                               patience=3000,
                                                               verbose=True)
Anthony Larcher's avatar
Anthony Larcher committed
1221
1222
1223
1224

    return optimizer, scheduler


Anthony Larcher's avatar
Anthony Larcher committed
1225
def save_model(model, training_monitor, model_opts, training_opts, optimizer, scheduler):
Anthony Larcher's avatar
Anthony Larcher committed
1226
1227
1228
1229
1230
    """

    :param model:
    :param training_monitor:
    :param model_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1231
1232
    :param training_opts:
    :param optimizer:
Anthony Larcher's avatar
Anthony Larcher committed
1233
1234
1235
1236
1237
1238
    :param scheduler:
    :return:
    """
    # TODO à reprendre
    if type(model) is Xtractor:
        save_checkpoint({
Anthony Larcher's avatar
Anthony Larcher committed
1239
            'epoch': training_monitor.current_epoch,
Anthony Larcher's avatar
Anthony Larcher committed
1240
1241
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
Anthony Larcher's avatar
Anthony Larcher committed
1242
            'accuracy': training_monitor.best_accuracy,
Anthony Larcher's avatar
Anthony Larcher committed
1243
            'scheduler': scheduler,
Anthony Larcher's avatar
Anthony Larcher committed
1244
1245
            'speaker_number' : model.speaker_number,
            'model_archi': model_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1246
1247
            'loss': model_opts["loss"]["type"]
        }, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1248
1249
    else:
        save_checkpoint({
Anthony Larcher's avatar
Anthony Larcher committed
1250
            'epoch': training_monitor.current_epoch,
Anthony Larcher's avatar
Anthony Larcher committed
1251
1252
            'model_state_dict': model.module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
Anthony Larcher's avatar
Anthony Larcher committed
1253
            'accuracy': training_monitor.best_accuracy,
Anthony Larcher's avatar
Anthony Larcher committed
1254
            'scheduler': scheduler,
Anthony Larcher's avatar
Anthony Larcher committed
1255
1256
            'speaker_number': model.module.speaker_number,
            'model_archi': model_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1257
1258
            'loss': model_opts["loss"]["type"]
        }, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1259
1260
1261
1262


def new_xtrain(dataset_description,
               model_description,
Anthony Larcher's avatar
Anthony Larcher committed
1263
               training_description,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1264
               local_rank=-1,
Anthony Larcher's avatar
Anthony Larcher committed
1265
1266
1267
               **kwargs):
    """
    REFACTORING
Anthony Larcher's avatar
Anthony Larcher committed
1268
    - affiner les loggings
Anthony Larcher's avatar
Anthony Larcher committed
1269
    - en cas de redemarrage à partir d'un modele existant, recharger l'optimize et le scheduler
Anthony Larcher's avatar
Anthony Larcher committed
1270
    """
Anthony Larcher's avatar
Anthony Larcher committed
1271
1272
1273
    # Test to optimize
    torch.autograd.profiler.emit_nvtx(enabled=False)

Anthony Larcher's avatar
Anthony Larcher committed
1274
1275
    dataset_opts, model_opts, training_opts = update_training_dictionary(dataset_description,
                                                                         model_description,
Anthony Larcher's avatar
Anthony Larcher committed
1276
                                                                         training_description,
Anthony Larcher's avatar
Anthony Larcher committed
1277
1278
                                                                         kwargs)

Anthony Larcher's avatar
Anthony Larcher committed
1279
1280
1281
1282
1283
1284
1285
1286
    # Initialize the training monitor
    monitor = TrainingMonitor(output_file=training_opts["log_file"],
                              patience=training_opts["patience"],
                              best_accuracy=0.0,
                              best_eer_epoch=1,
                              best_eer=100,
                              compute_test_eer=training_opts["compute_test_eer"])

Anthony Larcher's avatar
Anthony Larcher committed
1287
1288
1289
1290
    # Make PyTorch Deterministic
    torch.backends.cudnn.deterministic = False
    if training_opts["deterministic"]:
        torch.backends.cudnn.deterministic = True
Anthony Larcher's avatar
Anthony Larcher committed
1291
1292
1293
1294
1295
1296

    # Set all the seeds
    numpy.random.seed(training_opts["seed"]) # Set the random seed of numpy for the data split.
    torch.manual_seed(training_opts["seed"])
    torch.cuda.manual_seed(training_opts["seed"])

Anthony Larcher's avatar
Anthony Larcher committed
1297
    # Display the entire configurations as YAML dictionnaries
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1298
    if local_rank < 1:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1299
1300
1301
1302
1303
1304
        monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
        monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
        monitor.logger.info("\n*********************************\nModel options\n*********************************\n")
        monitor.logger.info(yaml.dump(model_opts, default_flow_style=False))
        monitor.logger.info("\n*********************************\nTraining options\n*********************************\n")
        monitor.logger.info(yaml.dump(training_opts, default_flow_style=False))
Anthony Larcher's avatar
Anthony Larcher committed
1305
1306

    # Initialize the model
Anthony Larcher's avatar
Anthony Larcher committed
1307
    model = get_network(model_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1308
    speaker_number = model.speaker_number
Anthony Larcher's avatar
Anthony Larcher committed
1309
    embedding_size = model.embedding_size
Anthony Larcher's avatar
Anthony Larcher committed
1310

Anthony Larcher's avatar
Anthony Larcher committed
1311
    # Set the device and manage parallel processing
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1312
1313
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda")
Anthony Larcher's avatar
Anthony Larcher committed
1314
1315
    model.to(device)

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
    # If multi-gpu
    """ [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
        - Add the following line right after "if __name__ == '__main__':" in your main script :
        parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
        - Then, in your shell :
        export NUM_NODES=1
        export NUM_GPUS_PER_NODE=2
        export NODE_RANK=0
        export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
        python -m torch.distributed.launch \
            --nproc_per_node=$NUM_GPUS_PER_NODE \
            --nnodes=$NUM_NODES \
            --node_rank $NODE_RANK \
            train_xvector.py ...
    """
    if training_opts["multi_gpu"]:
        if local_rank < 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank
        )
    else:
        print("Train on a single GPU")

Anthony Larcher's avatar
Anthony Larcher committed
1343
    # Initialise data loaders
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1344
1345
1346
    training_loader, validation_loader, sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
                                                                                                              training_opts,
                                                                                                              speaker_number)
Anthony Larcher's avatar
Anthony Larcher committed
1347

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1348
1349
1350
1351
    if local_rank < 1:
        monitor.logger.info(f"Start training process")
        monitor.logger.info(f"Use \t{torch.cuda.device_count()} \tgpus")
        monitor.logger.info(f"Use \t{training_opts['num_cpu']} \tcpus")
Anthony Larcher's avatar
Anthony Larcher committed
1352

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1353
1354
1355
        monitor.logger.info(f"Validation EER will be measured using")
        monitor.logger.info(f"\t {numpy.sum(validation_tar_indices)} target trials and")
        monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials")
Anthony Larcher's avatar
Anthony Larcher committed
1356
1357

    # Create optimizer and scheduler
Anthony Larcher's avatar
Anthony Larcher committed
1358
    optimizer, scheduler = get_optimizer(model, model_opts, training_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1359
1360
1361
1362
1363

    scaler = None
    if training_opts["mixed_precision"]:
        scaler = torch.cuda.amp.GradScaler()

Anthony Larcher's avatar
Anthony Larcher committed
1364
    for epoch in range(1, training_opts["epochs"] + 1):
Anthony Larcher's avatar
Anthony Larcher committed
1365

Anthony Larcher's avatar
Anthony Larcher committed
1366
        monitor.update(epoch=epoch)
Anthony Larcher's avatar
Anthony Larcher committed
1367

Anthony Larcher's avatar
Anthony Larcher committed
1368
        # Process one epoch and return the current model
Anthony Larcher's avatar
Anthony Larcher committed
1369
        if monitor.current_patience == 0:
Anthony Larcher's avatar
Anthony Larcher committed
1370
1371
1372
            print(f"Stopping at epoch {epoch} for cause of patience")
            break

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1373
1374
        sampler.set_epoch(epoch)

Anthony Larcher's avatar
Anthony Larcher committed
1375
        model = new_train_epoch(model,
Anthony Larcher's avatar
Anthony Larcher committed
1376
1377
1378
1379
1380
1381
1382
                                training_opts,
                                monitor,
                                training_loader,
                                optimizer,
                                scheduler,
                                device,
                                scaler=scaler)
Anthony Larcher's avatar
Anthony Larcher committed
1383
1384
1385
1386
1387
1388

        # Cross validation
        if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
            val_acc, val_loss, val_eer = cross_validation(model,
                                                          validation_loader,
                                                          device,
Anthony Larcher's avatar
Anthony Larcher committed
1389
                                                          [validation_loader.dataset.__len__(), embedding_size],
Anthony Larcher's avatar
Anthony Larcher committed
1390
1391
                                                          validation_tar_indices,
                                                          validation_non_indices,
Anthony Larcher's avatar
Anthony Larcher committed
1392
1393
1394
                                                          training_opts["mixed_precision"])

            test_eer = None
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1395
            if training_opts["compute_test_eer"] and local_rank < 1:
Anthony Larcher's avatar
Anthony Larcher committed
1396
                test_eer = new_test_metrics(model, device, model_opts, dataset_opts, training_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1397

Anthony Larcher's avatar
Anthony Larcher committed
1398
            monitor.update(test_eer=test_eer,
Anthony Larcher's avatar
Anthony Larcher committed
1399
1400
1401
1402
                           val_eer=val_eer,
                           val_loss=val_loss,
                           val_acc=val_acc)

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1403
1404
            if local_rank < 1:
                monitor.display()
Anthony Larcher's avatar
Anthony Larcher committed
1405

Anthony Larcher's avatar
Anthony Larcher committed
1406
1407
            # Save the current model and if needed update the best one
            # TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1408
1409
            if local_rank < 1:
                save_model(model, monitor, model_opts, training_opts, optimizer, scheduler)
Anthony Larcher's avatar
Anthony Larcher committed
1410
1411

    for ii in range(torch.cuda.device_count()):
Anthony Larcher's avatar
Anthony Larcher committed
1412
        monitor.logger.info(torch.cuda.memory_summary(ii))
Anthony Larcher's avatar
Anthony Larcher committed
1413
1414

    # TODO gérer l'affichage en utilisant le training_monitor
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1415
1416
    if local_rank < 1:
        monitor.display_final()
Anthony Larcher's avatar
Anthony Larcher committed
1417

Anthony Larcher's avatar
Anthony Larcher committed
1418
1419
    return monitor.best_eer

Anthony Larcher's avatar
Anthony Larcher committed
1420
1421
1422



Anthony Larcher's avatar
Anthony Larcher committed
1423
def xtrain(speaker_number,
Anthony Larcher's avatar
Anthony Larcher committed
1424
           dataset_yaml,
Anthony Larcher's avatar
Anthony Larcher committed
1425
1426
           epochs=None,
           lr=None,
Anthony Larcher's avatar
Anthony Larcher committed
1427
           model_yaml=None,
Anthony Larcher's avatar
Anthony Larcher committed
1428
           model_name=None,
Anthony Larcher's avatar
Anthony Larcher committed
1429
1430
1431
1432
           loss=None,
           aam_margin=None,
           aam_s=None,
           patience=None,
Anthony Larcher's avatar
Anthony Larcher committed
1433
           tmp_model_name=None,
Anthony Larcher's avatar
minor    
Anthony Larcher committed
1434
           best_model_name=None,
Anthony Larcher's avatar
Anthony Larcher committed
1435
           multi_gpu=True,
Anthony Larcher's avatar
Anthony Larcher committed
1436
           device=None,
Anthony Larcher's avatar
Anthony Larcher committed
1437
           mixed_precision=False,
1438
           clipping=False,
Anthony Larcher's avatar
Anthony Larcher committed
1439
           opt=None,
Anthony Larcher's avatar
Anthony Larcher committed
1440
1441
           reset_parts=[],
           freeze_parts=[],
Anthony Larcher's avatar
Anthony Larcher committed
1442
           num_thread=None,
Anthony Larcher's avatar
Anthony Larcher committed
1443
           compute_test_eer=True):
1444
1445
    """

Anthony Larcher's avatar
Anthony Larcher committed
1446
1447
1448
1449
1450
1451
    :param speaker_number:
    :param dataset_yaml:
    :param epochs:
    :param lr:
    :param model_yaml:
    :param model_name:
Anthony Larcher's avatar
Anthony Larcher committed
1452
1453
1454
1455
    :param loss:
    :param aam_margin:
    :param aam_s:
    :param patience:
Anthony Larcher's avatar
Anthony Larcher committed
1456
1457
1458
    :param tmp_model_name:
    :param best_model_name:
    :param multi_gpu:
Anthony Larcher's avatar
Anthony Larcher committed
1459
    :param mixed_precision:
Anthony Larcher's avatar
Anthony Larcher committed
1460
    :param clipping:
Anthony Larcher's avatar
Anthony Larcher committed
1461
1462
1463
    :param opt:
    :param reset_parts:
    :param freeze_parts:
Anthony Larcher's avatar
Anthony Larcher committed
1464
    :param num_thread:
Anthony Larcher's avatar
Anthony Larcher committed
1465
    :param compute_test_eer:
1466
1467
    :return:
    """
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1468
1469
1470
    # Test to optimize
    torch.autograd.profiler.emit_nvtx(enabled=False)

Anthony Larcher's avatar
Anthony Larcher committed
1471
    if num_thread is None:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1472
        import multiprocessing
Anthony Larcher's avatar
Anthony Larcher committed
1473
1474
        num_thread = multiprocessing.cpu_count()

Anthony Larcher's avatar
Anthony Larcher committed
1475
    logging.critical(f"Use {num_thread} cpus")
Anthony Larcher's avatar
Anthony Larcher committed
1476
    logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
1477

Anthony Larcher's avatar
Anthony Larcher committed
1478
1479
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1480
1481

    # Use a predefined architecture
Anthony Larcher's avatar
Anthony Larcher committed
1482
    if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34"]:
1483
1484

        if model_name is None:
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1485
            model = Xtractor(speaker_number, model_yaml, loss=loss)
1486
1487
1488
1489

        else:
            logging.critical(f"*** Load model from = {model_name}")
            checkpoint = torch.load(model_name)
Anthony Larcher's avatar
Anthony Larcher committed
1490
            model = Xtractor(speaker_number, model_yaml)
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508

            """
            Here we remove all layers that we don't want to reload

            """
            pretrained_dict = checkpoint["model_state_dict"]
            for part in reset_parts:
                pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}

            new_model_dict = model.state_dict()
            new_model_dict.update(pretrained_dict)
            model.load_state_dict(new_model_dict)

        # Freeze required layers
        for name, param in model.named_parameters():
            if name.split(".")[0] in freeze_parts:
                param.requires_grad = False

Anthony Larcher's avatar
Anthony Larcher committed
1509
        model_archi = model_yaml
1510
1511

    # Here use a config file to build the architecture
Anthony Larcher's avatar
Anthony Larcher committed
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
    else:
        with open(model_yaml, 'r') as fh:
            model_archi = yaml.load(fh, Loader=yaml.FullLoader)
            if epochs is None:
                epochs = model_archi["training"]["epochs"]
            if patience is None:
                patience = model_archi["training"]["patience"]
            if opt is None:
                opt = model_archi["training"]["opt"]
            if lr is None:
                lr = model_archi["training"]["lr"]
            if loss is None:
                loss = model_archi["training"]["loss"]
            if aam_margin is None and model_archi["training"]["loss"] == "aam":
                aam_margin = model_archi["training"]["aam_margin"]
            if aam_s is None and model_archi["training"]["loss"] == "aam":
                aam_s = model_archi["training"]["aam_s"]
            if tmp_model_name is None:
                tmp_model_name = model_archi["training"]["tmp_model_name"]
            if best_model_name is None:
                best_model_name = model_archi["training"]["best_model_name"]
            if multi_gpu is None:
                multi_gpu = model_archi["training"]["multi_gpu"]
            if clipping is None:
                clipping = model_archi["training"]["clipping"]

Anthony Larcher's avatar
fix API    
Anthony Larcher committed
1538
        if model_name is None:
Anthony Larcher's avatar
Anthony Larcher committed
1539
1540
            model = Xtractor(speaker_number, model_yaml)

Anthony Larcher's avatar
Anthony Larcher committed
1541
        # If we start from an existing model
Anthony Larcher's avatar
Anthony Larcher committed
1542
        else:
Anthony Larcher's avatar
Anthony Larcher committed
1543
1544
            # Load the model
            logging.critical(f"*** Load model from = {model_name}")
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1545
1546
            checkpoint = torch.load(model_name, map_location=device)
            model = Xtractor(speaker_number, model_yaml, loss=loss)
Anthony Larcher's avatar
Anthony Larcher committed
1547

Anthony Larcher's avatar
fix API    
Anthony Larcher committed
1548
1549
            """
            Here we remove all layers that we don't want to reload
Anthony Larcher's avatar
Anthony Larcher committed
1550
        
Anthony Larcher's avatar
fix API    
Anthony Larcher committed
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
            """
            pretrained_dict = checkpoint["model_state_dict"]
            for part in reset_parts:
                pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}

            new_model_dict = model.state_dict()
            new_model_dict.update(pretrained_dict)
            model.load_state_dict(new_model_dict)

        # Freeze required layers
        for name, param in model.named_parameters():
            if name.split(".")[0] in freeze_parts:
                param.requires_grad = False
Anthony Larcher's avatar
Anthony Larcher committed
1564

Anthony Larcher's avatar
Anthony Larcher committed
1565
1566
1567
1568
    logging.critical(model)

    logging.critical("model_parameters_count: {:d}".format(
        sum(p.numel()
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1569
1570
1571
1572
            for p in model.sequence_network.parameters()
            if p.requires_grad) + \
        sum(p.numel()
            for p in model.before_speaker_embedding.parameters()
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1573
1574
1575
            if p.requires_grad) + \
        sum(p.numel()
            for p in model.stat_pooling.parameters()
Anthony Larcher's avatar
Anthony Larcher committed
1576
1577
            if p.requires_grad)))

Anthony Larcher's avatar
Anthony Larcher committed
1578
    embedding_size = model.embedding_size
Anthony Larcher's avatar
Anthony Larcher committed
1579

Anthony Larcher's avatar
Anthony Larcher committed
1580
    if torch.cuda.device_count() > 1 and multi_gpu:
Anthony Larcher's avatar
Anthony Larcher committed
1581
1582
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
Anthony Larcher's avatar
Anthony Larcher committed
1583

Anthony Larcher's avatar
Anthony Larcher committed
1584
1585
    else:
        print("Train on a single GPU")
Anthony Larcher's avatar
Anthony Larcher committed
1586
    model.to(device)
Anthony Larcher's avatar
Anthony Larcher committed
1587

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1588
1589
1590
1591
    with open(dataset_yaml, "r") as fh:
        dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
        df = pandas.read_csv(dataset_params["dataset_description"])

Anthony Larcher's avatar
Anthony Larcher committed
1592
1593
1594
1595
1596
1597
    """
    Set the dataloaders according to the dataset_yaml
    
    First we load the dataframe from CSV file in order to split it for training and validation purpose
    Then we provide those two 
    """
Anthony Larcher's avatar
Anthony Larcher committed
1598
    training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"] , stratify=df["speaker_idx"])
1599

Anthony Larcher's avatar
Anthony Larcher committed
1600
    torch.manual_seed(dataset_params['seed'])
Anthony Larcher's avatar
Anthony Larcher committed
1601

Anthony Larcher's avatar
Anthony Larcher committed
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
    training_set = SideSet(dataset_yaml,
                           set_type="train",
                           chunk_per_segment=-1,
                           overlap=dataset_params['train']['overlap'],
                           dataset_df=training_df,
                           output_format="pytorch",
                           )

    validation_set = SideSet(dataset_yaml,
                             set_type="validation",
Anthony Larcher's avatar
Anthony Larcher committed
1612
                             chunk_per_segment=1,
Anthony Larcher's avatar
Anthony Larcher committed
1613
1614
                             dataset_df=validation_df,
                             output_format="pytorch")
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1615

Anthony Larcher's avatar
Anthony Larcher committed
1616
1617
1618
    side_sampler = SideSampler(training_set.sessions['speaker_idx'],
                               speaker_number,
                               1,
Anthony Larcher's avatar
Anthony Larcher committed
1619
                               128,
Anthony Larcher's avatar
Anthony Larcher committed
1620
                               dataset_params["batch_size"])
Anthony Larcher's avatar
Anthony Larcher committed
1621

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1622
    training_loader = DataLoader(training_set,
Anthony Larcher's avatar
Anthony Larcher committed
1623
                                 batch_size=dataset_params["batch_size"],
Anthony Larcher's avatar
Anthony Larcher committed
1624
                                 shuffle=False,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1625
1626
                                 drop_last=True,
                                 pin_memory=True,
Anthony Larcher's avatar
Anthony Larcher committed
1627
                                 sampler=side_sampler,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1628
                                 num_workers=num_thread,
Anthony Larcher's avatar
Anthony Larcher committed
1629
                                 persistent_workers=True)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1630
1631

    validation_loader = DataLoader(validation_set,
Anthony Larcher's avatar
Anthony Larcher committed
1632
                                   batch_size=dataset_params["batch_size"],
Anthony Larcher's avatar
Anthony Larcher committed
1633
                                   drop_last=False,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1634
                                   pin_memory=True,
Anthony Larcher's avatar
spkset    
Anthony Larcher committed
1635
                                   num_workers=num_thread,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1636
                                   persistent_workers=False)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1637

Anthony Larcher's avatar
Anthony Larcher committed
1638

Anthony Larcher's avatar
Anthony Larcher committed
1639
1640
1641
    """
    Set the training options
    """
Anthony Larcher's avatar
Anthony Larcher committed
1642
    if opt == 'adam':
Anthony Larcher's avatar
Anthony Larcher committed
1643
        _optimizer = torch.optim.Adam
Anthony Larcher's avatar
Anthony Larcher committed
1644
        _options = {'lr': lr}
Anthony Larcher's avatar
Anthony Larcher committed
1645
1646
    elif opt == 'rmsprop':
        _optimizer = torch.optim.RMSprop
Anthony Larcher's avatar
Anthony Larcher committed
1647
        _options = {'lr': lr}
Anthony Larcher's avatar
Anthony Larcher committed
1648
    else: # opt == 'sgd'
Anthony Larcher's avatar
Anthony Larcher committed
1649
1650
        _optimizer = torch.optim.SGD
        _options = {'lr': lr, 'momentum': 0.9}
Anthony Larcher's avatar
Anthony Larcher committed
1651

Anthony Larcher's avatar
Anthony Larcher committed
1652
    param_list = []
Anthony Larcher's avatar
Anthony Larcher committed
1653
    if type(model) is Xtractor:
Anthony Larcher's avatar
Anthony Larcher committed
1654
1655
1656
1657
1658
1659
1660
        if model.preprocessor is not None:
            param_list.append({'params': model.preprocessor.parameters(), 'weight_decay': model.preprocessor_weight_decay})
        param_list.append({'params': model.sequence_network.parameters(), 'weight_decay': model.sequence_network_weight_decay})
        param_list.append({'params': model.stat_pooling.parameters(), 'weight_decay': model.stat_pooling_weight_decay})
        param_list.append({'params': model.before_speaker_embedding.parameters(), 'weight_decay': model.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.after_speaker_embedding.parameters(), 'weight_decay': model.after_speaker_embedding_weight_decay})

Anthony Larcher's avatar
Anthony Larcher committed
1661
    else:
Anthony Larcher's avatar
Anthony Larcher committed
1662
1663
1664
1665
1666
1667
1668
1669
        if model.module.preprocessor is not None:
            param_list.append({'params': model.module.preprocessor.parameters(), 'weight_decay': model.module.preprocessor_weight_decay})
        param_list.append({'params': model.module.sequence_network.parameters(), 'weight_decay': model.module.sequence_network_weight_decay})
        param_list.append({'params': model.module.stat_pooling.parameters(), 'weight_decay': model.module.stat_pooling_weight_decay})
        param_list.append({'params': model.module.before_speaker_embedding.parameters(), 'weight_decay': model.module.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.module.after_speaker_embedding.parameters(), 'weight_decay': model.module.after_speaker_embedding_weight_decay})

    optimizer = _optimizer(param_list, **_options)
Anthony Larcher's avatar
Anthony Larcher committed
1670
1671
1672
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5 * training_loader.__len__(),
                                                gamma=0.75)
1673

Anthony Larcher's avatar
Anthony Larcher committed
1674
    if mixed_precision:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1675
        scaler = torch.cuda.amp.GradScaler()
Anthony Larcher's avatar
Anthony Larcher committed
1676
1677
1678
    else:
        scaler = None

Anthony Larcher's avatar
sincxv    
Anthony Larcher committed
1679
    best_accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1680
    best_accuracy_epoch = 1
Anthony Larcher's avatar
spkset    
Anthony Larcher committed
1681
    best_eer = 100
Anthony Larcher's avatar
Anthony Larcher committed
1682
    curr_patience = patience
Anthony Larcher's avatar
Anthony Larcher committed
1683

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1684
    test_eer = 100.
Anthony Larcher's avatar
Anthony Larcher committed
1685

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1686
1687
1688
1689
1690
1691
1692
1693
    classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
    mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
    tar_indices = torch.tril(mask, -1).numpy()
    non_indices = torch.tril(~mask, -1).numpy()
    tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
    non_indices *= numpy.random.choice([False, True], size=non_indices.shape, p=[1-tar_non_ratio, tar_non_ratio])
    #tar_indices *= numpy.random.choice([False, True], size=tar_indices.shape, p=[0.9, 0.1])
    #non_indices *= numpy.random.choice([False, True], size=non_indices.shape, p=[0.9, 0.1])
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1694

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1695
    logging.critical("val tar count : {:d}, non count : {:d}".format(numpy.sum(tar_indices), numpy.sum(non_indices)))
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1696

Anthony Larcher's avatar
Anthony Larcher committed
1697
    for epoch in range(1, epochs + 1):
1698
        # Process one epoch and return the current model
Anthony Larcher's avatar
Anthony Larcher committed
1699
1700
1701
        if curr_patience == 0:
            print(f"Stopping at epoch {epoch} for cause of patience")
            break
Anthony Larcher's avatar
Anthony Larcher committed
1702
1703
1704
1705
        model = train_epoch(model,
                            epoch,
                            training_loader,
                            optimizer,
Anthony Larcher's avatar
Anthony Larcher committed
1706
                            scheduler,
Anthony Larcher's avatar
Anthony Larcher committed
1707
                            dataset_params["log_interval"],
Anthony Larcher's avatar
Anthony Larcher committed
1708
1709
                            device,
                            scaler=scaler,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1710
                            clipping=clipping)
1711
1712

        # Add the cross validation here
Anthony Larcher's avatar
test    
Anthony Larcher committed
1713
        if math.fmod(epoch, 1) == 0:
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1714
            val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], tar_indices, non_indices, mixed_precision)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1715
            logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
Anthony Larcher's avatar
Anthony Larcher committed
1716
1717

            if compute_test_eer:
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1718
1719
                test_eer = test_metrics(model, device, speaker_number, num_thread, mixed_precision)
                #logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Reversed Test EER = {rev_eer * 100} %")
Anthony Larcher's avatar
Anthony Larcher committed
1720
                logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %")
Anthony Larcher's avatar
Anthony Larcher committed
1721

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1722
            # remember best accuracy and save checkpoint
Anthony Larcher's avatar
Anthony Larcher committed
1723
1724
1725
1726
1727
1728
1729
            if compute_test_eer:
                is_best = test_eer < best_eer
                best_eer = min(test_eer, best_eer)
            else:
                is_best = val_eer < best_eer
                best_eer = min(val_eer, best_eer)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1730
            best_accuracy = max(val_acc, best_accuracy)
Anthony Larcher's avatar
Anthony Larcher committed
1731

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
            if tmp_model_name is None:
                tmp_model_name = "tmp_model"
            if best_model_name is None:
                best_model_name = "best_model"

            if type(model) is Xtractor:
                save_checkpoint({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'accuracy': best_accuracy,
                    'scheduler': scheduler,
                    'speaker_number' : speaker_number,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1745
1746
                    'model_archi': model_archi,
                    'loss': loss
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1747
1748
1749
1750
1751
1752
1753
1754
1755
                }, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
            else:
                save_checkpoint({
                    'epoch': epoch,
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'accuracy': best_accuracy,
                    'scheduler': scheduler,
                    'speaker_number': speaker_number,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1756
1757
                    'model_archi': model_archi,
                    'loss': loss
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1758
1759
1760
1761
1762
1763
1764
                }, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')

            if is_best:
                best_accuracy_epoch = epoch
                curr_patience = patience
            else:
                curr_patience -= 1
1765

1766
1767
1768
    for ii in range(torch.cuda.device_count()):
        print(torch.cuda.memory_summary(ii))

Anthony Larcher's avatar
Anthony Larcher committed
1769
    logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
1770

Anthony Larcher's avatar
Anthony Larcher committed
1771

Anthony Larcher's avatar
Anthony Larcher committed
1772
def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interval, device, scaler=None, clipping=False):
1773
1774
1775
1776
    """

    :param model:
    :param epoch:
Anthony Larcher's avatar
Anthony Larcher committed
1777
    :param training_loader:
1778
    :param optimizer:
Anthony Larcher's avatar
Anthony Larcher committed
1779
1780
1781
    :param log_interval:
    :param device:
    :param clipping:
1782
1783
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
1784
    model.train()
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1785
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
1786

Anthony Larcher's avatar
Anthony Larcher committed
1787
1788
1789
1790
1791
    if isinstance(model, Xtractor):
        loss_criteria = model.loss
    else:
        loss_criteria = model.module.loss

1792
    accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1793
    running_loss = 0.0
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1794
    for batch_idx, (data, target) in enumerate(training_loader):
Anthony Larcher's avatar
Anthony Larcher committed
1795

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1796
        data = data.squeeze().to(device)
1797
        target = target.squeeze()
Anthony Larcher's avatar
Anthony Larcher committed
1798
        target = target.to(device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1799
        optimizer.zero_grad(set_to_none=True)
Anthony Larcher's avatar
Anthony Larcher committed
1800

Anthony Larcher's avatar
Anthony Larcher committed
1801
        if scaler is not None:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1802
            with torch.cuda.amp.autocast():
Anthony Larcher's avatar
Anthony Larcher committed
1803
1804
                if loss_criteria == 'aam':
                    output, _ = model(data, target=target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1805
                    loss = criterion(output, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1806
1807
1808
1809
                elif loss_criteria == 'smn':
                    output_tuple, _ = model(data, target=target)
                    loss, output = output_tuple
                    loss += criterion(output, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1810
1811
1812
1813
                elif loss_criteria == 'aps':
                    output_tuple, _ = model(data, target=target)
                    cos_sim_matx, output = output_tuple
                    loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
Anthony Larcher's avatar
Anthony Larcher committed
1814
1815
                else:
                    output, _ = model(data, target=None)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1816
                    loss = criterion(output, target)
Anthony Larcher's avatar
Anthony Larcher committed
1817
        else:
1818
1819
            if loss_criteria == 'aam':
                output, _ = model(data, target=target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1820
1821
1822
1823
1824
                loss = criterion(output, target)
            elif loss_criteria == 'aps':
                output_tuple, _ = model(data, target=target)
                cos_sim_matx, output = output_tuple
                loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
1825
            else:
Anthony Larcher's avatar
Anthony Larcher committed
1826
                output, _ = model(data, target=None)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1827
                loss = criterion(output, target)
Anthony Larcher's avatar
Anthony Larcher committed
1828
        if not torch.isnan(loss):
Anthony Larcher's avatar
Anthony Larcher committed
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
            if scaler is not None:
                scaler.scale(loss).backward()
                if clipping:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
Anthony Larcher's avatar
Anthony Larcher committed
1839
            running_loss += loss.item()
Anthony Larcher's avatar
Anthony Larcher committed
1840
            accuracy += (torch.argmax(output.data, 1) == target).sum()
Anthony Larcher's avatar
Anthony Larcher committed
1841
1842

            if math.fmod(batch_idx, log_interval) == 0:
Anthony Larcher's avatar
Anthony Larcher committed
1843
                batch_size = target.shape[0]
Anthony Larcher's avatar
Anthony Larcher committed
1844
                logging.critical('{}, Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1845
                    time.strftime('%H:%M:%S', time.localtime()),
Anthony Larcher's avatar
Anthony Larcher committed
1846
1847
1848
1849
                    epoch, batch_idx + 1, training_loader.__len__(),
                    100. * batch_idx / training_loader.__len__(), loss.item(),
                    100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))

1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
        else:
            save_checkpoint({
                             'epoch': epoch,
                             'model_state_dict': model.state_dict(),
                             'optimizer_state_dict': optimizer.state_dict(),
                             'accuracy': 0.0,
                             'scheduler': 0.0
                             }, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
            with open("batch_loss_NAN.pkl", "wb") as fh:
                pickle.dump(data.cpu(), fh)
            import sys
            sys.exit()
        running_loss = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1863
        scheduler.step()
1864
1865
1866
    return model


Anthony Larcher's avatar
Anthony Larcher committed
1867
def new_train_epoch(model,
Anthony Larcher's avatar
Anthony Larcher committed
1868
                    training_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
                    training_monitor,
                    training_loader,
                    optimizer,
                    scheduler,
                    device,
                    scaler=None,
                    clipping=False):
    """

    :param model:
Anthony Larcher's avatar
Anthony Larcher committed
1879
    :param training_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
    :param training_monitor:
    :param training_loader:
    :param optimizer:
    :param scheduler:
    :param device:
    :param scaler:
    :param clipping:
    :return:
    """
    model.train()
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    if isinstance(model, Xtractor):
        loss_criteria = model.loss
    else:
        loss_criteria = model.module.loss

    accuracy = 0.0
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(training_loader):
        data = data.squeeze().to(device)
        target = target.squeeze()
        target = target.to(device)
        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                if loss_criteria == 'aam':
                    output, _ = model(data, target=target)
                    loss = criterion(output, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1910
1911
1912
1913
                elif loss_criteria == 'smn':
                    output_tuple, _ = model(data, target=target)
                    loss, output = output_tuple
                    loss += criterion(output, target)
Anthony Larcher's avatar
Anthony Larcher committed
1914
1915
                elif loss_criteria == 'aps':
                    output_tuple, _ = model(data, target=target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1916
                    loss, output = output_tuple
Anthony Larcher's avatar
Anthony Larcher committed
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
                else:
                    output, _ = model(data, target=None)
                    loss = criterion(output, target)
        else:
            if loss_criteria == 'aam':
                output, _ = model(data, target=target)
                loss = criterion(output, target)
            elif loss_criteria == 'aps':
                output_tuple, _ = model(data, target=target)
                cos_sim_matx, output = output_tuple
                loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
            else:
                output, _ = model(data, target=None)
                loss = criterion(output, target)

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1932
1933
        #if not torch.isnan(loss):
        if True:
Anthony Larcher's avatar
Anthony Larcher committed
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
            if scaler is not None:
                scaler.scale(loss).backward()
                if clipping:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            running_loss += loss.item()
            accuracy += (torch.argmax(output.data, 1) == target).sum()

Anthony Larcher's avatar
Anthony Larcher committed
1947
            if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
Anthony Larcher's avatar
Anthony Larcher committed
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
                batch_size = target.shape[0]
                training_monitor.update(training_loss=loss.item(),
                                        training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))

                training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                    training_monitor.current_epoch,
                    batch_idx + 1,
                    training_loader.__len__(),
                    100. * batch_idx / training_loader.__len__(),
                    loss.item(),
                    100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
        #else:
        #    save_checkpoint({
        #        'epoch': training_monitor.current_epoch,
        #        'model_state_dict': model.state_dict(),
        #        'optimizer_state_dict': optimizer.state_dict(),
        #        'accuracy': 0.0,
        #        'scheduler': 0.0
        #    }, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
        #    with open("batch_loss_NAN.pkl", "wb") as fh:
        #        pickle.dump(data.cpu(), fh)
        #    import sys
        #    sys.exit()
Anthony Larcher's avatar
Anthony Larcher committed
1972

Anthony Larcher's avatar
Anthony Larcher committed
1973
        running_loss = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1974
1975
1976
1977
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(training_monitor.best_eer)
        else:
            scheduler.step()
Anthony Larcher's avatar
Anthony Larcher committed
1978
1979
    return model

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1980
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
1981
1982
1983
    """

    :param model:
Anthony Larcher's avatar
Anthony Larcher committed
1984
1985
    :param validation_loader:
    :param device:
1986
    :param validation_shape:
1987
1988
1989
    :return:
    """
    model.eval()
Anthony Larcher's avatar
Anthony Larcher committed
1990
1991
1992
1993
1994
    if isinstance(model, Xtractor):
        loss_criteria = model.loss
    else:
        loss_criteria = model.module.loss

1995
    accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1996
    loss = 0.0
1997
    criterion = torch.nn.CrossEntropyLoss()
Anthony Larcher's avatar
Anthony Larcher committed
1998
    embeddings = torch.zeros(validation_shape)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1999
    cursor = 0
Anthony Larcher's avatar
Anthony Larcher committed
2000
    with torch.no_grad():