xvector.py 72.7 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1001

Anthony Larcher's avatar
Anthony Larcher committed
1002
1003
1004
1005
    # Load the model if it exists
    if model_opts["initial_model_name"] is not None and os.path.isfile(model_opts["initial_model_name"]):
        logging.critical(f"*** Load model from = {model_opts['initial_model_name']}")
        checkpoint = torch.load(model_opts["initial_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1006

Anthony Larcher's avatar
Anthony Larcher committed
1007
1008
        """
        Here we remove all layers that we don't want to reload
Anthony Larcher's avatar
Anthony Larcher committed
1009

Anthony Larcher's avatar
Anthony Larcher committed
1010
1011
1012
1013
        """
        pretrained_dict = checkpoint["model_state_dict"]
        for part in model_opts["reset_parts"]:
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
Anthony Larcher's avatar
Anthony Larcher committed
1014

Anthony Larcher's avatar
Anthony Larcher committed
1015
1016
1017
        new_model_dict = model.state_dict()
        new_model_dict.update(pretrained_dict)
        model.load_state_dict(new_model_dict)
Anthony Larcher's avatar
Anthony Larcher committed
1018
1019
1020

        # Freeze required layers
        for name, param in model.named_parameters():
Anthony Larcher's avatar
Anthony Larcher committed
1021
            if name.split(".")[0] in model_opts["reset_parts"]:
Anthony Larcher's avatar
Anthony Larcher committed
1022
1023
                param.requires_grad = False

Anthony Larcher's avatar
debug    
Anthony Larcher committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    if local_rank < 1:
        logging.info(model)
        logging.info("Model_parameters_count: {:d}".format(
            sum(p.numel()
                for p in model.sequence_network.parameters()
                if p.requires_grad) + \
            sum(p.numel()
                for p in model.before_speaker_embedding.parameters()
                if p.requires_grad) + \
            sum(p.numel()
                for p in model.stat_pooling.parameters()
                if p.requires_grad)))
Anthony Larcher's avatar
Anthony Larcher committed
1036
1037
1038
1039

    return model


Anthony Larcher's avatar
Anthony Larcher committed
1040
def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
Anthony Larcher's avatar
Anthony Larcher committed
1041
1042
    """

Anthony Larcher's avatar
Anthony Larcher committed
1043
1044
1045
    :param dataset_opts:
    :param training_opts:
    :param model_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1046
1047
1048
1049
1050
1051
1052
    :return:
    """

    """
    Set the dataloaders according to the dataset_yaml
    
    First we load the dataframe from CSV file in order to split it for training and validation purpose
Anthony Larcher's avatar
Anthony Larcher committed
1053
    Then we provide those two
Anthony Larcher's avatar
Anthony Larcher committed
1054
    """
Anthony Larcher's avatar
Anthony Larcher committed
1055
    df = pandas.read_csv(dataset_opts["dataset_csv"])
Anthony Larcher's avatar
Anthony Larcher committed
1056

Anthony Larcher's avatar
Anthony Larcher committed
1057
1058
1059
    stratify = None
    if dataset_opts["stratify"]:
        stratify = df["speaker_idx"]
Anthony Larcher's avatar
Anthony Larcher committed
1060
1061
    training_df, validation_df = train_test_split(df,
                                                  test_size=dataset_opts["validation_ratio"],
Anthony Larcher's avatar
Anthony Larcher committed
1062
                                                  stratify=stratify)
Anthony Larcher's avatar
Anthony Larcher committed
1063

Anthony Larcher's avatar
extract    
Anthony Larcher committed
1064
1065
1066
    torch.manual_seed(training_opts['torch_seed'] + local_rank)
    torch.cuda.manual_seed(training_opts['torch_seed'] + local_rank)

Anthony Larcher's avatar
Anthony Larcher committed
1067
    training_set = SideSet(dataset_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1068
1069
                           set_type="train",
                           chunk_per_segment=-1,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1070
                           transform_number=dataset_opts['train']['transform_number'],
Anthony Larcher's avatar
Anthony Larcher committed
1071
                           overlap=dataset_opts['train']['overlap'],
Anthony Larcher's avatar
Anthony Larcher committed
1072
1073
1074
1075
                           dataset_df=training_df,
                           output_format="pytorch",
                           )

Anthony Larcher's avatar
Anthony Larcher committed
1076
    validation_set = SideSet(dataset_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1077
1078
1079
1080
                             set_type="validation",
                             dataset_df=validation_df,
                             output_format="pytorch")

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1081
1082
1083
1084
1085
1086
1087
1088
    if model_opts["loss"]["type"] == 'aps':
        samples_per_speaker = 2
    else:
        samples_per_speaker = 1

    if training_opts["multi_gpu"]:
        assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
        assert dataset_opts["batch_size"] % samples_per_speaker == 0
Anthony Larcher's avatar
Anthony Larcher committed
1089
        batch_size = dataset_opts["batch_size"]//(torch.cuda.device_count() * dataset_opts["train"]["sampler"]["examples_per_speaker"])
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1090

Anthony Larcher's avatar
Anthony Larcher committed
1091
1092
1093
1094
1095
1096
1097
        side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
                                   spk_count=model_opts["speaker_number"],
                                   examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
                                   samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
                                   batch_size=batch_size,
                                   seed=training_opts['torch_seed'],
                                   rank=local_rank,
Anthony Larcher's avatar
Anthony Larcher committed
1098
1099
                                   num_process=torch.cuda.device_count(),
                                   num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
Anthony Larcher's avatar
Anthony Larcher committed
1100
                                   )
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1101
    else:
Anthony Larcher's avatar
Anthony Larcher committed
1102
        batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"]
Anthony Larcher's avatar
Anthony Larcher committed
1103
1104
1105
1106
1107
1108
1109
1110
        side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
                                   spk_count=model_opts["speaker_number"],
                                   examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
                                   samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
                                   batch_size=batch_size,
                                   seed=training_opts['torch_seed'],
                                   rank=0,
                                   num_process=torch.cuda.device_count(),
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1111
                                   num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
Anthony Larcher's avatar
Anthony Larcher committed
1112
                                   )
Anthony Larcher's avatar
Anthony Larcher committed
1113
1114

    training_loader = DataLoader(training_set,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1115
                                 batch_size=batch_size * dataset_opts["train"]["sampler"]["augmentation_replica"],
Anthony Larcher's avatar
Anthony Larcher committed
1116
1117
1118
1119
                                 shuffle=False,
                                 drop_last=True,
                                 pin_memory=True,
                                 sampler=side_sampler,
Anthony Larcher's avatar
Anthony Larcher committed
1120
                                 num_workers=training_opts["num_cpu"],
Anthony Larcher's avatar
Anthony Larcher committed
1121
1122
                                 persistent_workers=False,
                                 worker_init_fn=seed_worker)
Anthony Larcher's avatar
Anthony Larcher committed
1123
1124

    validation_loader = DataLoader(validation_set,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1125
                                   batch_size=batch_size,
Anthony Larcher's avatar
Anthony Larcher committed
1126
1127
                                   drop_last=False,
                                   pin_memory=True,
Anthony Larcher's avatar
Anthony Larcher committed
1128
                                   num_workers=training_opts["num_cpu"],
Anthony Larcher's avatar
Anthony Larcher committed
1129
1130
                                   persistent_workers=False,
                                   worker_init_fn=seed_worker)
Anthony Larcher's avatar
Anthony Larcher committed
1131

Anthony Larcher's avatar
Anthony Larcher committed
1132
1133
1134
1135
1136
    # Compute indices for target and non-target trials once only to avoid recomputing for each epoch
    classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
    mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
    tar_indices = torch.tril(mask, -1).numpy()
    non_indices = torch.tril(~mask, -1).numpy()
Anthony Larcher's avatar
Anthony Larcher committed
1137

Anthony Larcher's avatar
Anthony Larcher committed
1138
1139
    # Select a subset of non-target trials to reduce the number of tests
    tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
Anthony Larcher's avatar
Anthony Larcher committed
1140
    non_indices *= (numpy.random.rand(*non_indices.shape) < tar_non_ratio)
Anthony Larcher's avatar
Anthony Larcher committed
1141

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1142
    return training_loader, validation_loader, side_sampler, tar_indices, non_indices
Anthony Larcher's avatar
Anthony Larcher committed
1143
1144


Anthony Larcher's avatar
debug    
Anthony Larcher committed
1145
def get_optimizer(model, model_opts, train_opts, training_loader):
Anthony Larcher's avatar
Anthony Larcher committed
1146
1147
1148
    """

    :param model:
Anthony Larcher's avatar
Anthony Larcher committed
1149
1150
1151
    :param model_opts:
    :param train_opts:
    :param training_loader:
Anthony Larcher's avatar
Anthony Larcher committed
1152
1153
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
1154
    if train_opts["optimizer"]["type"] == 'adam':
Anthony Larcher's avatar
Anthony Larcher committed
1155
        _optimizer = torch.optim.Adam
Anthony Larcher's avatar
Anthony Larcher committed
1156
        _options = {'lr': train_opts["lr"]}
Anthony Larcher's avatar
Anthony Larcher committed
1157
    elif train_opts["optimizer"]["type"] == 'rmsprop':
Anthony Larcher's avatar
Anthony Larcher committed
1158
        _optimizer = torch.optim.RMSprop
Anthony Larcher's avatar
Anthony Larcher committed
1159
1160
        _options = {'lr': train_opts["lr"]}
    else:  # train_opts["optimizer"]["type"] == 'sgd'
Anthony Larcher's avatar
Anthony Larcher committed
1161
        _optimizer = torch.optim.SGD
Anthony Larcher's avatar
Anthony Larcher committed
1162
        _options = {'lr': train_opts["lr"], 'momentum': 0.9}
Anthony Larcher's avatar
Anthony Larcher committed
1163
1164
1165
1166

    param_list = []
    if type(model) is Xtractor:
        if model.preprocessor is not None:
Anthony Larcher's avatar
Anthony Larcher committed
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
            param_list.append({'params': model.preprocessor.parameters(),
                               'weight_decay': model.preprocessor_weight_decay})
        param_list.append({'params': model.sequence_network.parameters(),
                           'weight_decay': model.sequence_network_weight_decay})
        param_list.append({'params': model.stat_pooling.parameters(),
                           'weight_decay': model.stat_pooling_weight_decay})
        param_list.append({'params': model.before_speaker_embedding.parameters(),
                           'weight_decay': model.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.after_speaker_embedding.parameters(),
                           'weight_decay': model.after_speaker_embedding_weight_decay})
Anthony Larcher's avatar
Anthony Larcher committed
1177
1178
1179

    else:
        if model.module.preprocessor is not None:
Anthony Larcher's avatar
Anthony Larcher committed
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
            param_list.append({'params': model.module.preprocessor.parameters(),
                               'weight_decay': model.module.preprocessor_weight_decay})
        param_list.append({'params': model.module.sequence_network.parameters(),
                           'weight_decay': model.module.sequence_network_weight_decay})
        param_list.append({'params': model.module.stat_pooling.parameters(),
                           'weight_decay': model.module.stat_pooling_weight_decay})
        param_list.append({'params': model.module.before_speaker_embedding.parameters(),
                           'weight_decay': model.module.before_speaker_embedding_weight_decay})
        param_list.append({'params': model.module.after_speaker_embedding.parameters(),
                           'weight_decay': model.module.after_speaker_embedding_weight_decay})
Anthony Larcher's avatar
Anthony Larcher committed
1190
1191
1192

    optimizer = _optimizer(param_list, **_options)

Anthony Larcher's avatar
Anthony Larcher committed
1193
    if train_opts["scheduler"]["type"] == 'CyclicLR':
Anthony Larcher's avatar
Anthony Larcher committed
1194
        cycle_momentum = True
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1195
        if train_opts["optimizer"]["type"] == "adam":
Anthony Larcher's avatar
Anthony Larcher committed
1196
            cycle_momentum = False
Anthony Larcher's avatar
Anthony Larcher committed
1197
1198
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
                                                      base_lr=1e-8,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1199
                                                      max_lr=train_opts["lr"],
Anthony Larcher's avatar
Anthony Larcher committed
1200
                                                      step_size_up=model_opts["speaker_number"] * 8,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1201
                                                      step_size_down=None,
Anthony Larcher's avatar
Anthony Larcher committed
1202
                                                      cycle_momentum=cycle_momentum,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1203
                                                      mode="triangular2")
Anthony Larcher's avatar
Anthony Larcher committed
1204
    elif train_opts["scheduler"]["type"] == "MultiStepLR":
Anthony Larcher's avatar
Anthony Larcher committed
1205
1206
1207
1208
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                         milestones=[10000,50000,100000],
                                                         gamma=0.5)

Anthony Larcher's avatar
Anthony Larcher committed
1209
    elif train_opts["scheduler"]["type"] == "StepLR":
Anthony Larcher's avatar
Anthony Larcher committed
1210
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
Anthony Larcher's avatar
Anthony Larcher committed
1211
                                                           step_size=1 * training_loader.__len__(),
Anthony Larcher's avatar
Anthony Larcher committed
1212
1213
                                                           gamma=0.95)

Anthony Larcher's avatar
Anthony Larcher committed
1214
    elif train_opts["scheduler"]["type"] == "StepLR2":
Anthony Larcher's avatar
Anthony Larcher committed
1215
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
Anthony Larcher's avatar
Anthony Larcher committed
1216
                                                           step_size=1 * training_loader.__len__(),
Anthony Larcher's avatar
Anthony Larcher committed
1217
1218
                                                           gamma=0.5)
    else:
Anthony Larcher's avatar
Anthony Larcher committed
1219
1220
1221
1222
1223
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                               mode='min',
                                                               factor=0.5,
                                                               patience=3000,
                                                               verbose=True)
Anthony Larcher's avatar
Anthony Larcher committed
1224
1225
1226
1227

    return optimizer, scheduler


Anthony Larcher's avatar
Anthony Larcher committed
1228
def save_model(model, training_monitor, model_opts, training_opts, optimizer, scheduler):
Anthony Larcher's avatar
Anthony Larcher committed
1229
1230
1231
1232
1233
    """

    :param model:
    :param training_monitor:
    :param model_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1234
1235
    :param training_opts:
    :param optimizer:
Anthony Larcher's avatar
Anthony Larcher committed
1236
1237
1238
1239
1240
1241
    :param scheduler:
    :return:
    """
    # TODO à reprendre
    if type(model) is Xtractor:
        save_checkpoint({
Anthony Larcher's avatar
Anthony Larcher committed
1242
            'epoch': training_monitor.current_epoch,
Anthony Larcher's avatar
Anthony Larcher committed
1243
1244
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
Anthony Larcher's avatar
Anthony Larcher committed
1245
            'accuracy': training_monitor.best_accuracy,
Anthony Larcher's avatar
Anthony Larcher committed
1246
            'scheduler': scheduler,
Anthony Larcher's avatar
Anthony Larcher committed
1247
1248
            'speaker_number' : model.speaker_number,
            'model_archi': model_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1249
1250
            'loss': model_opts["loss"]["type"]
        }, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1251
1252
    else:
        save_checkpoint({
Anthony Larcher's avatar
Anthony Larcher committed
1253
            'epoch': training_monitor.current_epoch,
Anthony Larcher's avatar
Anthony Larcher committed
1254
1255
            'model_state_dict': model.module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
Anthony Larcher's avatar
Anthony Larcher committed
1256
            'accuracy': training_monitor.best_accuracy,
Anthony Larcher's avatar
Anthony Larcher committed
1257
            'scheduler': scheduler,
Anthony Larcher's avatar
Anthony Larcher committed
1258
1259
            'speaker_number': model.module.speaker_number,
            'model_archi': model_opts,
Anthony Larcher's avatar
Anthony Larcher committed
1260
1261
            'loss': model_opts["loss"]["type"]
        }, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
Anthony Larcher's avatar
Anthony Larcher committed
1262
1263


Anthony Larcher's avatar
Anthony Larcher committed
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
class AAMScheduler():
    """
    For now we only update margin
    """
    def __init__(self, original_margin, final_margin, final_steps_nb, update_frequency, mode='lin', Tau=1, verbose=True):
        """

        :param final_margin:
        :param num_epochs:
        :param mode: can be linear or exp
        :param verbose:
        """
        self.current_margin = original_margin
        self.original_margin = original_margin
        self.final_margin = final_margin
        self.final_steps_nb = final_steps_nb
        self.update_frequency = update_frequency
        self.mode = mode
        self.Tau = Tau
        self.verbose = verbose
        self._counter = 0

    def __step__(self):
        self._counter += 1

        if self._counter % self.update_frequency == 0:
            # update the parameters
            if self.mode == "lin":
                self.current_margin = self.original_margin + \
                                      (self.final_margin - self.original_margin) * \
                                      (self._counter / self.final_steps_nb)
            else:
                self.current_margin = self.original_margin + \
                                      (self.final_margin - self.original_margin) * \
                                      (1 - numpy.exp(-self._counter / (self.final_steps_nb/7)))

        return self.current_margin


Anthony Larcher's avatar
Anthony Larcher committed
1303
1304
1305
1306
def xtrain(dataset_description,
           model_description,
           training_description,
           **kwargs):
Anthony Larcher's avatar
Anthony Larcher committed
1307
1308
    """
    REFACTORING
Anthony Larcher's avatar
Anthony Larcher committed
1309
    - en cas de redemarrage à partir d'un modele existant, recharger l'optimize et le scheduler
Anthony Larcher's avatar
Anthony Larcher committed
1310
    """
Anthony Larcher's avatar
Anthony Larcher committed
1311
1312
1313

    local_rank = -1
    if "RANK" in os.environ:
Anthony Larcher's avatar
Anthony Larcher committed
1314
        local_rank = int(os.environ['RANK'])
Anthony Larcher's avatar
Anthony Larcher committed
1315

Anthony Larcher's avatar
Anthony Larcher committed
1316
1317
1318
    # Test to optimize
    torch.autograd.profiler.emit_nvtx(enabled=False)

Anthony Larcher's avatar
Anthony Larcher committed
1319
1320
    dataset_opts, model_opts, training_opts = update_training_dictionary(dataset_description,
                                                                         model_description,
Anthony Larcher's avatar
Anthony Larcher committed
1321
                                                                         training_description,
Anthony Larcher's avatar
Anthony Larcher committed
1322
1323
                                                                         kwargs)

Anthony Larcher's avatar
Anthony Larcher committed
1324
1325
1326
1327
1328
1329
1330
1331
    # Initialize the training monitor
    monitor = TrainingMonitor(output_file=training_opts["log_file"],
                              patience=training_opts["patience"],
                              best_accuracy=0.0,
                              best_eer_epoch=1,
                              best_eer=100,
                              compute_test_eer=training_opts["compute_test_eer"])

Anthony Larcher's avatar
Anthony Larcher committed
1332
1333
1334
1335
    # Make PyTorch Deterministic
    torch.backends.cudnn.deterministic = False
    if training_opts["deterministic"]:
        torch.backends.cudnn.deterministic = True
Anthony Larcher's avatar
Anthony Larcher committed
1336
1337

    # Set all the seeds
Anthony Larcher's avatar
Anthony Larcher committed
1338
1339
1340
    numpy.random.seed(training_opts["numpy_seed"]) # Set the random seed of numpy for the data split.
    torch.manual_seed(training_opts["torch_seed"])
    torch.cuda.manual_seed(training_opts["torch_seed"])
Anthony Larcher's avatar
Anthony Larcher committed
1341

Anthony Larcher's avatar
Anthony Larcher committed
1342
    # Display the entire configurations as YAML dictionaries
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1343
    if local_rank < 1:
Anthony Larcher's avatar
Anthony Larcher committed
1344
1345
1346
1347
1348
1349
        monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
        monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
        monitor.logger.info("\n*********************************\nModel options\n*********************************\n")
        monitor.logger.info(yaml.dump(model_opts, default_flow_style=False))
        monitor.logger.info("\n*********************************\nTraining options\n*********************************\n")
        monitor.logger.info(yaml.dump(training_opts, default_flow_style=False))
Anthony Larcher's avatar
Anthony Larcher committed
1350
1351

    # Initialize the model
Anthony Larcher's avatar
Anthony Larcher committed
1352
    model = get_network(model_opts, local_rank)
Anthony Larcher's avatar
Anthony Larcher committed
1353
    embedding_size = model.embedding_size
Anthony Larcher's avatar
Anthony Larcher committed
1354
    aam_scheduler = None
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1355
1356
1357
1358
1359
1360
1361
1362
    #if model.loss == "aam":
    #    aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"],
    #                                 final_margin=0.5,
    #                                 final_steps_nb=120000,
    #                                 update_frequency=25000,
    #                                 mode='exp',
    #                                 Tau=1,
    #                                 verbose=True)
Anthony Larcher's avatar
Anthony Larcher committed
1363

Anthony Larcher's avatar
Anthony Larcher committed
1364
    # Set the device and manage parallel processing
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1365
    torch.cuda.set_device(local_rank)
Anthony Larcher's avatar
Anthony Larcher committed
1366
    device = torch.device(local_rank)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1367
1368
1369
1370

    if training_opts["multi_gpu"]:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

Anthony Larcher's avatar
Anthony Larcher committed
1371
1372
    model.to(device)

Anthony Larcher's avatar
Anthony Larcher committed
1373

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
    """ [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
        - Add the following line right after "if __name__ == '__main__':" in your main script :
        parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
        - Then, in your shell :
        export NUM_NODES=1
        export NUM_GPUS_PER_NODE=2
        export NODE_RANK=0
        export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
        python -m torch.distributed.launch \
            --nproc_per_node=$NUM_GPUS_PER_NODE \
            --nnodes=$NUM_NODES \
            --node_rank $NODE_RANK \
            train_xvector.py ...
    """
    if training_opts["multi_gpu"]:
        if local_rank < 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
Anthony Larcher's avatar
Anthony Larcher committed
1392
1393
1394
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[local_rank],
                                                          output_device=local_rank)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1395
1396
1397
    else:
        print("Train on a single GPU")

Anthony Larcher's avatar
Anthony Larcher committed
1398
    # Initialise data loaders
Anthony Larcher's avatar
Anthony Larcher committed
1399
    training_loader, validation_loader,\
Anthony Larcher's avatar
Anthony Larcher committed
1400
1401
    sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
                                                                          training_opts,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1402
                                                                          model_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1403

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1404
1405
1406
1407
    if local_rank < 1:
        monitor.logger.info(f"Start training process")
        monitor.logger.info(f"Use \t{torch.cuda.device_count()} \tgpus")
        monitor.logger.info(f"Use \t{training_opts['num_cpu']} \tcpus")
Anthony Larcher's avatar
Anthony Larcher committed
1408

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1409
1410
1411
        monitor.logger.info(f"Validation EER will be measured using")
        monitor.logger.info(f"\t {numpy.sum(validation_tar_indices)} target trials and")
        monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials")
Anthony Larcher's avatar
Anthony Larcher committed
1412
1413

    # Create optimizer and scheduler
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1414
    optimizer, scheduler = get_optimizer(model, model_opts, training_opts, training_loader)
Anthony Larcher's avatar
Anthony Larcher committed
1415
1416
1417
1418
1419

    scaler = None
    if training_opts["mixed_precision"]:
        scaler = torch.cuda.amp.GradScaler()

Anthony Larcher's avatar
Anthony Larcher committed
1420
    for epoch in range(1, training_opts["epochs"] + 1):
Anthony Larcher's avatar
Anthony Larcher committed
1421

Anthony Larcher's avatar
Anthony Larcher committed
1422
        monitor.update(epoch=epoch)
Anthony Larcher's avatar
Anthony Larcher committed
1423

Anthony Larcher's avatar
Anthony Larcher committed
1424
        # Process one epoch and return the current model
Anthony Larcher's avatar
Anthony Larcher committed
1425
        if monitor.current_patience == 0:
Anthony Larcher's avatar
Anthony Larcher committed
1426
1427
1428
            print(f"Stopping at epoch {epoch} for cause of patience")
            break

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1429
        sampler.set_epoch(epoch)
Anthony Larcher's avatar
Anthony Larcher committed
1430
1431
        if training_opts["multi_gpu"]:
            torch.distributed.barrier()
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1432

Anthony Larcher's avatar
Anthony Larcher committed
1433
1434
1435
1436
1437
1438
1439
        model = train_epoch(model,
                            training_opts,
                            monitor,
                            training_loader,
                            optimizer,
                            scheduler,
                            device,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1440
1441
                            scaler=scaler)
        #                    aam_scheduler=aam_scheduler)
Anthony Larcher's avatar
Anthony Larcher committed
1442
1443
1444
1445
1446
1447

        # Cross validation
        if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
            val_acc, val_loss, val_eer = cross_validation(model,
                                                          validation_loader,
                                                          device,
Anthony Larcher's avatar
Anthony Larcher committed
1448
                                                          [validation_loader.dataset.__len__(), embedding_size],
Anthony Larcher's avatar
Anthony Larcher committed
1449
1450
                                                          validation_tar_indices,
                                                          validation_non_indices,
Anthony Larcher's avatar
Anthony Larcher committed
1451
1452
1453
                                                          training_opts["mixed_precision"])

            test_eer = None
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1454
            if training_opts["compute_test_eer"] and local_rank < 1:
Anthony Larcher's avatar
Anthony Larcher committed
1455
                test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts)
Anthony Larcher's avatar
Anthony Larcher committed
1456

Anthony Larcher's avatar
Anthony Larcher committed
1457
            monitor.update(test_eer=test_eer,
Anthony Larcher's avatar
Anthony Larcher committed
1458
1459
1460
1461
                           val_eer=val_eer,
                           val_loss=val_loss,
                           val_acc=val_acc)

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1462
1463
            if local_rank < 1:
                monitor.display()
Anthony Larcher's avatar
Anthony Larcher committed
1464

Anthony Larcher's avatar
Anthony Larcher committed
1465
1466
            # Save the current model and if needed update the best one
            # TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1467
1468
            if local_rank < 1:
                save_model(model, monitor, model_opts, training_opts, optimizer, scheduler)
Anthony Larcher's avatar
Anthony Larcher committed
1469
1470

    for ii in range(torch.cuda.device_count()):
Anthony Larcher's avatar
Anthony Larcher committed
1471
        monitor.logger.info(torch.cuda.memory_summary(ii))
Anthony Larcher's avatar
Anthony Larcher committed
1472
1473

    # TODO gérer l'affichage en utilisant le training_monitor
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1474
1475
    if local_rank < 1:
        monitor.display_final()
Anthony Larcher's avatar
Anthony Larcher committed
1476

Anthony Larcher's avatar
Anthony Larcher committed
1477
1478
    return monitor.best_eer

Anthony Larcher's avatar
Anthony Larcher committed
1479

Anthony Larcher's avatar
Anthony Larcher committed
1480
1481
1482
1483
1484
1485
1486
1487
def train_epoch(model,
                training_opts,
                training_monitor,
                training_loader,
                optimizer,
                scheduler,
                device,
                scaler=None,
Anthony Larcher's avatar
Anthony Larcher committed
1488
1489
                clipping=False,
                aam_scheduler=None):
Anthony Larcher's avatar
Anthony Larcher committed
1490
1491
1492
    """

    :param model:
Anthony Larcher's avatar
Anthony Larcher committed
1493
    :param training_opts:
Anthony Larcher's avatar
Anthony Larcher committed
1494
1495
1496
1497
1498
1499
1500
    :param training_monitor:
    :param training_loader:
    :param optimizer:
    :param scheduler:
    :param device:
    :param scaler:
    :param clipping:
Anthony Larcher's avatar
Anthony Larcher committed
1501
    :param aam_scheduler:
Anthony Larcher's avatar
Anthony Larcher committed
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
    :return:
    """
    model.train()
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    if isinstance(model, Xtractor):
        loss_criteria = model.loss
    else:
        loss_criteria = model.module.loss

    accuracy = 0.0
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(training_loader):
        data = data.squeeze().to(device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1516

Anthony Larcher's avatar
Anthony Larcher committed
1517
1518
1519
1520
1521
1522
1523
1524
1525
        target = target.squeeze()
        target = target.to(device)
        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                if loss_criteria == 'aam':
                    output, _ = model(data, target=target)
                    loss = criterion(output, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1526
1527
1528
1529
                elif loss_criteria == 'smn':
                    output_tuple, _ = model(data, target=target)
                    loss, output = output_tuple
                    loss += criterion(output, target)
Anthony Larcher's avatar
Anthony Larcher committed
1530
1531
                elif loss_criteria == 'aps':
                    output_tuple, _ = model(data, target=target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1532
                    loss, output = output_tuple
Anthony Larcher's avatar
Anthony Larcher committed
1533
1534
1535
                else:
                    output, _ = model(data, target=None)
                    loss = criterion(output, target)
Anthony Larcher's avatar
Anthony Larcher committed
1536
1537
1538
1539
1540
1541
1542
1543

            scaler.scale(loss).backward()
            if clipping:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            scaler.step(optimizer)
            scaler.update()

Anthony Larcher's avatar
Anthony Larcher committed
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
        else:
            if loss_criteria == 'aam':
                output, _ = model(data, target=target)
                loss = criterion(output, target)
            elif loss_criteria == 'aps':
                output_tuple, _ = model(data, target=target)
                cos_sim_matx, output = output_tuple
                loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
            else:
                output, _ = model(data, target=None)
                loss = criterion(output, target)

Anthony Larcher's avatar
Anthony Larcher committed
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        accuracy += (torch.argmax(output.data, 1) == target).sum()

        if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
            batch_size = target.shape[0]
            training_monitor.update(training_loss=loss.item(),
                                    training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))

            training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
                training_monitor.current_epoch,
                batch_idx + 1,
                training_loader.__len__(),
                100. * batch_idx / training_loader.__len__(),
                loss.item(),
                100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))

Anthony Larcher's avatar
Anthony Larcher committed
1575

Anthony Larcher's avatar
Anthony Larcher committed
1576
        running_loss = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1577
1578
1579
1580
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(training_monitor.best_eer)
        else:
            scheduler.step()
Anthony Larcher's avatar
Anthony Larcher committed
1581
1582
        if aam_scheduler is not None:
            model.after_speaker_embedding.margin = aam_scheduler.step()
Anthony Larcher's avatar
Anthony Larcher committed
1583
1584
    return model

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1585
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
1586
1587
1588
    """

    :param model:
Anthony Larcher's avatar
Anthony Larcher committed
1589
1590
    :param validation_loader:
    :param device:
1591
    :param validation_shape:
1592
1593
1594
    :return:
    """
    model.eval()
Anthony Larcher's avatar
Anthony Larcher committed
1595
1596
1597
1598
1599
    if isinstance(model, Xtractor):
        loss_criteria = model.loss
    else:
        loss_criteria = model.module.loss

1600
    accuracy = 0.0
Anthony Larcher's avatar
Anthony Larcher committed
1601
    loss = 0.0
1602
    criterion = torch.nn.CrossEntropyLoss()
Anthony Larcher's avatar
Anthony Larcher committed
1603
    embeddings = torch.zeros(validation_shape)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1604
    cursor = 0
Anthony Larcher's avatar
Anthony Larcher committed
1605
    with torch.no_grad():
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1606
        for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1607
1608
            if target.dim() != 1:
                target = target.squeeze()
1609
            target = target.to(device)
Anthony Larcher's avatar
Anthony Larcher committed
1610
            batch_size = target.shape[0]
1611
            data = data.squeeze().to(device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1612
            with torch.cuda.amp.autocast(enabled=mixed_precision):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1613
1614
1615
1616
1617
1618
1619
                output, batch_embeddings = model(data, target=None, is_eval=True)
                if loss_criteria == 'cce':
                    batch_embeddings = l2_norm(batch_embeddings)
                if loss_criteria == 'smn':
                    batch_embeddings, batch_predictions = output
                else:
                    batch_predictions = output
1620
1621
                accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
                loss += criterion(batch_predictions, target)
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1622
1623
            embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
            cursor += batch_size
1624

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1625
    local_device = "cpu" if embeddings.shape[0] > 3e4 else device
Anthony Larcher's avatar
merge    
Anthony Larcher committed
1626
1627
1628
1629
    embeddings = embeddings.to(local_device)
    scores = torch.einsum('ij,kj', embeddings, embeddings).cpu().numpy()
    negatives = scores[non_indices]
    positives = scores[tar_indices]
1630

Anthony Larcher's avatar
merge    
Anthony Larcher committed
1631
1632
    pmiss, pfa = rocch(positives, negatives)
    equal_error_rate = rocch2eer(pmiss, pfa)
Anthony Larcher's avatar
Anthony Larcher committed
1633

Anthony Larcher's avatar
Anthony Larcher committed
1634
    return (100. * accuracy.cpu().numpy() / validation_shape[0],
Anthony Larcher's avatar
Anthony Larcher committed
1635
            loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
Anthony Larcher's avatar
Anthony Larcher committed
1636
            equal_error_rate)
1637
1638


Anthony Larcher's avatar
Anthony Larcher committed
1639
1640
1641
1642
1643
def extract_embeddings(idmap_name,
                       model_filename,
                       data_root_name,
                       device,
                       file_extension="wav",
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1644
                       transform_pipeline={},
Anthony Larcher's avatar
Anthony Larcher committed
1645
1646
1647
                       sliding_window=False,
                       win_duration=3.,
                       win_shift=1.5,
Anthony Larcher's avatar
Anthony Larcher committed
1648
                       num_thread=1,
Anthony Larcher's avatar
Anthony Larcher committed
1649
                       sample_rate=16000,
Anthony Larcher's avatar
Anthony Larcher committed
1650
1651
                       mixed_precision=False,
                       norm_embeddings=True):
1652
1653
    """

Anthony Larcher's avatar
Anthony Larcher committed
1654
1655
1656
1657
1658
1659
    :param idmap_name:
    :param model_filename:
    :param data_root_name:
    :param device:
    :param file_extension:
    :param transform_pipeline:
Anthony Larcher's avatar
Anthony Larcher committed
1660
1661
1662
    :param sliding_window:
    :param win_duration:
    :param win_shift:
1663
    :param num_thread:
Anthony Larcher's avatar
Anthony Larcher committed
1664
    :param sample_rate:
Anthony Larcher's avatar
Anthony Larcher committed
1665
    :param mixed_precision:
1666
1667
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
1668
1669
    # Load the model
    if isinstance(model_filename, str):
Anthony Larcher's avatar
Anthony Larcher committed
1670
        checkpoint = torch.load(model_filename, map_location=device)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1671
        speaker_number = checkpoint["speaker_number"]
Anthony Larcher's avatar
debg    
Anthony Larcher committed
1672
1673
        model_opts = checkpoint["model_archi"]
        model = Xtractor(speaker_number, model_archi=model_opts["model_type"], loss=model_opts["loss"]["type"])
1674
1675
1676
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename
Anthony Larcher's avatar
Anthony Larcher committed
1677

Anthony Larcher's avatar
Anthony Larcher committed
1678
    if isinstance(idmap_name, IdMap):
1679
1680
1681
1682
        idmap = idmap_name
    else:
        idmap = IdMap(idmap_name)

Anthony Larcher's avatar
Anthony Larcher committed
1683
    # Create dataset to load the data
Anthony Larcher's avatar
Anthony Larcher committed
1684
    dataset = IdMapSet(idmap_name=idmap,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1685
                       data_path=data_root_name,
Anthony Larcher's avatar
Anthony Larcher committed
1686
                       file_extension=file_extension,
1687
                       transform_pipeline=transform_pipeline,
Anthony Larcher's avatar
Anthony Larcher committed
1688
1689
1690
1691
1692
                       transform_number=0,
                       sliding_window=sliding_window,
                       window_len=win_duration,
                       window_shift=win_shift,
                       sample_rate=sample_rate,
Anthony Larcher's avatar
Anthony Larcher committed
1693
                       min_duration=win_duration
1694
                       )
Anthony Larcher's avatar
Anthony Larcher committed
1695

1696
1697
1698
1699
1700
1701
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)
1702

Anthony Larcher's avatar
Anthony Larcher committed
1703
    with torch.no_grad():
1704
1705
1706
        model.eval()
        model.to(device)

Anthony Larcher's avatar
Anthony Larcher committed
1707
1708
1709
1710
        embed = []
        modelset= []
        segset = []
        starts = []
1711

Anthony Larcher's avatar
extract    
Anthony Larcher committed
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
        for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
                                                                      desc='xvector extraction',
                                                                      mininterval=1,
                                                                      disable=None)):

            if data.dim() > 2:
                data = data.squeeze()

            with torch.cuda.amp.autocast(enabled=mixed_precision):
                tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
                for td in tmp_data:
Anthony Larcher's avatar
Anthony Larcher committed
1723
                    _, vec = model(x=td.to(device), is_eval=True, norm_embedding=norm_embeddings)
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1724
1725
1726
1727
1728
1729
1730
1731
                    embed.append(vec.detach().cpu())

                modelset.extend(mod * data.shape[0])
                segset.extend(seg * data.shape[0])
                if sliding_window:
                    starts.extend(numpy.arange(start, start + data.shape[0] * win_shift, win_shift))
                else:
                    starts.append(start)
Anthony Larcher's avatar
Anthony Larcher committed
1732
1733

        embeddings = StatServer()
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1734
        embeddings.stat1 = numpy.concatenate(embed)
Anthony Larcher's avatar
Anthony Larcher committed
1735
1736
1737
1738
1739
        embeddings.modelset = numpy.array(modelset).astype('>U')
        embeddings.segset = numpy.array(segset).astype('>U')
        embeddings.start = numpy.array(starts)
        embeddings.stop = numpy.array(starts) + win_duration
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1740

Anthony Larcher's avatar
Anthony Larcher committed
1741
1742
1743
1744

    return embeddings


Anthony Larcher's avatar
Anthony Larcher committed
1745
1746
1747
1748
1749
def extract_embeddings_per_speaker(idmap_name,
                                   model_filename,
                                   data_root_name,
                                   device,
                                   file_extension="wav",
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1750
1751
1752
                                   transform_pipeline={},
                                   sample_rate=16000,
                                   mixed_precision=False,
Anthony Larcher's avatar
Anthony Larcher committed
1753
1754
1755
                                   num_thread=1):
    # Load the model
    if isinstance(model_filename, str):
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1756
1757
1758
1759
        checkpoint = torch.load(model_filename, map_location=device)
        speaker_number = checkpoint["speaker_number"]
        model_opts = checkpoint["model_archi"]
        model = Xtractor(speaker_number, model_archi=model_opts["model_type"], loss=model_opts["loss"]["type"])
Anthony Larcher's avatar
Anthony Larcher committed
1760
1761
1762
1763
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model = model_filename

Anthony Larcher's avatar
extract    
Anthony Larcher committed
1764
1765
1766
1767
    if isinstance(idmap_name, IdMap):
        idmap = idmap_name
    else:
        idmap = IdMap(idmap_name)
Anthony Larcher's avatar
Anthony Larcher committed
1768

Anthony Larcher's avatar
Anthony Larcher committed
1769
    # Create dataset to load the data
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1770
    dataset = IdMapSetPerSpeaker(idmap_name=idmap,
Anthony Larcher's avatar
Anthony Larcher committed
1771
1772
1773
                                 data_root_path=data_root_name,
                                 file_extension=file_extension,
                                 transform_pipeline=transform_pipeline,
Anthony Larcher's avatar
extract    
Anthony Larcher committed
1774
                                 frame_rate=sample_rate,
Anthony Larcher's avatar
Anthony Larcher committed
1775
                                 min_duration=1.)
Anthony Larcher's avatar
Anthony Larcher committed
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=num_thread)

    with torch.no_grad():
        model.eval()
        model.to(device)

        # Get the size of embeddings to extract
        name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
Anthony Larcher's avatar
Anthony Larcher committed
1790
1791
1792
        if extract_after_pooling:
            emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
        else:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1793
            emb_size = model.embedding_size
Anthony Larcher's avatar
Anthony Larcher committed
1794
1795
1796

        # Create the StatServer
        embeddings = StatServer()
Anthony Larcher's avatar
Anthony Larcher committed
1797
1798
1799
1800
        embeddings.modelset = dataset.output_im.leftids
        embeddings.segset = dataset.output_im.rightids
        embeddings.start = dataset.output_im.start
        embeddings.stop = dataset.output_im.stop
Anthony Larcher's avatar
Anthony Larcher committed
1801
1802
1803
1804
1805
        embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
        embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))

        # Process the data
        with torch.no_grad():
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1806
            for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
Anthony Larcher's avatar
Anthony Larcher committed
1807
1808
                if data.shape[1] > 20000000:
                    data = data[..., :20000000]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
1809
                vec = model(data.to(device), is_eval=True)
Anthony Larcher's avatar
Anthony Larcher committed
1810
1811
1812
1813
                embeddings.stat1[idx, :] = vec.detach().cpu()

    return embeddings