Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Ambuj Mehrish
sidekit
Commits
3cfe03cc
Commit
3cfe03cc
authored
Mar 15, 2019
by
Anthony Larcher
Browse files
xtractorHot
parent
1b768454
Changes
1
Hide whitespace changes
Inline
Side-by-side
nnet/xvector.py
View file @
3cfe03cc
...
...
@@ -105,8 +105,7 @@ class Xtractor(torch.nn.Module):
seg_emb_4
=
self
.
norm7
(
self
.
activation
(
self
.
seg_lin1
(
seg_emb_3
)))
# No batch-normalisation after this layer
seg_emb_5
=
self
.
seg_lin2
(
seg_emb_4
)
result
=
torch
.
nn
.
functional
.
softmax
(
self
.
activation
(
seg_emb_5
),
dim
=
1
)
#return seg_emb_5
result
=
self
.
activation
(
seg_emb_5
)
return
result
def
LossFN
(
self
,
x
,
label
):
...
...
@@ -140,19 +139,102 @@ class Xtractor(torch.nn.Module):
frame_emb_2
=
self
.
norm2
(
self
.
activation
(
self
.
frame_conv2
(
frame_emb_1
)))
frame_emb_3
=
self
.
norm3
(
self
.
activation
(
self
.
frame_conv3
(
frame_emb_2
)))
frame_emb_4
=
self
.
norm4
(
self
.
activation
(
self
.
frame_conv4
(
frame_emb_3
)))
mean
=
torch
.
mean
(
frame_emb_4
,
dim
=
2
)
std
=
torch
.
std
(
frame_emb_4
,
dim
=
2
)
seg_emb
=
torch
.
cat
([
mean
,
std
],
dim
=
1
)
embedding_A
=
self
.
seg_lin0
(
seg_emb
)
embedding_B
=
self
.
seg_lin1
(
self
.
norm6
(
self
.
activation
(
embedding_A
)))
return
embedding_A
,
embedding_B
class
XtractorHot
(
Xtractor
):
def
__init__
(
self
,
spk_number
,
dropout
):
super
(
Xtractor
,
self
).
__init__
()
self
.
frame_conv0
=
torch
.
nn
.
Conv1d
(
20
,
512
,
5
,
dilation
=
1
)
self
.
frame_conv1
=
torch
.
nn
.
Conv1d
(
512
,
512
,
3
,
dilation
=
2
)
self
.
frame_conv2
=
torch
.
nn
.
Conv1d
(
512
,
512
,
3
,
dilation
=
3
)
self
.
frame_conv3
=
torch
.
nn
.
Conv1d
(
512
,
512
,
1
)
self
.
frame_conv4
=
torch
.
nn
.
Conv1d
(
512
,
3
*
512
,
1
)
self
.
seg_lin0
=
torch
.
nn
.
Linear
(
3
*
512
*
2
,
512
)
self
.
dropout_lin0
=
torch
.
nn
.
Dropout
(
p
=
dropout
)
self
.
seg_lin1
=
torch
.
nn
.
Linear
(
512
,
512
)
self
.
dropout_lin1
=
torch
.
nn
.
Dropout
(
p
=
dropout
)
self
.
seg_lin2
=
torch
.
nn
.
Linear
(
512
,
spk_number
)
#
self
.
norm0
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm1
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm2
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm3
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm4
=
torch
.
nn
.
BatchNorm1d
(
3
*
512
)
self
.
norm6
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm7
=
torch
.
nn
.
BatchNorm1d
(
512
)
#
self
.
activation
=
torch
.
nn
.
LeakyReLU
(
0.2
)
def
forward
(
self
,
x
):
frame_emb_0
=
self
.
norm0
(
self
.
activation
(
self
.
frame_conv0
(
x
)))
frame_emb_1
=
self
.
norm1
(
self
.
activation
(
self
.
frame_conv1
(
frame_emb_0
)))
frame_emb_2
=
self
.
norm2
(
self
.
activation
(
self
.
frame_conv2
(
frame_emb_1
)))
frame_emb_3
=
self
.
norm3
(
self
.
activation
(
self
.
frame_conv3
(
frame_emb_2
)))
frame_emb_4
=
self
.
norm4
(
self
.
activation
(
self
.
frame_conv4
(
frame_emb_3
)))
# Pooling Layer that computes mean and standard devition of frame level embeddings
# The output of the pooling layer is the first segment-level representation
mean
=
torch
.
mean
(
frame_emb_4
,
dim
=
2
)
std
=
torch
.
std
(
frame_emb_4
,
dim
=
2
)
seg_emb
=
torch
.
cat
([
mean
,
std
],
dim
=
1
)
seg_emb_0
=
torch
.
cat
([
mean
,
std
],
dim
=
1
)
# batch-normalisation after this layer
seg_emb_1
=
self
.
dropout_lin0
(
seg_emb_0
)
seg_emb_2
=
self
.
norm6
(
self
.
activation
(
self
.
seg_lin0
(
seg_emb_1
)))
# new layer with batch Normalization
seg_emb_3
=
self
.
dropout_lin1
(
seg_emb_2
)
seg_emb_4
=
self
.
norm7
(
self
.
activation
(
self
.
seg_lin1
(
seg_emb_3
)))
# No batch-normalisation after this layer
# seg_emb_1 = self.activation(self.seg_lin0(seg_emb_0))
seg_emb_5
=
self
.
seg_lin2
(
seg_emb_4
)
result
=
torch
.
nn
.
functional
.
softmax
(
self
.
activation
(
seg_emb_5
),
dim
=
1
)
return
result
def
LossFN
(
self
,
x
,
label
):
loss
=
-
torch
.
trace
(
torch
.
mm
(
torch
.
log10
(
x
),
torch
.
t
(
label
)))
return
loss
def
init_weights
(
self
):
"""
"""
torch
.
nn
.
init
.
normal_
(
self
.
frame_conv0
.
weight
,
mean
=-
0.5
,
std
=
0.1
)
torch
.
nn
.
init
.
normal_
(
self
.
frame_conv1
.
weight
,
mean
=-
0.5
,
std
=
0.1
)
torch
.
nn
.
init
.
normal_
(
self
.
frame_conv2
.
weight
,
mean
=-
0.5
,
std
=
0.1
)
torch
.
nn
.
init
.
normal_
(
self
.
frame_conv3
.
weight
,
mean
=-
0.5
,
std
=
0.1
)
torch
.
nn
.
init
.
normal_
(
self
.
frame_conv4
.
weight
,
mean
=-
0.5
,
std
=
0.1
)
torch
.
nn
.
init
.
xavier_uniform
(
self
.
seg_lin0
.
weight
)
torch
.
nn
.
init
.
xavier_uniform
(
self
.
seg_lin1
.
weight
)
torch
.
nn
.
init
.
xavier_uniform
(
self
.
seg_lin2
.
weight
)
torch
.
nn
.
init
.
constant
(
self
.
frame_conv0
.
bias
,
0.1
)
torch
.
nn
.
init
.
constant
(
self
.
frame_conv1
.
bias
,
0.1
)
torch
.
nn
.
init
.
constant
(
self
.
frame_conv2
.
bias
,
0.1
)
torch
.
nn
.
init
.
constant
(
self
.
frame_conv3
.
bias
,
0.1
)
torch
.
nn
.
init
.
constant
(
self
.
frame_conv4
.
bias
,
0.1
)
torch
.
nn
.
init
.
constant
(
self
.
seg_lin0
.
bias
,
0.1
)
torch
.
nn
.
init
.
constant
(
self
.
seg_lin1
.
bias
,
0.1
)
torch
.
nn
.
init
.
constant
(
self
.
seg_lin2
.
bias
,
0.1
)
def
extract
(
self
,
x
):
frame_emb_0
=
self
.
norm0
(
self
.
activation
(
self
.
frame_conv0
(
x
)))
frame_emb_1
=
self
.
norm1
(
self
.
activation
(
self
.
frame_conv1
(
frame_emb_0
)))
frame_emb_2
=
self
.
norm2
(
self
.
activation
(
self
.
frame_conv2
(
frame_emb_1
)))
frame_emb_3
=
self
.
norm3
(
self
.
activation
(
self
.
frame_conv3
(
frame_emb_2
)))
frame_emb_4
=
self
.
norm4
(
self
.
activation
(
self
.
frame_conv4
(
frame_emb_3
)))
seg_emb_A
=
self
.
seg_lin0
(
seg_emb
)
seg_emb_B
=
self
.
seg_lin1
(
self
.
activation
(
seg_emb_A
))
mean
=
torch
.
mean
(
frame_emb_4
,
dim
=
2
)
std
=
torch
.
std
(
frame_emb_4
,
dim
=
2
)
seg_emb
=
torch
.
cat
([
mean
,
std
],
dim
=
1
)
# return torch.nn.functional.softmax(seg_emb_3,dim=1)
return
seg_emb_A
,
seg_emb_B
embedding_A
=
self
.
seg_lin0
(
seg_emb
)
embedding_B
=
self
.
seg_lin1
(
self
.
norm6
(
self
.
activation
(
embedding_A
)))
return
embedding_A
,
embedding_B
def
xtrain
(
args
):
...
...
@@ -169,11 +251,11 @@ def xtrain(args):
print
(
"*** Cross validation accuracy = {} %"
.
format
(
accuracy
))
# Decrease learning rate after every epoch
args
.
lr
=
args
.
lr
*
0.9
#
args.lr = args.lr * 0.9
def
xtrain_hot
(
args
):
# Initialize a first model and save to disk
model
=
Xtractor
(
args
.
class_number
,
args
.
dropout
)
model
=
Xtractor
Hot
(
args
.
class_number
,
args
.
dropout
)
current_model_file_name
=
"initial_model"
torch
.
save
(
model
.
state_dict
(),
current_model_file_name
)
...
...
@@ -185,7 +267,7 @@ def xtrain_hot(args):
print
(
"*** Cross validation accuracy = {} %"
.
format
(
accuracy
))
# Decrease learning rate after every epoch
args
.
lr
=
args
.
lr
*
0.9
#
args.lr = args.lr * 0.9
def
train_epoch
(
epoch
,
args
,
initial_model_file_name
):
# Compute the megabatch number
...
...
@@ -288,7 +370,7 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
def
train_worker_hot
(
rank
,
epoch
,
args
,
initial_model_file_name
,
batch_list
,
output_queue
):
model
=
Xtractor
(
args
.
class_number
,
args
.
dropout
)
model
=
Xtractor
Hot
(
args
.
class_number
,
args
.
dropout
)
model
.
load_state_dict
(
torch
.
load
(
initial_model_file_name
))
model
.
train
()
...
...
@@ -410,7 +492,7 @@ def train_asynchronous_hot(epoch, args, initial_model_file_name, batch_file_list
for
p
in
processes
:
p
.
join
()
av_model
=
Xtractor
(
args
.
class_number
,
args
.
dropout
)
av_model
=
Xtractor
Hot
(
args
.
class_number
,
args
.
dropout
)
tmp
=
av_model
.
state_dict
()
average_param
=
dict
()
...
...
@@ -531,7 +613,7 @@ def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
output_queue
.
put
((
cv_loader
.
__len__
(),
accuracy
.
cpu
().
numpy
()))
def
cv_worker_hot
(
rank
,
args
,
current_model_file_name
,
batch_list
,
output_queue
):
model
=
Xtractor
(
args
.
class_number
,
args
.
dropout
)
model
=
Xtractor
Hot
(
args
.
class_number
,
args
.
dropout
)
model
.
load_state_dict
(
torch
.
load
(
current_model_file_name
))
model
.
eval
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment