Commit 5f3a133b authored by Anthony Larcher's avatar Anthony Larcher
Browse files

?

parent 307a61a8
......@@ -37,7 +37,7 @@ import torch
import torch.optim as optim
import torch.multiprocessing as mp
from collections import OrderedDict
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset, XvectorMultiDataset_hot
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from sidekit.bosaris import IdMap
from sidekit.statserver import StatServer
......@@ -258,22 +258,8 @@ def xtrain(args):
# Decrease learning rate after every epoch
#args.lr = args.lr * 0.9
#args.lr = args.lr * 0.9
def xtrain_hot(args):
# Initialize a first model and save to disk
model = XtractorHot(args.class_number, args.dropout)
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
for epoch in range(1, args.epochs + 1):
current_model_file_name = train_epoch_hot(epoch, args, current_model_file_name)
# Add the cross validation here
accuracy = cross_validation_hot(args, current_model_file_name)
print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
#args.lr = args.lr * 0.9
def train_epoch(epoch, args, initial_model_file_name):
# Compute the megabatch number
......@@ -299,30 +285,6 @@ def train_epoch(epoch, args, initial_model_file_name):
megabatch_number) # function that split train, fuse and write the new model to disk
return current_model
def train_epoch_hot(epoch, args, initial_model_file_name):
# Compute the megabatch number
with open(args.batch_training_list, 'r') as fh:
batch_file_list = [l.rstrip() for l in fh]
# Shorten the batch_file_list to be a multiple of
megabatch_number = len(batch_file_list) // (args.averaging_step * args.num_processes)
megabatch_size = args.averaging_step * args.num_processes
print("Epoch {}, number of megabatches = {}".format(epoch, megabatch_number))
current_model = initial_model_file_name
# For each sublist: run an asynchronous training and averaging of the model
for ii in range(megabatch_number):
print('Process megabatch [{} / {}]'.format(ii + 1, megabatch_number))
current_model = train_asynchronous_hot(epoch,
args,
current_model,
batch_file_list[megabatch_size * ii: megabatch_size * (ii + 1)],
ii,
megabatch_number) # function that split train, fuse and write the new model to disk
return current_model
def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
......@@ -375,57 +337,6 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
output_queue.put(model_param)
def train_worker_hot(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
model = XtractorHot(args.class_number, args.dropout)
model.load_state_dict(torch.load(initial_model_file_name))
model.train()
torch.manual_seed(args.seed + rank)
train_loader = XvectorMultiDataset_hot(batch_list, args.batch_path, args.class_number)
device = torch.device("cuda:{}".format(rank))
model.to(device)
# optimizer = optim.Adam(model.parameters(), lr = args.lr)
optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
], lr=args.lr)
#criterion = torch.nn.CrossEntropyLoss(reduction='sum')
#criterion = torch.nn.NLLLoss()
#criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data.to(device))
loss = model.LossFN(output, target.to(device))
loss.backward()
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == torch.argmax(target.to(device), 1)).sum()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
model_param = OrderedDict()
params = model.state_dict()
for k in list(params.keys()):
model_param[k] = params[k].cpu().detach().numpy()
output_queue.put(model_param)
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
# Split the list of files for each process
sub_lists = split_file_list(batch_file_list, args.num_processes)
......@@ -473,53 +384,6 @@ def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, me
return current_model_file_name
def train_asynchronous_hot(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
# Split the list of files for each process
sub_lists = split_file_list(batch_file_list, args.num_processes)
#
output_queue = mp.Queue()
# output_queue = multiprocessing.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=train_worker_hot,
args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
)
# We first train the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
asynchronous_model = []
for ii in range(args.num_processes):
asynchronous_model.append(dict(output_queue.get()))
for p in processes:
p.join()
av_model = XtractorHot(args.class_number, args.dropout)
tmp = av_model.state_dict()
average_param = dict()
for k in list(asynchronous_model[0].keys()):
average_param[k] = asynchronous_model[0][k]
for mod in asynchronous_model[1:]:
average_param[k] += mod[k]
if 'num_batches_tracked' not in k:
tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))
# return the file name of the new model
current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
megabatch_idx)
torch.save(tmp, current_model_file_name)
if megabatch_idx == megabatch_number:
torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))
return current_model_file_name
def cross_validation(args, current_model_file_name):
"""
......@@ -561,47 +425,6 @@ def cross_validation(args, current_model_file_name):
return 100. * accuracy / (total_batch_number * args.batch_size)
def cross_validation_hot(args, current_model_file_name):
"""
:param args:
:param current_model_file_name:
:return:
"""
with open(args.cross_validation_list, 'r') as fh:
cross_validation_list = [l.rstrip() for l in fh]
sub_lists = split_file_list(cross_validation_list, args.num_processes)
#
output_queue = mp.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=cv_worker_hot,
args=(rank, args, current_model_file_name, sub_lists[rank], output_queue)
)
# We first evaluate the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
result = []
for ii in range(args.num_processes):
result.append(output_queue.get())
for p in processes:
p.join()
# Compute the global accuracy
accuracy = 0.0
total_batch_number = 0
for bn, acc in result:
accuracy += acc
total_batch_number += bn
return 100. * accuracy / (total_batch_number * args.batch_size)
def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(current_model_file_name))
......@@ -618,24 +441,6 @@ def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
def cv_worker_hot(rank, args, current_model_file_name, batch_list, output_queue):
model = XtractorHot(args.class_number, args.dropout)
model.load_state_dict(torch.load(current_model_file_name))
model.eval()
cv_loader = XvectorMultiDataset_hot(batch_list, args.batch_path, args.class_number)
device = torch.device("cuda:{}".format(rank))
model.to(device)
accuracy = 0.0
for batch_idx, (data, target) in enumerate(cv_loader):
output = model(data.to(device))
#accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
accuracy += (torch.argmax(output.data, 1) == torch.argmax(target.to(device), 1)).sum()
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, output_queue):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment