Commit 195c4fc8 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

schedulers

parent 2a3b74d2
......@@ -235,6 +235,7 @@ def test_metrics(model,
model_filename=model,
data_root_name=data_root_name,
device=device,
loss=model.loss,
transform_pipeline=transform_pipeline,
num_thread=num_thread,
mixed_precision=mixed_precision)
......@@ -246,7 +247,8 @@ def test_metrics(model,
check_missing=True,
device=device)
tar, non = scores.get_tar_non(Key(key_test_filename))
k = Key(key_test_filename)
tar, non = scores.get_tar_non(k)
test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
return test_eer
......@@ -770,6 +772,8 @@ def xtrain(speaker_number,
loss=None,
aam_margin=None,
aam_s=None,
scheduler_type="ReduceLROnPlateau",
scheduler_params={},
patience=None,
tmp_model_name=None,
best_model_name=None,
......@@ -1004,9 +1008,29 @@ def xtrain(speaker_number,
param_list.append({'params': model.module.after_speaker_embedding.parameters(), 'weight_decay': model.module.after_speaker_embedding_weight_decay})
optimizer = _optimizer(param_list, **_options)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[10000,50000,100000],
gamma=0.5)
if scheduler_type == 'CyclicLR':
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
max_lr=1e-3,
mode="triangular2",
step_size_up=75000)
elif scheduler_type == "MultiStepLR":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[10000,50000,100000],
gamma=0.5)
elif scheduler_type == "StepLR":
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
step_size=2e3,
gamma=0.95)
elif scheduler_type == "StepLR2":
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
step_size=1e5,
gamma=0.5,
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(**scheduler_params)
if mixed_precision:
scaler = torch.cuda.amp.GradScaler()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment