Commit 291c55a9 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xvectors cleaning

parent 690ba2cc
......@@ -28,6 +28,7 @@ Copyright 2014-2020 Yevhenii Prokopalo, Anthony Larcher
import logging
import numpy
import pickle
import shutil
import torch
import torch.optim as optim
import torch.multiprocessing as mp
......@@ -151,6 +152,7 @@ class Xtractor(torch.nn.Module):
segmental_layers.append((k, torch.nn.BatchNorm1d(input_size)))
self.sequence_network = torch.nn.Sequential(OrderedDict(segmental_layers))
self.sequence_network_weight_decay = cfg["segmental"]["weight_decay"]
# Create sequential object for the second part of the network
input_size = input_size * 2
......@@ -173,6 +175,7 @@ class Xtractor(torch.nn.Module):
before_embedding_layers.append((k, torch.nn.Dropout(p=cfg["before_embedding"][k])))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(before_embedding_layers))
self.before_speaker_embedding_weight_decay = cfg["before_embedding"]["weight_decay"]
# Create sequential object for the second part of the network
after_embedding_layers = []
......@@ -194,6 +197,7 @@ class Xtractor(torch.nn.Module):
after_embedding_layers.append((k, torch.nn.Dropout(p=cfg["after_embedding"][k])))
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict(after_embedding_layers))
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
def forward(self, x, is_eval=False):
"""
......@@ -216,7 +220,13 @@ class Xtractor(torch.nn.Module):
return x
def xtrain(args):
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
#def xtrain(args):
def xtrain(speaker_number, config=None, model_name=None)
"""
Initialize and train an x-vector on a single GPU
......@@ -224,19 +234,20 @@ def xtrain(args):
:return:
"""
# If we start from an existing model
if not args.init_model_name == '':
if model_name is not None:
# Load the model
logging.critical("*** Load model from = {}/{}".format(args.model_path, args.init_model_name))
model_file_name = '/'.join([args.model_path, args.init_model_name])
model = torch.load(model_file_name)
model.train()
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, config)
model.load_state_dict(checkpoint["model_state_dict"])
else:
# Initialize a first model and save to disk
if args.yaml is None:
model = Xtractor(args.class_number)
if config is None:
model = Xtractor(speaker_number)
else:
model = Xtractor(args.class_number, args.yaml)
model.train()
model = Xtractor(speaker_number, config)
model.train()
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
......@@ -263,8 +274,12 @@ def xtrain(args):
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
optimizer = torch.optim.SGD([
{'params': model.sequence_network.parameters(), 'weight_decay': self.sequence_network_weight_decay},
{'params': model.before_speaker_embedding.parameters(), 'weight_decay': self.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(), 'weight_decay': self.after_speaker_embedding_weight_decay}],
lr=args.lr, momentum=0.9
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
for epoch in range(1, args.epochs + 1):
......@@ -278,12 +293,20 @@ def xtrain(args):
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
# return the file name of the new model
base_name = "model"
if not args.init_model_name == "":
base_name = args.init_model_name
current_model_file_name = "{}/{}_{}_epoch_{}".format(args.model_path, base_name, args.expe_id, epoch)
torch.save(model, current_model_file_name)
# remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename = tmp_model_name+".pt", best_filename=output_model_name+'.pt')
if is_best:
best_accuracy_epoch = epoch
def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment