Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Ambuj Mehrish
sidekit
Commits
43927406
Commit
43927406
authored
Mar 05, 2021
by
Anthony Larcher
Browse files
resnet
parent
0ff214ca
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/res_net.py
View file @
43927406
...
...
@@ -420,10 +420,10 @@ class BasicBlock(torch.nn.Module):
)
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
bn2
(
self
.
conv2
(
out
))
out
+=
self
.
shortcut
(
x
)
out
=
F
.
relu
(
out
)
out
=
torch
.
nn
.
functional
.
relu
(
out
)
return
out
...
...
@@ -488,7 +488,7 @@ class ResNet(torch.nn.Module):
return
torch
.
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
layer1
(
out
)
out
=
self
.
layer2
(
out
)
out
=
self
.
layer3
(
out
)
...
...
@@ -534,7 +534,8 @@ class PreResNet34(torch.nn.Module):
return
torch
.
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
x
.
unsqueeze
(
1
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
out
=
self
.
layer1
(
out
)
out
=
self
.
layer2
(
out
)
out
=
self
.
layer3
(
out
)
...
...
nnet/xvector.py
View file @
43927406
...
...
@@ -344,6 +344,8 @@ class Xtractor(torch.nn.Module):
(
"linear6"
,
torch
.
nn
.
Linear
(
3072
,
512
))
]))
self
.
embedding_size
=
512
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
512
,
int
(
self
.
speaker_number
),
...
...
@@ -375,21 +377,17 @@ class Xtractor(torch.nn.Module):
self
.
stat_pooling
=
MeanStdPooling
()
self
.
stat_pooling_weight_decay
=
0
if
loss
not
in
[
"cce"
,
'aam'
]:
raise
NotImplementedError
(
f
"The valid loss are for now cce and aam "
)
else
:
self
.
loss
=
loss
self
.
loss
=
"aam"
if
self
.
loss
==
"aam"
:
if
loss
==
'aam'
:
self
.
after_speaker_embedding
=
ArcLinear
(
256
,
int
(
self
.
speaker_number
),
margin
=
aam_margin
,
s
=
aam_s
)
self
.
after_speaker_embedding
=
ArcMarginProduct
(
256
,
int
(
self
.
speaker_number
),
s
=
64
,
m
=
0.2
,
easy_margin
=
True
)
elif
self
.
loss
==
"cce"
:
self
.
after_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
256
,
out_features
=
int
(
self
.
speaker_number
),
bias
=
True
)
self
.
embedding_size
=
256
self
.
preprocessor_weight_decay
=
0.000
self
.
sequence_network_weight_decay
=
0.000
self
.
stat_pooling_weight_decay
=
0.000
...
...
@@ -750,12 +748,14 @@ def xtrain(speaker_number,
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Start from scratch
if
model_name
is
None
and
model_yaml
in
[
"xvector"
,
"rawnet2"
]:
if
model_name
is
None
and
model_yaml
in
[
"xvector"
,
"rawnet2"
,
"resnet34"
]:
# Initialize a first model
if
model_yaml
==
"xvector"
:
model
=
Xtractor
(
speaker_number
,
"xvector"
,
loss
=
loss
)
elif
model_yaml
==
"rawnet2"
:
model
=
Xtractor
(
speaker_number
,
"rawnet2"
)
elif
model_yaml
==
"resnet34"
:
model
=
Xtractor
(
speaker_number
,
"resnet34"
)
model_archi
=
model_yaml
else
:
with
open
(
model_yaml
,
'r'
)
as
fh
:
...
...
@@ -988,6 +988,11 @@ def xtrain(speaker_number,
is_best
=
val_acc
>
best_accuracy
best_accuracy
=
max
(
val_acc
,
best_accuracy
)
if
tmp_model_name
is
None
:
tmp_model_name
=
"tmp_model"
if
best_model_name
is
None
:
best_model_name
=
"best_model"
if
type
(
model
)
is
Xtractor
:
save_checkpoint
({
'epoch'
:
epoch
,
...
...
@@ -1153,6 +1158,131 @@ def cross_validation(model, validation_loader, device, validation_shape):
loss
.
cpu
().
numpy
()
/
((
batch_idx
+
1
)
*
batch_size
),
equal_error_rate
class
XtractorTop
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_filename
,
loss
=
None
,
aam_margin
=
None
,
aam_s
=
None
):
"""
:param model_filename:
:param loss:
:param aam_margin:
:param aam_s:
"""
super
(
XtractorTop
,
self
).
__init__
()
# Load the model and only use the last part of it (not to use sequence_network)
checkpoint
=
torch
.
load
(
model_filename
,
map_location
=
'cpu'
)
cfg
=
checkpoint
[
"model_archi"
]
self
.
speaker_number
=
checkpoint
[
"speaker_number"
]
# Get activation function
if
cfg
[
"activation"
]
==
'LeakyReLU'
:
self
.
activation
=
torch
.
nn
.
LeakyReLU
(
0.2
)
elif
cfg
[
"activation"
]
==
'PReLU'
:
self
.
activation
=
torch
.
nn
.
PReLU
()
elif
cfg
[
"activation"
]
==
'ReLU6'
:
self
.
activation
=
torch
.
nn
.
ReLU6
()
else
:
self
.
activation
=
torch
.
nn
.
ReLU
()
model_layers
=
[]
for
k
in
cfg
[
"before_embedding"
].
keys
():
if
k
.
startswith
(
"lin"
):
input_size
=
checkpoint
[
"model_state_dict"
][
"before_speaker_embedding."
+
k
+
".weight"
].
shape
[
1
]
output_size
=
checkpoint
[
"model_state_dict"
][
"before_speaker_embedding."
+
k
+
".weight"
].
shape
[
0
]
model_layers
.
append
((
k
,
torch
.
nn
.
Linear
(
input_size
,
output_size
)))
elif
k
.
startswith
(
"activation"
):
model_layers
.
append
((
k
,
self
.
activation
))
elif
k
.
startswith
(
'batch_norm'
):
model_layers
.
append
((
k
,
torch
.
nn
.
BatchNorm1d
(
input_size
)))
elif
k
.
startswith
(
'dropout'
):
model_layers
.
append
((
k
,
torch
.
nn
.
Dropout
(
p
=
cfg
[
"before_embedding"
][
k
])))
self
.
before_speaker_embedding
=
torch
.
nn
.
Sequential
(
OrderedDict
(
model_layers
))
# if loss_criteria is "cce"
# Create sequential object for the second part of the network
if
checkpoint
[
"model_archi"
][
"training"
][
"loss"
]
==
"cce"
:
for
k
in
cfg
[
"after_embedding"
].
keys
():
if
k
.
startswith
(
"lin"
):
if
cfg
[
"after_embedding"
][
k
][
"output"
]
==
"speaker_number"
:
model_layers
.
append
((
k
,
torch
.
nn
.
Linear
(
input_size
,
self
.
speaker_number
)))
else
:
model_layers
.
append
((
k
,
torch
.
nn
.
Linear
(
input_size
,
cfg
[
"after_embedding"
][
k
][
"output"
])))
input_size
=
cfg
[
"after_embedding"
][
k
][
"output"
]
elif
k
.
startswith
(
'arc'
):
model_layers
.
append
((
k
,
ArcLinear
(
output_size
,
self
.
speaker_number
,
margin
=
aam_margin
,
s
=
aam_s
)))
elif
k
.
startswith
(
"activation"
):
model_layers
.
append
((
k
,
self
.
activation
))
elif
k
.
startswith
(
'batch_norm'
):
model_layers
.
append
((
k
,
torch
.
nn
.
BatchNorm1d
(
input_size
)))
elif
k
.
startswith
(
'dropout'
):
model_layers
.
append
((
k
,
torch
.
nn
.
Dropout
(
p
=
cfg
[
"after_embedding"
][
k
])))
model
=
torch
.
nn
.
Sequential
(
OrderedDict
(
model_layers
))
elif
checkpoint
[
"model_archi"
][
"training"
][
"loss"
]
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
output_size
,
int
(
self
.
speaker_number
),
s
=
64
,
m
=
0.2
,
easy_margin
=
True
)
# Now load layers from the file
new_model_dict
=
self
.
state_dict
()
pretrained_dict
=
checkpoint
[
"model_state_dict"
]
for
k
,
v
in
pretrained_dict
.
items
():
if
k
in
new_model_dict
:
new_model_dict
[
k
]
=
v
self
.
load_state_dict
(
new_model_dict
)
def
forward
(
self
,
x
,
is_eval
=
False
,
target
=
None
,):
"""
:param x:
:param is_eval: False for training
:return:
"""
x
=
self
.
before_speaker_embedding
(
x
)
if
self
.
loss
==
"cce"
:
if
is_eval
:
return
self
.
after_speaker_embedding
(
x
),
x
else
:
return
self
.
after_speaker_embedding
(
x
)
elif
self
.
loss
==
"aam"
:
if
not
is_eval
:
x
=
self
.
after_speaker_embedding
(
l2_norm
(
x
),
target
=
target
),
l2_norm
(
x
)
else
:
x
=
self
.
after_speaker_embedding
(
l2_norm
(
x
),
target
=
None
),
l2_norm
(
x
)
return
x
def
extract_embeddings
(
idmap_name
,
model_filename
,
data_root_name
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment