Commit dcfd5acc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 0df15d91
# coding: utf-8 -*-
#coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
......@@ -1492,22 +1492,22 @@ def xtrain(dataset_description,
validation_non_indices,
training_opts["mixed_precision"])
#test_eer = None
#if training_opts["compute_test_eer"] and local_rank < 1:
# test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts)
test_eer = None
if training_opts["compute_test_eer"] and local_rank < 1:
test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts)
monitor.update(test_eer=test_eer,
val_eer=val_eer,
val_loss=val_loss,
val_acc=val_acc)
#monitor.update(test_eer=test_eer,
# val_eer=val_eer,
# val_loss=val_loss,
# val_acc=val_acc)
#if local_rank < 1:
# monitor.display()
# Save the current model and if needed update the best one
# TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
if local_rank < 1:
save_model(model, monitor, model_opts, training_opts, optimizer, scheduler, epoch)
#if local_rank < 1:
# save_model(model, monitor, model_opts, training_opts, optimizer, scheduler, epoch)
for ii in range(torch.cuda.device_count()):
......@@ -1681,6 +1681,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
return (100. * accuracy.cpu().numpy() / validation_shape[0],
loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
equal_error_rate)
return 0, 0, 0
def extract_embeddings(idmap_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment