Commit 29f13393 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new xtractor single GPU training

parent c6ccacfd
......@@ -64,6 +64,9 @@ def split_file_list(batch_files, num_processes):
class Xtractor(torch.nn.Module):
"""
Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
"""
def __init__(self, spk_number, dropout):
super(Xtractor, self).__init__()
self.frame_conv0 = torch.nn.Conv1d(20, 512, 5, dilation=1)
......@@ -87,50 +90,12 @@ class Xtractor(torch.nn.Module):
#
self.activation = torch.nn.LeakyReLU(0.2)
def forward(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
# Pooling Layer that computes mean and standard devition of frame level embeddings
# The output of the pooling layer is the first segment-level representation
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb_0 = torch.cat([mean, std], dim=1)
# batch-normalisation after this layer
seg_emb_1 = self.dropout_lin0(seg_emb_0)
seg_emb_2 = self.norm6(self.activation(self.seg_lin0(seg_emb_1)))
# new layer with batch Normalization
seg_emb_3 = self.dropout_lin1(seg_emb_2)
seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
# No batch-normalisation after this layer
seg_emb_5 = self.seg_lin2(seg_emb_4)
result = self.activation(seg_emb_5)
return result
def init_weights(self):
def produce_embeddings(self, x):
"""
"""
torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
torch.nn.init.xavier_uniform(self.seg_lin0.weight)
torch.nn.init.xavier_uniform(self.seg_lin1.weight)
torch.nn.init.xavier_uniform(self.seg_lin2.weight)
torch.nn.init.constant(self.frame_conv0.bias, 0.1)
torch.nn.init.constant(self.frame_conv1.bias, 0.1)
torch.nn.init.constant(self.frame_conv2.bias, 0.1)
torch.nn.init.constant(self.frame_conv3.bias, 0.1)
torch.nn.init.constant(self.frame_conv4.bias, 0.1)
torch.nn.init.constant(self.seg_lin0.bias, 0.1)
torch.nn.init.constant(self.seg_lin1.bias, 0.1)
torch.nn.init.constant(self.seg_lin2.bias, 0.1)
def extract(self, x):
:param x:
:return:
"""
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
......@@ -139,46 +104,41 @@ class Xtractor(torch.nn.Module):
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb = torch.cat([mean, std], dim=1)
seg_emb_0 = torch.cat([mean, std], dim=1)
# batch-normalisation after this layer
seg_emb_1 = self.seg_lin0(seg_emb_0)
seg_emb_2 = self.activation(seg_emb_1)
seg_emb_3 = self.norm6(seg_emb_2)
seg_emb_4 = self.seg_lin1(seg_emb_3)
seg_emb_5 = self.activation(seg_emb_4)
seg_emb_6 = self.norm7(seg_emb_5)
return seg_emb_1, seg_emb_2, seg_emb_3, seg_emb_4, seg_emb_5, seg_emb_6
embedding_a = self.seg_lin0(seg_emb)
return embedding_a
def forward(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
# Pooling Layer that computes mean and standard devition of frame level embeddings
# The output of the pooling layer is the first segment-level representation
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb_0 = torch.cat([mean, std], dim=1)
"""
:param x:
:return:
"""
seg_emb_0 = self.produce_embeddings(x)
# batch-normalisation after this layer
seg_emb_1 = self.dropout_lin0(seg_emb_0)
seg_emb_2 = self.norm6(self.activation(self.seg_lin0(seg_emb_1)))
seg_emb_1 = self.norm6(self.activation(seg_emb_0))
# new layer with batch Normalization
seg_emb_3 = self.dropout_lin1(seg_emb_2)
seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
seg_emb_2 = self.norm7(self.activation(self.seg_lin1(self.dropout_lin1(seg_emb_1))))
# No batch-normalisation after this layer
seg_emb_5 = self.seg_lin2(seg_emb_4)
result = self.activation(seg_emb_5)
result = self.activation(self.seg_lin2(seg_emb_2))
return result
def LossFN(self, x, label):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(label)))
return loss
def extract(self, x):
"""
Extract x-vector given an input sequence of features
:param x:
:return:
"""
embedding_a = self.produce_embeddings(x)
embedding_b = self.seg_lin1(self.norm6(self.activation(embedding_a)))
return embedding_a, embedding_b
def init_weights(self):
"""
Initialize the x-vector extract weights and biaises
"""
torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
......@@ -198,24 +158,14 @@ class Xtractor(torch.nn.Module):
torch.nn.init.constant(self.seg_lin1.bias, 0.1)
torch.nn.init.constant(self.seg_lin2.bias, 0.1)
def extract(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb = torch.cat([mean, std], dim=1)
embedding_A = self.seg_lin0(seg_emb)
embedding_B = self.seg_lin1(self.norm6(self.activation(embedding_A)))
return embedding_A, embedding_B
def xtrain(args):
"""
Initialize and train an x-vector in asynchronous manner
:param args:
:return:
"""
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
current_model_file_name = "initial_model"
......@@ -232,10 +182,17 @@ def xtrain(args):
#args.lr = args.lr * 0.9
args.lr = args.lr * 0.9
print(" Decrease learning rate: {}".format(args.lr))
def train_epoch(epoch, args, initial_model_file_name):
"""
Process one training epoch using an asynchronous implementation of the training
:param epoch:
:param args:
:param initial_model_file_name:
:return:
"""
# Compute the megabatch number
with open(args.batch_training_list, 'r') as fh:
batch_file_list = [l.rstrip() for l in fh]
......@@ -260,7 +217,78 @@ def train_epoch(epoch, args, initial_model_file_name):
return current_model
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
"""
Process one mega-batch of data asynchronously, average the model parameters across
subrocesses and return the updated version of the model
:param epoch:
:param args:
:param initial_model_file_name:
:param batch_file_list:
:param megabatch_idx:
:param megabatch_number:
:return:
"""
# Split the list of files for each process
sub_lists = split_file_list(batch_file_list, args.num_processes)
#
output_queue = mp.Queue()
# output_queue = multiprocessing.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=train_worker,
args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
)
# We first train the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
asynchronous_model = []
for ii in range(args.num_processes):
asynchronous_model.append(dict(output_queue.get()))
for p in processes:
p.join()
av_model = Xtractor(args.class_number, args.dropout)
tmp = av_model.state_dict()
average_param = dict()
for k in list(asynchronous_model[0].keys()):
average_param[k] = asynchronous_model[0][k]
for mod in asynchronous_model[1:]:
average_param[k] += mod[k]
if 'num_batches_tracked' not in k:
tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))
# return the file name of the new model
current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
megabatch_idx)
torch.save(tmp, current_model_file_name)
if megabatch_idx == megabatch_number:
torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))
return current_model_file_name
def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
"""
:param rank:
:param epoch:
:param args:
:param initial_model_file_name:
:param batch_list:
:param output_queue:
:return:
"""
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(initial_model_file_name))
model.train()
......@@ -282,9 +310,6 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
], lr=args.lr)
#criterion = torch.nn.CrossEntropyLoss(reduction='sum')
#criterion = torch.nn.NLLLoss()
criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
......@@ -311,52 +336,6 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
output_queue.put(model_param)
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
# Split the list of files for each process
sub_lists = split_file_list(batch_file_list, args.num_processes)
#
output_queue = mp.Queue()
# output_queue = multiprocessing.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=train_worker,
args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
)
# We first train the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
asynchronous_model = []
for ii in range(args.num_processes):
asynchronous_model.append(dict(output_queue.get()))
for p in processes:
p.join()
av_model = Xtractor(args.class_number, args.dropout)
tmp = av_model.state_dict()
average_param = dict()
for k in list(asynchronous_model[0].keys()):
average_param[k] = asynchronous_model[0][k]
for mod in asynchronous_model[1:]:
average_param[k] += mod[k]
if 'num_batches_tracked' not in k:
tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))
# return the file name of the new model
current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
megabatch_idx)
torch.save(tmp, current_model_file_name)
if megabatch_idx == megabatch_number:
torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))
return current_model_file_name
def cross_validation(args, current_model_file_name):
"""
......@@ -473,7 +452,86 @@ def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, outpu
output_queue.put((segment_indices, emb_1, emb_2, emb_3, emb_4, emb_5, emb_6))
def xtrain_single(args):
"""
Initialize and train an x-vector on a single GPU
:param args:
:return:
"""
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
model.train()
model.cuda()
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
for epoch in range(1, args.epochs + 1):
# Process one epoch and return the current model
model = train_epoch_single(model, epoch, args, current_model_file_name)
# Add the cross validation here
#accuracy = cross_validation_single(args, current_model_file_name)
#print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
#args.lr = args.lr * 0.9
#args.lr = args.lr * 0.9
#print(" Decrease learning rate: {}".format(args.lr))
def train_epoch_single(model, epoch, args, batch_list, output_queue):
"""
:param model:
:param epoch:
:param args:
:param batch_list:
:param output_queue:
:return:
"""
device = device = torch.device("cuda:0")
torch.manual_seed(args.seed)
train_loader = XvectorMultiDataset(batch_list, args.batch_path)
optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
], lr=args.lr)
criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data.to(device))
loss = criterion(output, target.to(device))
loss.backward()
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
return model
def extract_parallel(args, fs_params):
"""
:param args:
:param fs_params:
:return:
"""
emb_a_size = 512
emb_b_size = 512
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment