Commit a8aa2443 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning in xvectors

parent 53274eaf
......@@ -62,13 +62,10 @@ def get_lr(optimizer):
return param_group['lr']
def split_file_list(batch_files, num_processes):
# Cut the list of files into args.num_processes lists of files
batch_sub_lists = [[]] * num_processes
x = [ii for ii in range(len(batch_files))]
for ii in range(num_processes):
batch_sub_lists[ii - 1] = [batch_files[z + ii] for z in x[::num_processes] if (z + ii) < len(batch_files)]
return batch_sub_lists
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
class Xtractor(torch.nn.Module):
......@@ -263,10 +260,6 @@ class Xtractor(torch.nn.Module):
return x
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
def xtrain(speaker_number,
......@@ -444,239 +437,6 @@ def cross_validation(model, validation_loader, device):
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), loss
def xtrain_asynchronous(args):
"""
Initialize and train an x-vector in asynchronous manner
:param args:
:return:
"""
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
for epoch in range(1, args.epochs + 1):
current_model_file_name = train_asynchronous_epoch(epoch, args, current_model_file_name)
# Add the cross validation here
accuracy = cross_asynchronous_validation(args, current_model_file_name)
print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
print(" Decrease learning rate: {}".format(args.lr))
def train_asynchronous_epoch(epoch, args, initial_model_file_name):
"""
Process one training epoch using an asynchronous implementation of the training
:param epoch:
:param args:
:param initial_model_file_name:
:return:
"""
# Compute the megabatch number
with open(args.batch_training_list, 'r') as fh:
batch_file_list = [l.rstrip() for l in fh]
# Shorten the batch_file_list to be a multiple of
megabatch_number = len(batch_file_list) // (args.averaging_step * args.num_processes)
megabatch_size = args.averaging_step * args.num_processes
print("Epoch {}, number of megabatches = {}".format(epoch, megabatch_number))
current_model = initial_model_file_name
# For each sublist: run an asynchronous training and averaging of the model
for ii in range(megabatch_number):
print('Process megabatch [{} / {}]'.format(ii + 1, megabatch_number))
current_model = train_asynchronous(epoch,
args,
current_model,
batch_file_list[megabatch_size * ii: megabatch_size * (ii + 1)],
ii,
megabatch_number) # function that split train, fuse and write the new model
return current_model
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
"""
Process one mega-batch of data asynchronously, average the model parameters across
subrocesses and return the updated version of the model
:param epoch:
:param args:
:param initial_model_file_name:
:param batch_file_list:
:param megabatch_idx:
:param megabatch_number:
:return:
"""
# Split the list of files for each process
sub_lists = split_file_list(batch_file_list, args.num_processes)
#
output_queue = mp.Queue()
# output_queue = multiprocessing.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=train_asynchronous_worker,
args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
)
# We first train the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
asynchronous_model = []
for ii in range(args.num_processes):
asynchronous_model.append(dict(output_queue.get()))
for p in processes:
p.join()
av_model = Xtractor(args.class_number, args.dropout)
tmp = av_model.state_dict()
average_param = dict()
for k in list(asynchronous_model[0].keys()):
average_param[k] = asynchronous_model[0][k]
for mod in asynchronous_model[1:]:
average_param[k] += mod[k]
if 'num_batches_tracked' not in k:
tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))
# return the file name of the new model
current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
megabatch_idx)
torch.save(tmp, current_model_file_name)
if megabatch_idx == megabatch_number:
torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))
return current_model_file_name
def train_asynchronous_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
"""
:param rank:
:param epoch:
:param args:
:param initial_model_file_name:
:param batch_list:
:param output_queue:
:return:
"""
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(initial_model_file_name))
model.train()
torch.manual_seed(args.seed + rank)
train_loader = XvectorMultiDataset(batch_list, args.batch_path)
device = torch.device("cuda:{}".format(rank))
model.to(device)
optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
], lr=args.lr)
criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data.to(device))
loss = criterion(output, target.to(device))
loss.backward()
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
model_param = OrderedDict()
params = model.state_dict()
for k in list(params.keys()):
model_param[k] = params[k].cpu().detach().numpy()
output_queue.put(model_param)
def cross_asynchronous_validation(args, current_model_file_name):
"""
:param args:
:param current_model_file_name:
:return:
"""
with open(args.cross_validation_list, 'r') as fh:
cross_validation_list = [l.rstrip() for l in fh]
sub_lists = split_file_list(cross_validation_list, args.num_processes)
#
output_queue = mp.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=cv_asynchronous_worker,
args=(rank, args, current_model_file_name, sub_lists[rank], output_queue)
)
# We first evaluate the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
result = []
for ii in range(args.num_processes):
result.append(output_queue.get())
for p in processes:
p.join()
# Compute the global accuracy
accuracy = 0.0
total_batch_number = 0
for bn, acc in result:
accuracy += acc
total_batch_number += bn
return 100. * accuracy / (total_batch_number * args.batch_size)
def cv_asynchronous_worker(rank, args, current_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(current_model_file_name))
model.eval()
cv_loader = XvectorMultiDataset(batch_list, args.batch_path)
device = torch.device("cuda:{}".format(rank))
model.to(device)
accuracy = 0.0
for batch_idx, (data, target) in enumerate(cv_loader):
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
def extract_idmap(args, device_id, segment_indices, fs_params, idmap_name, output_queue):
"""
Function that takes a model and an idmap and extract all x-vectors based on this model
......@@ -800,45 +560,3 @@ def extract_parallel(args, fs_params):
return x_server_1, x_server_2, x_server_3, x_server_4, x_server_5, x_server_6
def extract_embeddings(args):
"""
:param args:
:param device_id:
:param fs_params:
:return:
"""
device = torch.device("cuda:0")
# Load the model
logging.critical("*** Load model from = {}/{}".format(args.model_path, args.init_model_name))
model_file_name = '/'.join([args.model_path, args.init_model_name])
model = torch.load(model_file_name)
model = torch.nn.DataParallel(model)
model.eval()
model.to(device)
# Get the list of files
total_seg_df = pickle.load(open(args.batch_training_list, "rb"))
speaker_dict = {}
tmp = total_seg_df.speaker_id.unique()
tmp.sort()
for idx, spk in enumerate(tmp):
speaker_dict[spk] = idx
extract_transform = [CMVN(), ]
extract_set = VoxDataset(total_seg_df, speaker_dict, None, transform=transforms.Compose(extract_transform),
spec_aug_ratio=args.spec_aug, temp_aug_ratio=args.temp_aug)
extract_loader = DataLoader(extract_set, batch_size=1, shuffle=False, num_workers=5)
#CREER UN TENSEUR DE LA BONNE TAILLE POUR STOCKER LES X-VECTEURS
for batch_idx, (data, target, _, __) in enumerate(extract_loader):
print("extrait x-vecteur numero {}".format(batch_idx))
embedding = model.produce_embeddings(data.to(device))
#REMPLIR LE TENSEUR AVEC LE NOUVEAU X-VECTEUR
#FAIRE CORRESPONDRE LES SPK_ID AVEC LES X-VECTEURS
#RENVOYER LE TENSEUR DE X-VECTEURS SUR LE CPU OU L ECRTIRE SUR LE DISQUE
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment