Commit 2e629255 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new voxceleb dataloader

parent 06bacb32
......@@ -166,6 +166,7 @@ if CUDA:
from sidekit.nnet import Xtractor
from sidekit.nnet import xtrain
from sidekit.nnet import xtrain_single
from sidekit.nnet import xtrain_new
from sidekit.nnet import extract_idmap
from sidekit.nnet import extract_parallel
#from sidekit.nnet import SAD_RNN
......
......@@ -31,7 +31,7 @@ Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
from sidekit.nnet.feed_forward import FForwardNetwork
from sidekit.nnet.feed_forward import kaldi_to_hdf5
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from sidekit.nnet.xvector import Xtractor, xtrain, extract_idmap, extract_parallel, xtrain_single
from sidekit.nnet.xvector import Xtractor, xtrain, extract_idmap, extract_parallel, xtrain_single, xtrain_new
__author__ = "Anthony Larcher and Sylvain Meignier"
......
......@@ -24,7 +24,6 @@
"""
Copyright 2014-2019 Anthony Larcher
The authors would like to thank the BUT Speech@FIT group (http://speech.fit.vutbr.cz) and Lukas BURGET
for sharing the source code that strongly inspired this module. Thank you for your valuable contribution.
"""
......@@ -54,7 +53,6 @@ __docformat__ = 'reStructuredText'
def read_batch(batch_file):
"""
:param batch_file:
:return:
"""
......@@ -78,7 +76,7 @@ class XvectorDataset(Dataset):
def __init__(self, batch_list, batch_path):
with open(batch_list, 'r') as fh:
self.batch_files = [batch_path + '/' + l.rstrip() for l in fh]
self.len = len(self.batch_files)
self.len = len(self.batch_files)
def __getitem__(self, index):
data, label = read_batch(self.batch_files[index])
......@@ -127,7 +125,7 @@ class VoxDataset(Dataset):
"""
"""
def __init__(self, pickle_file_list, batch_size, min_duration=200, max_duration=500):
def __init__(self, pickle_file_list, duration=500):
"""
:param batch_size:
......@@ -147,43 +145,33 @@ class VoxDataset(Dataset):
for idx, spk in enumerate(self.session_list.speaker_id.unique()):
self.speaker_dict[spk] = idx
self.len = len(self.session_list) // batch_size
self.current_session = 0
self.current_hdf5_file = None
self.min_duration = min_duration
self.max_duration = max_duration
self.batch_size = batch_size
self.len = len(self.session_list)
self.duration = duration
# Open the first file and get the feature size from there
self.fh = h5py.File(self.session_list.loc[self.current_session].hdf5_file, 'r')
self.feature_size = self.fh[self.session_list.session_id[0]].shape[1]
def get_item(self):
def __getitem__(self, index):
"""
:return:
"""
# Randomly pick a segment duration
duration = random.randrange(self.min_duration, self.max_duration)
data = numpy.empty((self.batch_size, self.feature_size, duration), dtype= numpy.float32)
label = numpy.empty(self.batch_size, dtype=numpy.long)
# Charger les batch_size sessions suivantes et créer les labels
for ii in range(self.batch_size):
# Si on change de fichier HDF5, on ferme le précédent et on ouvre le nouveau
if not self.session_list.loc[self.current_session].hdf5_file == self.fh.filename:
self.fh.close()
fh = h5py.File(self.session_list.loc[self.current_session].hdf5_file, 'r')
# ATTENTION ICI IS FAUTDECOMPRESSER LES DONNEES
start = int(self.session_list.start[self.current_session])
data[ii] = read_dataset_percentile(self.fh, self.session_list.session_id[self.current_session])[start:start + duration, :].T
#data[ii] = self.fh[self.session_list.session_id[self.current_session]][start:start + duration, :].T
label[ii] = self.speaker_dict[self.session_list.speaker_id[self.current_session]]
self.current_session += 1
#duration = random.randrange(self.min_duration, self.max_duration)
fh = h5py.File(self.session_list.loc[index].hdf5_file, 'r')
feature_size = fh[self.session_list.session_id[index]].shape[1]
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
start = int(self.session_list.start[index])
data = read_dataset_percentile(fh, self.session_list.session_id[index]).T
if data.shape[1] < start + self.duration:
print("probleme {}, {}".format(data.shape, start+ self.duration))
data = data[:, start:start + self.duration]
m = data.mean(axis=0)
s = data.std(axis=0)
data = (data - m) / s
label = self.speaker_dict[self.session_list.speaker_id[index]]
fh.close()
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(numpy.array([label,]).astype('long'))
def __len__(self):
"""
......@@ -192,3 +180,4 @@ class VoxDataset(Dataset):
:return:
"""
return self.len
......@@ -33,10 +33,10 @@ import torch
import torch.optim as optim
import torch.multiprocessing as mp
from collections import OrderedDict
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset, VoxDataset
from sidekit.bosaris import IdMap
from sidekit.statserver import StatServer
from torch.utils.data import DataLoader
__license__ = "LGPL"
__author__ = "Anthony Larcher"
......@@ -65,7 +65,7 @@ class Xtractor(torch.nn.Module):
"""
def __init__(self, spk_number, dropout):
super(Xtractor, self).__init__()
self.frame_conv0 = torch.nn.Conv1d(20, 512, 5, dilation=1)
self.frame_conv0 = torch.nn.Conv1d(30, 512, 5, dilation=1)
self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
self.frame_conv2 = torch.nn.Conv1d(512, 512, 3, dilation=3)
self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
......@@ -481,6 +481,40 @@ def xtrain_single(args):
torch.save(model, current_model_file_name)
def xtrain_new(args):
"""
Initialize and train an x-vector on a single GPU
:param args:
:return:
"""
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
model.train()
model.cuda()
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
for epoch in range(1, args.epochs + 1):
# Process one epoch and return the current model
model = train_epoch_new(model, epoch, args)
# Add the cross validation here
accuracy = cross_validation(args, model)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
args.lr = args.lr * 0.9
logging.critical(" Decrease learning rate: {}".format(args.lr))
# return the file name of the new model
current_model_file_name = "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch)
torch.save(model, current_model_file_name)
def train_epoch_single(model, epoch, args):
"""
......@@ -544,29 +578,16 @@ def train_epoch_new(model, epoch, args):
:return:
"""
pickle_file_list = ["/lium/raid01_c/larcher/voxceleb/vox1_1_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_1_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_2_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_3_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_4_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_5_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_6_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_7_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox2_8_dur500.pkl",
"/lium/raid01_c/larcher/voxceleb/vox1_0_dur500.pkl"]
device = device = torch.device("cuda:0")
torch.manual_seed(args.seed)
# Get the list of batches
print(args.batch_training_list)
#train_loader = XvectorMultiDataset(batch_list, args.batch_path)
train_set = VoxDataset(["clean_segment_list.pkl",], 500)
with open(args.batch_training_list, 'r') as fh:
batch_list = [l.rstrip() for l in fh]
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=10)
#train_loader = XvectorMultiDataset(batch_list, args.batch_path)
train_loader = VoxDataset(pickle_file_list, 128, 350, 500)
optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
......@@ -581,6 +602,8 @@ def train_epoch_new(model, epoch, args):
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
target = target.squeeze()
print("shape data = {} , label = {}".format(data.shape, target.shape))
optimizer.zero_grad()
output = model(data.to(device))
loss = criterion(output, target.to(device))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment