Commit a846c826 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

hot version

parent 9593fd26
......@@ -29,8 +29,8 @@ Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
from sidekit.nnet.feed_forward import FForwardNetwork
from sidekit.nnet.feed_forward import kaldi_to_hdf5
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from sidekit.nnet.xvector import Xtractor, xtrain, extract_idmap, extract_parallel
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset, XvectorMultiDataset_hot
from sidekit.nnet.xvector import Xtractor, xtrain, extract_idmap, extract_parallel, xtrain_hot
__author__ = "Anthony Larcher and Sylvain Meignier"
......
......@@ -74,7 +74,10 @@ def read_hot_batch(batch_file):
m = data[idx].mean(axis=0)
s = data[idx].std(axis=0)
data[idx] = (data[idx] - m) / s
return data, label
lbl = numpy.zeros((128, 7363))
lbl[numpy.arange(128), label] += 1
return data, lbl
class XvectorDataset(Dataset):
"""
......@@ -103,7 +106,7 @@ class XvectorHotDataset(Dataset):
def __getitem__(self, index):
data, label = read_hot_batch(self.batch_files[index])
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype(numpy.float32))
def __len__(self):
return self.len
......@@ -132,8 +135,8 @@ class XvectorMultiDataset_hot(Dataset):
self.len = len(self.batch_files)
def __getitem__(self, index):
data, label = read_batch(self.batch_files[index])
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
data, label = read_hot_batch(self.batch_files[index])
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype(numpy.float32))
def __len__(self):
return self.len
......
......@@ -105,11 +105,12 @@ class Xtractor(torch.nn.Module):
seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
# No batch-normalisation after this layer
seg_emb_5 = self.seg_lin2(seg_emb_4)
#seg_output = torch.nn.LogSoftmax(seg_emb_5)
return seg_emb_5
result = torch.nn.functional.softmax(self.activation(seg_emb_5),dim=1)
#return seg_emb_5
return result
def LossFN(self, x, lable):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(lable)))
def LossFN(self, x, label):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(label)))
return loss
def init_weights(self):
......@@ -241,7 +242,7 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
model.train()
torch.manual_seed(args.seed + rank)
train_loader = XvectorMultiDataset(batch_list, args.batch_path)s
train_loader = XvectorMultiDataset(batch_list, args.batch_path)
device = torch.device("cuda:{}".format(rank))
model.to(device)
......@@ -542,8 +543,9 @@ def cv_worker_hot(rank, args, current_model_file_name, batch_list, output_queue)
accuracy = 0.0
for batch_idx, (data, target) in enumerate(cv_loader):
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
A REMPLACER ICI
#accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
accuracy += (torch.argmax(output.data, 1) == torch.argmax(target.to(device), 1)).sum()
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment