Commit 3a559ff9 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

refactoring

parent 39e464a5
...@@ -268,31 +268,24 @@ def new_test_metrics(model, ...@@ -268,31 +268,24 @@ def new_test_metrics(model,
Returns: Returns:
[type]: [description] [type]: [description]
""" """
# TODO modifier les parametres pour utiliser le dataset_description a la place de :
#idmap_test_filename,
#ndx_test_filename,
#key_test_filename,
#data_root_name,
transform_pipeline = dict() transform_pipeline = dict()
xv_stat = extract_embeddings(idmap_name=data_opts["idmap_test_filename"], xv_stat = extract_embeddings(idmap_name=data_opts["test"]["idmap"],
model_filename=model, model_filename=model,
data_root_name=data_opts["data_root_name"], data_root_name=data_opts["test"]["data_path"],
device=device, device=device,
loss=model.loss, loss=model.loss,
transform_pipeline=transform_pipeline, transform_pipeline=transform_pipeline,
num_thread=train_opts["num_thread"], num_thread=train_opts["num_cpu"],
mixed_precision=train_opts["mixed_precision"]) mixed_precision=train_opts["mixed_precision"])
tar, non = cosine_scoring(xv_stat, tar, non = cosine_scoring(xv_stat,
xv_stat, xv_stat,
Ndx(data_opts["ndx_test_filename"]), Ndx(data_opts["test"]["ndx"]),
wccn=None, wccn=None,
check_missing=True, check_missing=True,
device=device device=device
).get_tar_non(Key(data_opts["key_test_filename"])) ).get_tar_non(Key(data_opts["test"]["key"]))
#test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double)) #test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
pmiss, pfa = rocch(tar, non) pmiss, pfa = rocch(tar, non)
...@@ -314,9 +307,10 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename ...@@ -314,9 +307,10 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename
shutil.copyfile(filename, best_filename) shutil.copyfile(filename, best_filename)
class TrainingMonitor(): class TrainingMonitor():
"""
"""
def __init__(self, def __init__(self,
output_file, output_file,
log_interval=10, log_interval=10,
...@@ -326,7 +320,7 @@ class TrainingMonitor(): ...@@ -326,7 +320,7 @@ class TrainingMonitor():
best_eer=100, best_eer=100,
compute_test_eer=False compute_test_eer=False
): ):
# Stocker plutot des listes pour conserver l'historique complet
self.current_epoch = 0 self.current_epoch = 0
self.log_interval = log_interval self.log_interval = log_interval
self.init_patience = patience self.init_patience = patience
...@@ -347,13 +341,9 @@ class TrainingMonitor(): ...@@ -347,13 +341,9 @@ class TrainingMonitor():
self.is_best = True self.is_best = True
# Initialize the logger # Initialize the logger
logging.basicConfig(level=logging.DEBUG, logging_format = '%(asctime)-15s-8s %(message)s'
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', logging.basicConfig(level=logging.DEBUG, format=logging_format, datefmt='%m-%d %H:%M')
datefmt='%m-%d %H:%M', self.logger = logging.getLogger('Monitoring')
filename=output_file,
filemode='w')
self.logger = logging.getLogger('monitoring')
self.logger.setLevel(logging.INFO) self.logger.setLevel(logging.INFO)
# create file handler which logs even debug messages # create file handler which logs even debug messages
fh = logging.FileHandler(output_file) fh = logging.FileHandler(output_file)
...@@ -1002,9 +992,9 @@ def update_training_dictionary(dataset_description, ...@@ -1002,9 +992,9 @@ def update_training_dictionary(dataset_description,
def get_network(model_opts): def get_network(model_opts):
""" """
:param model_opts: :param model_opts:
:return: :return:
""" """
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34"]: if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
...@@ -2008,8 +1998,6 @@ def extract_embeddings(idmap_name, ...@@ -2008,8 +1998,6 @@ def extract_embeddings(idmap_name,
else: else:
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size() model_cs = model.module.context_size()
# Create dataset to load the data # Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name, dataset = IdMapSet(idmap_name=idmap_name,
...@@ -2042,7 +2030,7 @@ def extract_embeddings(idmap_name, ...@@ -2042,7 +2030,7 @@ def extract_embeddings(idmap_name,
if name != 'bias': if name != 'bias':
name = name + '.weight' name = name + '.weight'
emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0] emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer # Create the StatServer
embeddings = StatServer() embeddings = StatServer()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment