Commit 3a559ff9 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

refactoring

parent 39e464a5
......@@ -268,31 +268,24 @@ def new_test_metrics(model,
Returns:
[type]: [description]
"""
# TODO modifier les parametres pour utiliser le dataset_description a la place de :
#idmap_test_filename,
#ndx_test_filename,
#key_test_filename,
#data_root_name,
transform_pipeline = dict()
xv_stat = extract_embeddings(idmap_name=data_opts["idmap_test_filename"],
xv_stat = extract_embeddings(idmap_name=data_opts["test"]["idmap"],
model_filename=model,
data_root_name=data_opts["data_root_name"],
data_root_name=data_opts["test"]["data_path"],
device=device,
loss=model.loss,
transform_pipeline=transform_pipeline,
num_thread=train_opts["num_thread"],
num_thread=train_opts["num_cpu"],
mixed_precision=train_opts["mixed_precision"])
tar, non = cosine_scoring(xv_stat,
xv_stat,
Ndx(data_opts["ndx_test_filename"]),
Ndx(data_opts["test"]["ndx"]),
wccn=None,
check_missing=True,
device=device
).get_tar_non(Key(data_opts["key_test_filename"]))
).get_tar_non(Key(data_opts["test"]["key"]))
#test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
pmiss, pfa = rocch(tar, non)
......@@ -314,9 +307,10 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename
shutil.copyfile(filename, best_filename)
class TrainingMonitor():
"""
"""
def __init__(self,
output_file,
log_interval=10,
......@@ -326,7 +320,7 @@ class TrainingMonitor():
best_eer=100,
compute_test_eer=False
):
# Stocker plutot des listes pour conserver l'historique complet
self.current_epoch = 0
self.log_interval = log_interval
self.init_patience = patience
......@@ -347,13 +341,9 @@ class TrainingMonitor():
self.is_best = True
# Initialize the logger
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M',
filename=output_file,
filemode='w')
self.logger = logging.getLogger('monitoring')
logging_format = '%(asctime)-15s-8s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=logging_format, datefmt='%m-%d %H:%M')
self.logger = logging.getLogger('Monitoring')
self.logger.setLevel(logging.INFO)
# create file handler which logs even debug messages
fh = logging.FileHandler(output_file)
......@@ -2009,8 +1999,6 @@ def extract_embeddings(idmap_name,
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size()
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
data_path=data_root_name,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment