Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
9ab7a6b8
Commit
9ab7a6b8
authored
Jul 29, 2021
by
Le Lan Gaël
Browse files
local_rank fix
parent
3de11dde
Changes
3
Hide whitespace changes
Inline
Side-by-side
nnet/augmentation.py
View file @
9ab7a6b8
...
...
@@ -164,6 +164,9 @@ def data_augmentation(speech,
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
if
"none"
in
augmentations
:
pass
if
"stretch"
in
augmentations
:
stretched_length
=
int
(
speech
.
shape
[
1
]
*
random
.
uniform
(
0.95
,
1.05
))
speech
=
torch
.
zeros_like
(
speech
)
...
...
nnet/xsets.py
View file @
9ab7a6b8
...
...
@@ -177,7 +177,7 @@ class SideSet(Dataset):
overlap
=
0.
,
dataset_df
=
None
,
min_duration
=
0.165
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
):
"""
...
...
nnet/xvector.py
View file @
9ab7a6b8
...
...
@@ -536,12 +536,13 @@ class Xtractor(torch.nn.Module):
m
=
0.20
,
easy_margin
=
False
)
elif
self
.
loss
==
'aps'
:
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
))
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
),
emb_dim
=
self
.
embedding_size
)
self
.
preprocessor_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
self
.
stat_pooling_weight_decay
=
0.00002
self
.
before_speaker_embedding_weight_decay
=
0.00002
self
.
after_speaker_embedding_weight_decay
=
0
#0
.0002
self
.
after_speaker_embedding_weight_decay
=
0.0002
elif
model_archi
==
"rawnet2"
:
...
...
@@ -1094,7 +1095,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number
=
dataset_opts
[
'train'
][
'transform_number'
],
overlap
=
dataset_opts
[
'train'
][
'overlap'
],
dataset_df
=
training_df
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
)
validation_set
=
SideSet
(
dataset_opts
,
...
...
@@ -1109,15 +1110,14 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
samples_per_speaker
=
1
if
training_opts
[
"multi_gpu"
]:
assert
dataset_opts
[
"batch_size"
]
%
torch
.
cuda
.
device_count
()
==
0
assert
dataset_opts
[
"batch_size"
]
%
samples_per_speaker
==
0
batch_size
=
dataset_opts
[
"batch_size"
]
//
(
torch
.
cuda
.
device_count
()
*
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
])
#assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
#assert dataset_opts["batch_size"] % samples_per_speaker == 0
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
samples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"samples_per_speaker"
],
batch_size
=
batch_size
,
batch_size
=
dataset_opts
[
"batch_size"
]
*
torch
.
cuda
.
device_count
()
,
seed
=
training_opts
[
'torch_seed'
],
rank
=
local_rank
,
num_process
=
torch
.
cuda
.
device_count
(),
...
...
@@ -1131,7 +1131,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
batch_size
=
dataset_opts
[
"batch_size"
],
seed
=
training_opts
[
'torch_seed'
],
rank
=
0
,
num_process
=
torch
.
cuda
.
device_count
()
,
num_process
=
1
,
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replica"
]
)
...
...
@@ -1221,7 +1221,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
scheduler
=
torch
.
optim
.
lr_scheduler
.
CyclicLR
(
optimizer
=
optimizer
,
base_lr
=
1e-8
,
max_lr
=
train_opts
[
"lr"
],
step_size_up
=
training_loader
.
__len__
()
*
1
6
,
step_size_up
=
training_loader
.
__len__
()
*
1
0
,
step_size_down
=
None
,
cycle_momentum
=
cycle_momentum
,
mode
=
"triangular2"
)
...
...
@@ -1427,7 +1427,8 @@ def xtrain(dataset_description,
training_loader
,
validation_loader
,
\
sampler
,
validation_tar_indices
,
validation_non_indices
=
get_loaders
(
dataset_opts
,
training_opts
,
model_opts
)
model_opts
,
local_rank
)
if
local_rank
<
1
:
monitor
.
logger
.
info
(
f
"Start training process"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment