Commit 3cf12ccc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new options cyclic

parent a4125887
......@@ -511,8 +511,14 @@ class Xtractor(torch.nn.Module):
n_mels=80)
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = self.embedding_size)
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("lin_be", torch.nn.Linear(in_features = 5120, out_features = self.embedding_size, bias=False)),
("bn_be", torch.nn.BatchNorm1d(self.embedding_size))
]))
self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = loss
......@@ -953,7 +959,9 @@ def update_training_dictionary(dataset_description,
training_opts["scheduler"] = dict()
training_opts["scheduler"]["type"] = "ReduceLROnPlateau"
training_opts["scheduler"]["options"] = None
training_opts["scheduler"]["step_size_up"] = 10
training_opts["scheduler"]["base_lr"] = 1e-8
training_opts["scheduler"]["mode"] = "triangular2"
training_opts["compute_test_eer"] = False
training_opts["log_interval"] = 10
......@@ -1026,6 +1034,8 @@ def get_network(model_opts, local_rank):
print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
if local_rank < 1:
logging.info(model)
logging.info("Model_parameters_count: {:d}".format(
sum(p.numel()
......@@ -1199,24 +1209,24 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
if train_opts["optimizer"]["type"] == "adam":
cycle_momentum = False
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
max_lr=train_opts["lr"],
step_size_up=model_opts["speaker_number"] * 8,
step_size_down=None,
cycle_momentum=cycle_momentum,
mode="triangular2")
elif train_opts["scheduler"]["type"] == 'CyclicLR1':
cycle_momentum = True
if train_opts["optimizer"]["type"] == "adam":
cycle_momentum = False
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
base_lr=train_opts["scheduler"]["base_lr"],
max_lr=train_opts["lr"],
step_size_up=model_opts["speaker_number"] * 2,
step_size_up=train_opts["scheduler"]["step_size_up"],
step_size_down=None,
cycle_momentum=cycle_momentum,
mode="triangular")
mode=train_opts["scheduler"]["mode"])
#elif train_opts["scheduler"]["type"] == 'CyclicLR1':
# cycle_momentum = True
# if train_opts["optimizer"]["type"] == "adam":
# cycle_momentum = False
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
# base_lr=1e-8,
# max_lr=train_opts["lr"],
# step_size_up=model_opts["speaker_number"] * 4,
# step_size_down=None,
# cycle_momentum=cycle_momentum,
# mode="triangular")
elif train_opts["scheduler"]["type"] == "MultiStepLR":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
......@@ -1242,7 +1252,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
return optimizer, scheduler
def save_model(model, training_monitor, model_opts, training_opts, optimizer, scheduler):
def save_model(model, training_monitor, model_opts, training_opts, optimizer, scheduler, epoch):
"""
:param model:
......@@ -1253,6 +1263,13 @@ def save_model(model, training_monitor, model_opts, training_opts, optimizer, sc
:param scheduler:
:return:
"""
best_name = training_opts["best_model_name"]
tmp_name = training_opts["tmp_model_name"]
if epoch is not None:
best_name = best_name + f"_epoch{epoch}"
# TODO à reprendre
if type(model) is Xtractor:
save_checkpoint({
......@@ -1264,7 +1281,7 @@ def save_model(model, training_monitor, model_opts, training_opts, optimizer, sc
'speaker_number' : model.speaker_number,
'model_archi': model_opts,
'loss': model_opts["loss"]["type"]
}, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
}, training_monitor.is_best, filename=tmp_name, best_filename=best_name)
else:
save_checkpoint({
'epoch': training_monitor.current_epoch,
......@@ -1275,7 +1292,7 @@ def save_model(model, training_monitor, model_opts, training_opts, optimizer, sc
'speaker_number': model.module.speaker_number,
'model_archi': model_opts,
'loss': model_opts["loss"]["type"]
}, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
}, training_monitor.is_best, filename=tmp_name, best_filename=best_name)
class AAMScheduler():
......@@ -1369,6 +1386,9 @@ def xtrain(dataset_description,
# Initialize the model
model = get_network(model_opts, local_rank)
if local_rank < 1:
monitor.logger.info(model)
embedding_size = model.embedding_size
aam_scheduler = None
#if model.loss == "aam":
......@@ -1487,7 +1507,8 @@ def xtrain(dataset_description,
# Save the current model and if needed update the best one
# TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
if local_rank < 1:
save_model(model, monitor, model_opts, training_opts, optimizer, scheduler)
save_model(model, monitor, model_opts, training_opts, optimizer, scheduler, epoch)
for ii in range(torch.cuda.device_count()):
monitor.logger.info(torch.cuda.memory_summary(ii))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment