Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
d7b14542
Commit
d7b14542
authored
Apr 12, 2021
by
Anthony Larcher
Browse files
debug
parent
6c627622
Changes
1
Hide whitespace changes
Inline
Side-by-side
nnet/xvector.py
View file @
d7b14542
...
...
@@ -1006,7 +1006,7 @@ def update_training_dictionary(dataset_description,
return
dataset_opts
,
model_opts
,
training_opts
def
get_network
(
model_opts
):
def
get_network
(
model_opts
,
local_rank
):
"""
:param model_opts:
...
...
@@ -1041,18 +1041,19 @@ def get_network(model_opts):
if
name
.
split
(
"."
)[
0
]
in
model_opts
[
"reset_parts"
]:
param
.
requires_grad
=
False
logging
.
critical
(
model
)
logging
.
critical
(
"model_parameters_count: {:d}"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
sequence_network
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
before_speaker_embedding
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
stat_pooling
.
parameters
()
if
p
.
requires_grad
)))
if
local_rank
<
1
:
logging
.
info
(
model
)
logging
.
info
(
"Model_parameters_count: {:d}"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
sequence_network
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
before_speaker_embedding
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
stat_pooling
.
parameters
()
if
p
.
requires_grad
)))
return
model
...
...
@@ -1306,13 +1307,17 @@ def new_xtrain(dataset_description,
# Initialize the model
model
=
get_network
(
model_opts
)
speaker_number
=
model
.
speaker_number
#
speaker_number = model.speaker_number
embedding_size
=
model
.
embedding_size
# Set the device and manage parallel processing
#device = torch.cuda.device(local_rank)
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
device
(
local_rank
)
if
training_opts
[
"multi_gpu"
]:
model
=
torch
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
model
.
to
(
device
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment