Commit ba2f184a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xv cleaning

parent 291c55a9
......@@ -73,16 +73,16 @@ class Xtractor(torch.nn.Module):
Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
"""
def __init__(self, speaker_number, config=None):
def __init__(self, speaker_number, model_archi=None):
"""
If config is None, default architecture is created
:param config:
:param model_archi:
"""
super(Xtractor, self).__init__()
self.speaker_number = speaker_number
self.feature_size = 24
if config is None:
if model_archi is None:
self.activation = torch.nn.ReLU()
self.sequence_network = torch.nn.Sequential(OrderedDict([
......@@ -113,12 +113,12 @@ class Xtractor(torch.nn.Module):
("linear7", torch.nn.Linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.Linear(512, self.speaker_number ))
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
else:
# Load Yaml configuration
with open(config, 'r') as fh:
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
# Get Feature size
......@@ -226,7 +226,14 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename
shutil.copyfile(filename, best_filename)
#def xtrain(args):
def xtrain(speaker_number, config=None, model_name=None)
def xtrain(speaker_number,
dataset_config,
epochs=10,
lr=0.01,
model_archi=None,
model_name=None,
log_interval=10,
seed=1234):
"""
Initialize and train an x-vector on a single GPU
......@@ -238,14 +245,14 @@ def xtrain(speaker_number, config=None, model_name=None)
# Load the model
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, config)
model = Xtractor(speaker_number, model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
# Initialize a first model and save to disk
if config is None:
if model_archi is None:
model = Xtractor(speaker_number)
else:
model = Xtractor(speaker_number, config)
model = Xtractor(speaker_number, model_archi)
model.train()
......@@ -255,8 +262,16 @@ def xtrain(speaker_number, config=None, model_name=None)
model.cuda()
"""
Set the dataloaders according to the dataset_config
"""
# Split the training data in train and cv
total_seg_df = pickle.load(open(args.batch_training_list, "rb"))
# temporaire...
with open(dataset_config, "r") as fh:
dataCfg = yaml.load(fh, Loader=yaml.FullLoader)
total_seg_df = pickle.load(open(dataCfg["batch_training_list"], "rb"))
speaker_dict = {}
tmp = total_seg_df.speaker_id.unique()
......@@ -271,23 +286,42 @@ def xtrain(speaker_number, config=None, model_name=None)
train_seg_df = total_seg_df.iloc[idx[:int((1 - cv_portion) * len(idx))]].reset_index()
cv_seg_df = total_seg_df.iloc[idx[int((1 - cv_portion) * len(idx)):]].reset_index()
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
optimizer = torch.optim.SGD([
{'params': model.sequence_network.parameters(), 'weight_decay': self.sequence_network_weight_decay},
{'params': model.before_speaker_embedding.parameters(), 'weight_decay': self.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(), 'weight_decay': self.after_speaker_embedding_weight_decay}],
lr=args.lr, momentum=0.9
)
"""
Set the training options
"""
#optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
if type(model) is Xtractor:
optimizer = torch.optim.SGD([
{'params': model.sequence_network.parameters(), 'weight_decay': model.sequence_network_weight_decay},
{'params': model.before_speaker_embedding.parameters(), 'weight_decay': model.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(), 'weight_decay': model.after_speaker_embedding_weight_decay}],
lr=lr, momentum=0.9
)
else:
optimizer = torch.optim.SGD([
{'params': model.module.sequence_network.parameters(), 'weight_decay': model.module.sequence_network_weight_decay},
{'params': model.module.before_speaker_embedding.parameters(), 'weight_decay': model.module.before_speaker_embedding_weight_decay},
{'params': model.module.after_speaker_embedding.parameters(), 'weight_decay': model.module.after_speaker_embedding_weight_decay}],
lr=lr, momentum=0.9
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
for epoch in range(1, args.epochs + 1):
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
model = train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args)
model = train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, dataCfg["seed"], dataCfg["train"]["transformation"]["pipeline"],
dataCfg["train"]["duration"]*100,
dataCfg["train"]["transformation"]["spec_aug"], dataCfg["train"]["transformation"]["temp_aug"],
dataCfg["batch_size"], dataCfg["log_interval"])
# Add the cross validation here
accuracy, val_loss = cross_validation(args, model, cv_seg_df, speaker_dict)
accuracy, val_loss = cross_validation(model, cv_seg_df, speaker_dict, dataCfg["validation"]["transformation"]["pipeline"],
dataCfg["validation"]["transfortmation"]["spec_aug"],
dataCfg["validation"]["transfortmation"]["temp_aug"],
dataCfg["batch_size"])
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
......@@ -309,7 +343,15 @@ def xtrain(speaker_number, config=None, model_name=None)
best_accuracy_epoch = epoch
def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args):
def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer,
seed,
transformation_pipeline,
duration,
spec_aug,
temp_aug,
batch_size,
log_interval
):
"""
:param model:
......@@ -322,11 +364,11 @@ def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args):
"""
device = torch.device("cuda:0")
torch.manual_seed(args.seed)
torch.manual_seed(seed)
train_transform = []
if not args.train_transformation == '':
trans = args.train_transformation.split(',')
if not transformation_pipeline == '':
trans = transformation_pipeline.split(',')
for t in trans:
if "CMVN" in t:
train_transform.append(CMVN())
......@@ -337,9 +379,9 @@ def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args):
if "TemporalMask" in t:
a = int(t.split("(")[1].split(")")[0])
train_transform.append(TemporalMask(a))
train_set = VoxDataset(train_seg_df, speaker_dict, args.duration, transform=transforms.Compose(train_transform),
spec_aug_ratio=args.spec_aug, temp_aug_ratio=args.temp_aug)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=15)
train_set = VoxDataset(train_seg_df, speaker_dict, duration, transform=transforms.Compose(train_transform),
spec_aug_ratio=spec_aug, temp_aug_ratio=temp_aug)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=15)
criterion = torch.nn.CrossEntropyLoss()
......@@ -353,11 +395,11 @@ def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args):
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
if batch_idx % args.log_interval == 0:
if batch_idx % log_interval == 0:
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
return model
......@@ -380,7 +422,12 @@ def train_epoch(model, epoch, train_seg_df, speaker_dict, optimizer, args):
# return 100. * accuracy.cpu().numpy() / ((bi + 1) * args.batch_size)
def cross_validation(args, model, cv_seg_df, speaker_dict):
def cross_validation(model, cv_seg_df, speaker_dict,
transformation_pipeline,
spec_aug,
temp_aug,
batch_size
):
"""
:param args:
......@@ -389,8 +436,8 @@ def cross_validation(args, model, cv_seg_df, speaker_dict):
:return:
"""
cv_transform = []
if not args.cv_transformation == '':
trans = args.cv_transformation.split(',')
if not transformation_pipeline == '':
tpeline_rans = transformation.split(',')
for t in trans:
if "CMVN" in t:
cv_transform.append(CMVN())
......@@ -402,8 +449,8 @@ def cross_validation(args, model, cv_seg_df, speaker_dict):
a = t.split(",")[0].split("(")[1]
cv_transform.append(TemporalMask(a, b))
cv_set = VoxDataset(cv_seg_df, speaker_dict, 500, transform=transforms.Compose(cv_transform),
spec_aug_ratio=args.spec_aug, temp_aug_ratio=args.temp_aug)
cv_loader = DataLoader(cv_set, batch_size=args.batch_size, shuffle=False, num_workers=15)
spec_aug_ratio=spec_aug, temp_aug_ratio=temp_aug)
cv_loader = DataLoader(cv_set, batch_size=batch_size, shuffle=False, num_workers=15)
model.eval()
device = torch.device("cuda:0")
model.to(device)
......@@ -418,7 +465,7 @@ def cross_validation(args, model, cv_seg_df, speaker_dict):
loss = criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * args.batch_size), loss
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), loss
def xtrain_asynchronous(args):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment