Commit 90fd5ec9 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xtractor single

parent 29f13393
......@@ -37,7 +37,7 @@ import importlib
# Read environment variable if it exists
if 'SIDEKIT' in os.environ:
......@@ -165,9 +165,10 @@ if CUDA:
from sidekit.nnet import StatDataset
from sidekit.nnet import Xtractor
from sidekit.nnet import xtrain
from sidekit.nnet import xtrain_single
from sidekit.nnet import extract_idmap
from sidekit.nnet import extract_parallel
from sidekit.nnet import SAD_RNN
#from sidekit.nnet import SAD_RNN
print("Don't import Torch")
......@@ -27,11 +27,11 @@ Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
:mod:`nnet` provides methods to manage Neural Networks using PyTorch
from sidekit.nnet.sad_rnn import SAD_RNN
#from sidekit.nnet.sad_rnn import SAD_RNN
from sidekit.nnet.feed_forward import FForwardNetwork
from sidekit.nnet.feed_forward import kaldi_to_hdf5
from sidekit.nnet.xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from sidekit.nnet.xvector import Xtractor, xtrain, extract_idmap, extract_parallel
from sidekit.nnet.xvector import Xtractor, xtrain, extract_idmap, extract_parallel, xtrain_single
__author__ = "Anthony Larcher and Sylvain Meignier"
......@@ -469,19 +469,23 @@ def xtrain_single(args):
for epoch in range(1, args.epochs + 1):
# Process one epoch and return the current model
model = train_epoch_single(model, epoch, args, current_model_file_name)
model = train_epoch_single(model, epoch, args)
# Add the cross validation here
#accuracy = cross_validation_single(args, current_model_file_name)
#print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch = * 0.9 = * 0.9
#print(" Decrease learning rate: {}".format( = * 0.9 = * 0.9
print(" Decrease learning rate: {}".format(
# return the file name of the new model
current_model_file_name = "{}/model_{}_epoch_{}".format(args.model_path, args.expe_id, epoch), current_model_file_name)
def train_epoch_single(model, epoch, args, batch_list, output_queue):
def train_epoch_single(model, epoch, args):
:param model:
......@@ -494,6 +498,13 @@ def train_epoch_single(model, epoch, args, batch_list, output_queue):
device = device = torch.device("cuda:0")
# Get the list of batches
with open(args.batch_training_list, 'r') as fh:
batch_list = [l.rstrip() for l in fh]
train_loader = XvectorMultiDataset(batch_list, args.batch_path)
optimizer = optim.Adam([{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment