Commit 013117aa authored by Anthony Larcher's avatar Anthony Larcher
Browse files

scheduler option

parent 195c4fc8
......@@ -1016,21 +1016,22 @@ def xtrain(speaker_number,
mode="triangular2",
step_size_up=75000)
elif scheduler_type == "MultiStepLR":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
milestones=[10000,50000,100000],
gamma=0.5)
elif scheduler_type == "StepLR":
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
step_size=2e3,
gamma=0.95)
elif scheduler_type == "StepLR2":
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
step_size=1e5,
gamma=0.5,
gamma=0.5)
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(**scheduler_params)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,**scheduler_params)
if mixed_precision:
scaler = torch.cuda.amp.GradScaler()
......@@ -1217,8 +1218,13 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
pickle.dump(data.cpu(), fh)
import sys
sys.exit()
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(running_loss)
else:
scheduler.step()
running_loss = 0.0
scheduler.step()
return model
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment