Commit 62f9e7c4 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

seqtrain

parent e6c25a87
......@@ -275,13 +275,20 @@ def seqTrain(dataset_yaml,
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
#df = pandas.read_csv(dataset_params["dataset_description"])
#training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SeqSet(dataset_yaml,
set_type="train",
dataset_df=training_df)
wav_dir="data/wav/",
mdtm_dir="data/mdtm",
mode="vad",
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=16000,
output_framerate=100,
transform_pipeline="MFCC")
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
......@@ -290,15 +297,15 @@ def seqTrain(dataset_yaml,
pin_memory=True,
num_workers=num_thread)
validation_set = SeqSet(dataset_yaml,
set_type="validation",
dataset_df=validation_df)
#validation_set = SeqSet(dataset_yaml,
# set_type="validation",
# dataset_df=validation_df)
validation_loader = DataLoader(validation_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
pin_memory=True,
num_workers=num_thread)
#validation_loader = DataLoader(validation_set,
# batch_size=dataset_params["batch_size"],
# drop_last=True,
# pin_memory=True,
# num_workers=num_thread)
"""
Set the training options
......@@ -358,41 +365,41 @@ def seqTrain(dataset_yaml,
device=device)
# Cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
#accuracy, val_loss = cross_validation(model, validation_loader, device=device)
#logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
#scheduler.step(val_loss)
#print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
# remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
if type(model) is SeqToSeq:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
if is_best:
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
#is_best = accuracy > best_accuracy
#best_accuracy = max(accuracy, best_accuracy)
#if type(model) is SeqToSeq:
# save_checkpoint({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy,
# 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#else:
# save_checkpoint({
# 'epoch': epoch,
# 'model_state_dict': model.module.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy,
# 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#if is_best:
# best_accuracy_epoch = epoch
# curr_patience = patience
#else:
# curr_patience -= 1
#logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment