Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Ambuj Mehrish
sidekit
Commits
83a75e72
Commit
83a75e72
authored
Sep 25, 2020
by
Anthony Larcher
Browse files
new interface xtrain
parent
48f2656a
Changes
1
Hide whitespace changes
Inline
Side-by-side
nnet/xvector.py
View file @
83a75e72
...
...
@@ -623,20 +623,47 @@ def xtrain(speaker_number,
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Start from scratch
if
model_name
is
None
:
if
model_name
is
None
and
model_yaml
in
[
"xvector"
,
"rawnet2"
]
:
# Initialize a first model
if
model_yaml
==
"xvector"
:
model
=
Xtractor
(
speaker_number
,
"xvector"
)
elif
model_yaml
==
"rawnet2"
:
model
=
Xtractor
(
speaker_number
,
"rawnet2"
)
else
:
with
open
(
model_yaml
,
'r'
)
as
fh
:
model_archi
=
yaml
.
load
(
fh
,
Loader
=
yaml
.
FullLoader
)
if
epochs
is
None
:
epochs
=
model_archi
[
"training"
][
"epochs"
]
if
patience
is
None
:
patience
=
model_archi
[
"training"
][
"patience"
]
if
opt
is
None
:
opt
=
model_archi
[
"training"
][
"opt"
]
if
lr
is
None
:
lr
=
model_archi
[
"training"
][
"lr"
]
if
loss
is
None
:
loss
=
model_archi
[
"training"
][
"loss"
]
if
aam_margin
is
None
and
model_archi
[
"training"
][
"loss"
]
==
"aam"
:
aam_margin
=
model_archi
[
"training"
][
"aam_margin"
]
if
aam_s
is
None
and
model_archi
[
"training"
][
"loss"
]
==
"aam"
:
aam_s
=
model_archi
[
"training"
][
"aam_s"
]
if
tmp_model_name
is
None
:
tmp_model_name
=
model_archi
[
"training"
][
"tmp_model_name"
]
if
best_model_name
is
None
:
best_model_name
=
model_archi
[
"training"
][
"best_model_name"
]
if
multi_gpu
is
None
:
multi_gpu
=
model_archi
[
"training"
][
"multi_gpu"
]
if
clipping
is
None
:
clipping
=
model_archi
[
"training"
][
"clipping"
]
if
model_name
is
None
model
=
Xtractor
(
speaker_number
,
model_yaml
)
# If we start from an existing model
else
:
# Load the model
logging
.
critical
(
f
"*** Load model from =
{
model_name
}
"
)
checkpoint
=
torch
.
load
(
model_name
)
model
=
Xtractor
(
speaker_number
,
model_yaml
)
# If we start from an existing model
else
:
# Load the model
logging
.
critical
(
f
"*** Load model from =
{
model_name
}
"
)
checkpoint
=
torch
.
load
(
model_name
)
model
=
Xtractor
(
speaker_number
,
model_yaml
)
"""
Here we remove all layers that we don't want to reload
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment