Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
90fd5ec9
Commit
90fd5ec9
authored
Nov 05, 2019
by
Anthony Larcher
Browse files
xtractor single
parent
29f13393
Changes
3
Hide whitespace changes
Inline
Side-by-side
__init__.py
View file @
90fd5ec9
...
...
@@ -37,7 +37,7 @@ import importlib
# Read environment variable if it exists
SIDEKIT_CONFIG
=
{
"libsvm"
:
True
,
"mpi"
:
False
,
"cuda"
:
Fals
e
"cuda"
:
Tru
e
}
if
'SIDEKIT'
in
os
.
environ
:
...
...
@@ -165,9 +165,10 @@ if CUDA:
from
sidekit.nnet
import
StatDataset
from
sidekit.nnet
import
Xtractor
from
sidekit.nnet
import
xtrain
from
sidekit.nnet
import
xtrain_single
from
sidekit.nnet
import
extract_idmap
from
sidekit.nnet
import
extract_parallel
from
sidekit.nnet
import
SAD_RNN
#
from sidekit.nnet import SAD_RNN
else
:
print
(
"Don't import Torch"
)
...
...
nnet/__init__.py
View file @
90fd5ec9
...
...
@@ -27,11 +27,11 @@ Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
:mod:`nnet` provides methods to manage Neural Networks using PyTorch
"""
from
sidekit.nnet.sad_rnn
import
SAD_RNN
#
from sidekit.nnet.sad_rnn import SAD_RNN
from
sidekit.nnet.feed_forward
import
FForwardNetwork
from
sidekit.nnet.feed_forward
import
kaldi_to_hdf5
from
sidekit.nnet.xsets
import
XvectorMultiDataset
,
XvectorDataset
,
StatDataset
from
sidekit.nnet.xvector
import
Xtractor
,
xtrain
,
extract_idmap
,
extract_parallel
from
sidekit.nnet.xvector
import
Xtractor
,
xtrain
,
extract_idmap
,
extract_parallel
,
xtrain_single
__author__
=
"Anthony Larcher and Sylvain Meignier"
...
...
nnet/xvector.py
View file @
90fd5ec9
...
...
@@ -469,19 +469,23 @@ def xtrain_single(args):
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
# Process one epoch and return the current model
model
=
train_epoch_single
(
model
,
epoch
,
args
,
current_model_file_name
)
model
=
train_epoch_single
(
model
,
epoch
,
args
)
# Add the cross validation here
#accuracy = cross_validation_single(args, current_model_file_name)
#print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
#
args.lr = args.lr * 0.9
#
args.lr = args.lr * 0.9
#
print(" Decrease learning rate: {}".format(args.lr))
args
.
lr
=
args
.
lr
*
0.9
args
.
lr
=
args
.
lr
*
0.9
print
(
" Decrease learning rate: {}"
.
format
(
args
.
lr
))
# return the file name of the new model
current_model_file_name
=
"{}/model_{}_epoch_{}"
.
format
(
args
.
model_path
,
args
.
expe_id
,
epoch
)
torch
.
save
(
model
,
current_model_file_name
)
def
train_epoch_single
(
model
,
epoch
,
args
,
batch_list
,
output_queue
):
def
train_epoch_single
(
model
,
epoch
,
args
):
"""
:param model:
...
...
@@ -494,6 +498,13 @@ def train_epoch_single(model, epoch, args, batch_list, output_queue):
device
=
device
=
torch
.
device
(
"cuda:0"
)
torch
.
manual_seed
(
args
.
seed
)
# Get the list of batches
print
(
args
.
batch_training_list
)
with
open
(
args
.
batch_training_list
,
'r'
)
as
fh
:
batch_list
=
[
l
.
rstrip
()
for
l
in
fh
]
train_loader
=
XvectorMultiDataset
(
batch_list
,
args
.
batch_path
)
optimizer
=
optim
.
Adam
([{
'params'
:
model
.
frame_conv0
.
parameters
(),
'weight_decay'
:
args
.
l2_frame
},
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment