Commit 90dbe9a8 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debugging

parent cbb5ffc5
......@@ -87,7 +87,9 @@ class Xtractor(torch.nn.Module):
if model_archi is None:
self.feature_size = 30
self.activation = torch.nn.ReLU()
self.activation = torch.nn.LeakyReLU(0.2)
self.preprocessor = None
self.sequence_network = torch.nn.Sequential(OrderedDict([
("conv1", torch.nn.Conv1d(self.feature_size, 512, 5, dilation=1)),
......@@ -108,18 +110,23 @@ class Xtractor(torch.nn.Module):
]))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.Linear(1536, 512))
("linear6", torch.nn.Linear(3072, 512))
]))
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("dropout6", torch.nn.Dropout(p=0.05)),
("linear7", torch.nn.Linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
else:
# Load Yaml configuration
with open(model_archi, 'r') as fh:
......@@ -171,8 +178,8 @@ class Xtractor(torch.nn.Module):
if k.startswith("conv"):
segmental_layers.append((k, torch.nn.Conv1d(input_size,
cfg["segmental"][k]["output_channels"],
cfg["segmental"][k]["kernel_size"],
cfg["segmental"][k]["dilation"])))
kernel_size=cfg["segmental"][k]["kernel_size"],
dilation=cfg["segmental"][k]["dilation"])))
input_size = cfg["segmental"][k]["output_channels"]
elif k.startswith("activation"):
......@@ -332,26 +339,26 @@ def xtrain(speaker_number,
Set the training options
"""
if type(model) is Xtractor:
optimizer = torch.optim.SGD([
optimizer = torch.optim.Adam([
{'params': model.sequence_network.parameters(),
'weight_decay': model.sequence_network_weight_decay},
{'params': model.before_speaker_embedding.parameters(),
'weight_decay': model.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(),
'weight_decay': model.after_speaker_embedding_weight_decay}],
lr=lr, momentum=0.9
lr=lr
)
else:
optimizer = torch.optim.SGD([
optimizer = torch.optim.Adam([
{'params': model.module.sequence_network.parameters(),
'weight_decay': model.module.sequence_network_weight_decay},
{'params': model.module.before_speaker_embedding.parameters(),
'weight_decay': model.module.before_speaker_embedding_weight_decay},
{'params': model.module.after_speaker_embedding.parameters(),
'weight_decay': model.module.after_speaker_embedding_weight_decay}],
lr=lr, momentum=0.9
lr=lr
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
best_accuracy = 0.0
for epoch in range(1, epochs + 1):
......@@ -364,6 +371,7 @@ def xtrain(speaker_number,
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
# remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
......@@ -375,7 +383,7 @@ def xtrain(speaker_number,
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename = tmp_model_name+".pt", best_filename=output_model_name+'.pt')
}, is_best, filename=tmp_model_name, best_filename=output_model_name)
if is_best:
best_accuracy_epoch = epoch
......@@ -435,7 +443,7 @@ def cross_validation(model, validation_loader, device):
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
loss = criterion(output, target.to(device))
model.train()
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), loss
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment