Commit a62eb9f7 authored by Hubert Nourtel's avatar Hubert Nourtel
Browse files

Change scheduler call location depending on scheduler type

parent e2ff9075
......@@ -1016,6 +1016,10 @@ def xtrain(dataset_description,
if local_rank < 1:
save_model(model, monitor, model_opts, training_opts, optimizer, scheduler, epoch)
# Call scheduler step when validation data is needed for epoch scheduler
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(monitor.best_eer)
# TODO manage display using Training Monitor
if local_rank < 1:
monitor.display_final()
......@@ -1099,10 +1103,14 @@ def train_epoch(model,
running_loss = 0.0
accuracy = 0.0
batch_count = 0
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(training_monitor.best_eer)
else:
# Call scheduler step only for batch scheduler
if isinstance(scheduler, (torch.optim.lr_scheduler.OneCycleLR,
torch.optim.lr_scheduler.CyclicLR)):
scheduler.step()
# Call scheduler step for epoch scheduler without validation data needed
if isinstance(scheduler, (torch.optim.lr_scheduler.MultiStepLR,
torch.optim.lr_scheduler.StepLR)):
scheduler.step()
return model
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment