Commit dcfd5acc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 0df15d91
# coding: utf-8 -*- #coding: utf-8 -*-
# #
# This file is part of SIDEKIT. # This file is part of SIDEKIT.
# #
...@@ -1492,22 +1492,22 @@ def xtrain(dataset_description, ...@@ -1492,22 +1492,22 @@ def xtrain(dataset_description,
validation_non_indices, validation_non_indices,
training_opts["mixed_precision"]) training_opts["mixed_precision"])
#test_eer = None test_eer = None
#if training_opts["compute_test_eer"] and local_rank < 1: if training_opts["compute_test_eer"] and local_rank < 1:
# test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts) test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts)
monitor.update(test_eer=test_eer, #monitor.update(test_eer=test_eer,
val_eer=val_eer, # val_eer=val_eer,
val_loss=val_loss, # val_loss=val_loss,
val_acc=val_acc) # val_acc=val_acc)
#if local_rank < 1: #if local_rank < 1:
# monitor.display() # monitor.display()
# Save the current model and if needed update the best one # Save the current model and if needed update the best one
# TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR # TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
if local_rank < 1: #if local_rank < 1:
save_model(model, monitor, model_opts, training_opts, optimizer, scheduler, epoch) # save_model(model, monitor, model_opts, training_opts, optimizer, scheduler, epoch)
for ii in range(torch.cuda.device_count()): for ii in range(torch.cuda.device_count()):
...@@ -1681,6 +1681,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind ...@@ -1681,6 +1681,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
return (100. * accuracy.cpu().numpy() / validation_shape[0], return (100. * accuracy.cpu().numpy() / validation_shape[0],
loss.cpu().numpy() / ((batch_idx + 1) * batch_size), loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
equal_error_rate) equal_error_rate)
return 0, 0, 0
def extract_embeddings(idmap_name, def extract_embeddings(idmap_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment