Commit 9593fd26 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

hot

parent 86d80832
......@@ -63,6 +63,18 @@ def read_batch(batch_file):
data[idx] = (data[idx] - m) / s
return data, label
def read_hot_batch(batch_file):
with h5py.File(batch_file, 'r') as h5f:
data = _read_dataset_percentile(h5f, 'data')
label = h5f['label'].value
# Normalize and reshape
data = data.reshape((len(label), data.shape[0] // len(label), data.shape[1])).transpose(0, 2, 1)
for idx in range(data.shape[0]):
m = data[idx].mean(axis=0)
s = data[idx].std(axis=0)
data[idx] = (data[idx] - m) / s
return data, label
class XvectorDataset(Dataset):
"""
......@@ -80,6 +92,21 @@ class XvectorDataset(Dataset):
def __len__(self):
return self.len
class XvectorHotDataset(Dataset):
"""
Object that takes a list of files from a file and initialize a Dataset
"""
def __init__(self, batch_list, batch_path):
with open(batch_list, 'r') as fh:
self.batch_files = [batch_path + '/' + l.rstrip() for l in fh]
self.len = len(self.batch_files)
def __getitem__(self, index):
data, label = read_hot_batch(self.batch_files[index])
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
def __len__(self):
return self.len
class XvectorMultiDataset(Dataset):
"""
......@@ -96,6 +123,20 @@ class XvectorMultiDataset(Dataset):
def __len__(self):
return self.len
class XvectorMultiDataset_hot(Dataset):
"""
Object that takes a list of files as a Python List and initialize a DataSet
"""
def __init__(self, batch_list, batch_path):
self.batch_files = [batch_path + '/' + l for l in batch_list]
self.len = len(self.batch_files)
def __getitem__(self, index):
data, label = read_batch(self.batch_files[index])
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
def __len__(self):
return self.len
class StatDataset(Dataset):
"""
......
......@@ -36,7 +36,7 @@ import torch
import torch.optim as optim
import torch.multiprocessing as mp
from collections import OrderedDict
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset, XvectorMultiDataset_hot
from sidekit.bosaris import IdMap
from sidekit.statserver import StatServer
......@@ -123,23 +123,12 @@ class Xtractor(torch.nn.Module):
torch.nn.init.xavier_uniform(self.seg_lin0.weight)
torch.nn.init.xavier_uniform(self.seg_lin1.weight)
torch.nn.init.xavier_uniform(self.seg_lin2.weight)
#torch.nn.init.normal_(self.seg_lin0.weight, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin1.weight, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin2.weight, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv0.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv1.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv2.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv3.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv4.bias, mean=-0.5, std=1.)
torch.nn.init.constant(self.frame_conv0.bias, 0.1)
torch.nn.init.constant(self.frame_conv1.bias, 0.1)
torch.nn.init.constant(self.frame_conv2.bias, 0.1)
torch.nn.init.constant(self.frame_conv3.bias, 0.1)
torch.nn.init.constant(self.frame_conv4.bias, 0.1)
#torch.nn.init.normal_(self.seg_lin0.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin1.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin2.bias, mean=-0.5, std=1.)
torch.nn.init.constant(self.seg_lin0.bias, 0.1)
torch.nn.init.constant(self.seg_lin1.bias, 0.1)
torch.nn.init.constant(self.seg_lin2.bias, 0.1)
......@@ -181,6 +170,21 @@ def xtrain(args):
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
def xtrain_hot(args):
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
for epoch in range(1, args.epochs + 1):
current_model_file_name = train_epoch_hot(epoch, args, current_model_file_name)
# Add the cross validation here
accuracy = cross_validation_hot(args, current_model_file_name)
print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
def train_epoch(epoch, args, initial_model_file_name):
# Compute the megabatch number
......@@ -206,6 +210,30 @@ def train_epoch(epoch, args, initial_model_file_name):
megabatch_number) # function that split train, fuse and write the new model to disk
return current_model
def train_epoch_hot(epoch, args, initial_model_file_name):
# Compute the megabatch number
with open(args.batch_training_list, 'r') as fh:
batch_file_list = [l.rstrip() for l in fh]
# Shorten the batch_file_list to be a multiple of
megabatch_number = len(batch_file_list) // (args.averaging_step * args.num_processes)
megabatch_size = args.averaging_step * args.num_processes
print("Epoch {}, number of megabatches = {}".format(epoch, megabatch_number))
current_model = initial_model_file_name
# For each sublist: run an asynchronous training and averaging of the model
for ii in range(megabatch_number):
print('Process megabatch [{} / {}]'.format(ii + 1, megabatch_number))
current_model = train_asynchronous_hot(epoch,
args,
current_model,
batch_file_list[megabatch_size * ii: megabatch_size * (ii + 1)],
ii,
megabatch_number) # function that split train, fuse and write the new model to disk
return current_model
def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
......@@ -213,7 +241,7 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
model.train()
torch.manual_seed(args.seed + rank)
train_loader = XvectorMultiDataset(batch_list, args.batch_path)
train_loader = XvectorMultiDataset(batch_list, args.batch_path)s
device = torch.device("cuda:{}".format(rank))
model.to(device)
......@@ -258,6 +286,57 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
output_queue.put(model_param)
def train_worker_hot(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(initial_model_file_name))
model.train()
torch.manual_seed(args.seed + rank)
train_loader = XvectorMultiDataset_hot(batch_list, args.batch_path)
device = torch.device("cuda:{}".format(rank))
model.to(device)
# optimizer = optim.Adam(model.parameters(), lr = args.lr)
optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
], lr=args.lr)
#criterion = torch.nn.CrossEntropyLoss(reduction='sum')
#criterion = torch.nn.NLLLoss()
#criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data.to(device))
loss = model.LossFN(output, target.to(device))
loss.backward()
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == torch.argmax(target.to(device), 1)).sum()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
model_param = OrderedDict()
params = model.state_dict()
for k in list(params.keys()):
model_param[k] = params[k].cpu().detach().numpy()
output_queue.put(model_param)
def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
# Split the list of files for each process
sub_lists = split_file_list(batch_file_list, args.num_processes)
......@@ -305,6 +384,53 @@ def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, me
return current_model_file_name
def train_asynchronous_hot(epoch, args, initial_model_file_name, batch_file_list, megabatch_idx, megabatch_number):
# Split the list of files for each process
sub_lists = split_file_list(batch_file_list, args.num_processes)
#
output_queue = mp.Queue()
# output_queue = multiprocessing.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=train_worker_hot,
args=(rank, epoch, args, initial_model_file_name, sub_lists[rank], output_queue)
)
# We first train the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
asynchronous_model = []
for ii in range(args.num_processes):
asynchronous_model.append(dict(output_queue.get()))
for p in processes:
p.join()
av_model = Xtractor(args.class_number, args.dropout)
tmp = av_model.state_dict()
average_param = dict()
for k in list(asynchronous_model[0].keys()):
average_param[k] = asynchronous_model[0][k]
for mod in asynchronous_model[1:]:
average_param[k] += mod[k]
if 'num_batches_tracked' not in k:
tmp[k] = torch.FloatTensor(average_param[k] / len(asynchronous_model))
# return the file name of the new model
current_model_file_name = "{}/model_{}_epoch_{}_batch_{}".format(args.model_path, args.expe_id, epoch,
megabatch_idx)
torch.save(tmp, current_model_file_name)
if megabatch_idx == megabatch_number:
torch.save(tmp, "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch))
return current_model_file_name
def cross_validation(args, current_model_file_name):
"""
......@@ -346,6 +472,47 @@ def cross_validation(args, current_model_file_name):
return 100. * accuracy / (total_batch_number * args.batch_size)
def cross_validation_hot(args, current_model_file_name):
"""
:param args:
:param current_model_file_name:
:return:
"""
with open(args.cross_validation_list, 'r') as fh:
cross_validation_list = [l.rstrip() for l in fh]
sub_lists = split_file_list(cross_validation_list, args.num_processes)
#
output_queue = mp.Queue()
processes = []
for rank in range(args.num_processes):
p = mp.Process(target=cv_worker_hot,
args=(rank, args, current_model_file_name, sub_lists[rank], output_queue)
)
# We first evaluate the model across `num_processes` processes
p.start()
processes.append(p)
# Average the models and write the new one to disk
result = []
for ii in range(args.num_processes):
result.append(output_queue.get())
for p in processes:
p.join()
# Compute the global accuracy
accuracy = 0.0
total_batch_number = 0
for bn, acc in result:
accuracy += acc
total_batch_number += bn
return 100. * accuracy / (total_batch_number * args.batch_size)
def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(current_model_file_name))
......@@ -362,6 +529,24 @@ def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
def cv_worker_hot(rank, args, current_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(current_model_file_name))
model.eval()
cv_loader = XvectorMultiDataset_hot(batch_list, args.batch_path)
device = torch.device("cuda:{}".format(rank))
model.to(device)
accuracy = 0.0
for batch_idx, (data, target) in enumerate(cv_loader):
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
A REMPLACER ICI
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, output_queue):
"""
Function that takes a model and an idmap and extract all x-vectors based on this model
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment