Commit ee92acf2 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug noise

parent 14d2f470
......@@ -242,7 +242,7 @@ class SideSet(Dataset):
self.rir_df = None
if "add_reverb" in self.transform:
tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
tmp_rir_df = noise_df.loc[tmp_rir_df["type"] > "simulated_rirs"]
tmp_rir_df = tmp_rir_df.loc[tmp_rir_df["type"] == "simulated_rirs"]
# load the RIR database
self.rir_df = tmp_rir_df.set_index(tmp_rir_df.type)
......
......@@ -250,6 +250,7 @@ def test_metrics(model,
def new_test_metrics(model,
device,
model_opts,
data_opts,
train_opts):
"""Compute model metrics
......@@ -274,7 +275,7 @@ def new_test_metrics(model,
model_filename=model,
data_root_name=data_opts["test"]["data_path"],
device=device,
loss=model.loss,
loss=model_opts["loss"]["type"],
transform_pipeline=transform_pipeline,
num_thread=train_opts["num_cpu"],
mixed_precision=train_opts["mixed_precision"])
......@@ -343,13 +344,14 @@ class TrainingMonitor():
# Initialize the logger
logging_format = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=logging_format, datefmt='%m-%d %H:%M')
logger = logging.getLogger('Monitoring')
logger.setLevel(logging.DEBUG)
self.logger = logging.getLogger('Monitoring')
self.logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler(output_file)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
fh.setLevel(logging.DEBUG)
self.logger.addHandler(fh)
def display(self):
"""
......@@ -357,8 +359,8 @@ class TrainingMonitor():
:return:
"""
# TODO
self.logger.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
self.logger.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %")
self.logger.critical(f"***Validation metrics - Cross validation accuracy = {self.val_acc[-1]} %, EER = {self.val_eer[-1] * 100} %")
self.logger.critical(f"***Test metrics - Test EER = {self.test_eer[-1] * 100} %")
def display_final(self):
"""
......@@ -918,7 +920,7 @@ def update_training_dictionary(dataset_description,
dataset_opts["train"]["transformation"]["add_noise"]["noise_db_csv"] = ""
dataset_opts["train"]["transformation"]["add_noise"]["data_path"] = ""
dataset_opts["train"]["transformation"]["add_reverb"] = dict()
dataset_opts["train"]["transformation"]["add_reverb"]["noise_db_csv"] = ""
dataset_opts["train"]["transformation"]["add_reverb"]["rir_db_csv"] = ""
dataset_opts["train"]["transformation"]["add_reverb"]["data_path"] = ""
dataset_opts["valid"] = dict()
......@@ -1180,10 +1182,14 @@ def get_optimizer(model, model_opts, train_opts):
elif train_opts["scheduler"]["type"] == "StepLR2":
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
step_size=1e5,
step_size=2000,
gamma=0.5)
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
mode='min',
factor=0.5,
patience=3000,
verbose=True)
return optimizer, scheduler
......@@ -1209,8 +1215,8 @@ def save_model(model, training_monitor, model_opts, training_opts, optimizer, sc
'scheduler': scheduler,
'speaker_number' : model.speaker_number,
'model_archi': model_opts,
'loss': training_monitor.loss
}, training_monitor.is_best, filename=training_opts.tmp_model_name+".pt", best_filename=training_opts.best_model_name+'.pt')
'loss': model_opts["loss"]["type"]
}, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
else:
save_checkpoint({
'epoch': training_monitor.current_epoch,
......@@ -1220,8 +1226,8 @@ def save_model(model, training_monitor, model_opts, training_opts, optimizer, sc
'scheduler': scheduler,
'speaker_number': model.module.speaker_number,
'model_archi': model_opts,
'loss': training_monitor.loss
}, training_monitor.is_best, filename=training_opts.tmp_model_name+".pt", best_filename=training_opts.best_model_name+'.pt')
'loss': model_opts["loss"]["type"]
}, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
def new_xtrain(dataset_description,
......@@ -1329,7 +1335,7 @@ def new_xtrain(dataset_description,
test_eer = None
if training_opts["compute_test_eer"]:
test_eer = new_test_metrics(model, device, dataset_opts, new_test_metrics)
test_eer = new_test_metrics(model, device, model_opts, dataset_opts, training_opts)
monitor.update(test_eer=test_eer,
val_eer=val_eer,
......@@ -1896,7 +1902,10 @@ def new_train_epoch(model,
sys.exit()
running_loss = 0.0
scheduler.step()
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(training_monitor.best_eer)
else:
scheduler.step()
return model
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment