Commit 1c1f05fb authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xvector add cross validation

parent 1cda872e
......@@ -30,7 +30,7 @@ Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
from sidekit.nnet.feed_forward import FForwardNetwork
from sidekit.nnet.feed_forward import kaldi_to_hdf5
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from sidekit.nnet.xvectors import Xtractor, xtrain, extract_idmap, extract_parallel
from sidekit.nnet.xvector import Xtractor, xtrain, extract_idmap, extract_parallel
__author__ = "Anthony Larcher and Sylvain Meignier"
......
......@@ -160,7 +160,7 @@ def xtrain(args):
# Add the cross validation here
accuracy = cross_validation(args, current_model_file_name)
print("*** Cross validation accuracy = {.02f} %".format(accuracy))
print("*** Cross validation accuracy = {} %".format(accuracy))
def train_epoch(epoch, args, initial_model_file_name):
......@@ -293,7 +293,7 @@ def cross_validation(args, current_model_file_name):
:return:
"""
with open(args.cross_validation_list, 'r') as fh:
cross_validation_list = [l.rstrip() for l in fh
cross_validation_list = [l.rstrip() for l in fh]
sub_lists = split_file_list(cross_validation_list, args.num_processes)
#
......@@ -319,10 +319,10 @@ def cross_validation(args, current_model_file_name):
# Compute the global accuracy
accuracy = 0.0
total_batch_number = 0
for batch_number, acc in result:
for bn, acc in result:
accuracy += acc
total_batch_number += batch_number
total_batch_number += bn
return 100. * accuracy / (total_batch_number * args.batch_size)
......@@ -337,11 +337,10 @@ def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
model.to(device)
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
for batch_idx, (data, target) in enumerate(cv_loader):
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
output_queue.put((train_loader.__len__(), accuracy))
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, output_queue):
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment