Commit 23f92f53 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

more flexible xvectors

parent 045f0e51
......@@ -63,7 +63,7 @@ def read_batch(batch_file):
data[idx] = (data[idx] - m) / s
return data, label
def read_hot_batch(batch_file):
def read_hot_batch(batch_file, spk_nb):
with h5py.File(batch_file, 'r') as h5f:
data = _read_dataset_percentile(h5f, 'data')
label = h5f['label'].value
......@@ -75,7 +75,7 @@ def read_hot_batch(batch_file):
s = data[idx].std(axis=0)
data[idx] = (data[idx] - m) / s
lbl = numpy.zeros((128, 7363))
lbl = numpy.zeros((128, spk_nb))
lbl[numpy.arange(128), label] += 1
return data, lbl
......@@ -99,13 +99,14 @@ class XvectorHotDataset(Dataset):
"""
Object that takes a list of files from a file and initialize a Dataset
"""
def __init__(self, batch_list, batch_path):
def __init__(self, batch_list, batch_path, spk_nb):
with open(batch_list, 'r') as fh:
self.batch_files = [batch_path + '/' + l.rstrip() for l in fh]
self.len = len(self.batch_files)
self.spk_nb = spk_nb
def __getitem__(self, index):
data, label = read_hot_batch(self.batch_files[index])
data, label = read_hot_batch(self.batch_files[index], self.spk_nb)
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype(numpy.float32))
def __len__(self):
......@@ -130,12 +131,13 @@ class XvectorMultiDataset_hot(Dataset):
"""
Object that takes a list of files as a Python List and initialize a DataSet
"""
def __init__(self, batch_list, batch_path):
def __init__(self, batch_list, batch_path, spk_nb):
self.batch_files = [batch_path + '/' + l for l in batch_list]
self.len = len(self.batch_files)
self.spk_nb = spk_nb
def __getitem__(self, index):
data, label = read_hot_batch(self.batch_files[index])
data, label = read_hot_batch(self.batch_files[index], self.spk_nb)
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype(numpy.float32))
def __len__(self):
......
......@@ -31,6 +31,7 @@ for sharing the source code that strongly inspired this module. Thank you for yo
import h5py
import logging
import sys
import numpy
import torch
import torch.optim as optim
......@@ -50,7 +51,7 @@ __status__ = "Production"
__docformat__ = 'reS'
#logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def split_file_list(batch_files, num_processes):
......@@ -341,9 +342,9 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
], lr=args.lr)
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
#criterion = torch.nn.CrossEntropyLoss(reduction='sum')
#criterion = torch.nn.NLLLoss()
#criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
......@@ -375,7 +376,7 @@ def train_worker_hot(rank, epoch, args, initial_model_file_name, batch_list, out
model.train()
torch.manual_seed(args.seed + rank)
train_loader = XvectorMultiDataset_hot(batch_list, args.batch_path)
train_loader = XvectorMultiDataset_hot(batch_list, args.batch_path, args.class_number)
device = torch.device("cuda:{}".format(rank))
model.to(device)
......@@ -617,7 +618,7 @@ def cv_worker_hot(rank, args, current_model_file_name, batch_list, output_queue)
model.load_state_dict(torch.load(current_model_file_name))
model.eval()
cv_loader = XvectorMultiDataset_hot(batch_list, args.batch_path)
cv_loader = XvectorMultiDataset_hot(batch_list, args.batch_path, args.class_number)
device = torch.device("cuda:{}".format(rank))
model.to(device)
......@@ -667,9 +668,8 @@ def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, outpu
# Loop to extract all x-vectors
for idx, (model_id, segment_id, data) in enumerate(segment_loader):
print('Extract X-vector for {}\t[{} / {}]'.format(segment_id, idx, segment_loader.__len__()))
print("shape of data = {}".format(list(data.shape)))
print("shape[2] = {}".format(list(data.shape)[2]))
logging.critical('Process file {}, [{} / {}]'.format(segment_id, idx, segment_loader.__len__()))
#print('Process file {}'.format(segment_id))
if list(data.shape)[2] < 20:
pass
else:
......@@ -692,6 +692,7 @@ def extract_parallel(args, fs_params, dataset):
idmap_name = args.back_idmap
idmap = IdMap(idmap_name)
x_server_A = StatServer(idmap, 1, emb_a_size)
x_server_B = StatServer(idmap, 1, emb_b_size)
x_server_A.stat0 = numpy.ones(x_server_A.stat0.shape)
......@@ -699,10 +700,16 @@ def extract_parallel(args, fs_params, dataset):
# Split the indices
mega_batch_size = idmap.leftids.shape[0] // args.num_processes
logging.critical("Number of sessions to process: {}".format(idmap.leftids.shape[0]))
segment_idx = []
for ii in range(args.num_processes):
segment_idx.append(
numpy.arange(ii * mega_batch_size, numpy.max([(ii + 1) * mega_batch_size, idmap.leftids.shape[0]])))
numpy.arange(ii * mega_batch_size, numpy.min([(ii + 1) * mega_batch_size, idmap.leftids.shape[0]])))
for idx, si in enumerate(segment_idx):
logging.critical("Number of session on process {}: {}".format(idx, len(si)))
# Extract x-vectors in parallel
output_queue = mp.Queue()
......@@ -725,4 +732,11 @@ def extract_parallel(args, fs_params, dataset):
for p in processes:
p.join()
print("Process parallel fini")
return x_server_A, x_server_B
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment