Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
748a7a3e
Commit
748a7a3e
authored
Apr 07, 2021
by
Anthony Larcher
Browse files
cleaning
parent
ae14ad83
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/xsets.py
View file @
748a7a3e
...
...
@@ -120,7 +120,6 @@ class SideSampler(torch.utils.data.Sampler):
if
self
.
segment_cursors
[
value
]
>
len
(
self
.
labels_to_indices
[
value
])
-
1
:
random
.
shuffle
(
self
.
labels_to_indices
[
value
])
self
.
segment_cursors
[
value
]
=
0
self
.
index_iterator
[
idx
]
=
self
.
labels_to_indices
[
value
][
self
.
segment_cursors
[
value
]]
self
.
segment_cursors
[
value
]
+=
1
return
iter
(
self
.
index_iterator
)
...
...
nnet/xvector.py
View file @
748a7a3e
...
...
@@ -223,10 +223,10 @@ def test_metrics(model,
[type]: [description]
"""
idmap_test_filename
=
'
h5f/vox1_test_cleaned
_idmap.h5'
ndx_test_filename
=
'
h5f/vox1_test_cleaned
_ndx.h5'
key_test_filename
=
'
h5f/vox1_test_cleaned
_key.h5'
data_root_name
=
'/
hdd/data/vox1/test
/wav'
idmap_test_filename
=
'
/lium/raid01_c/larcher/data/allies_dev_verif
_idmap.h5'
ndx_test_filename
=
'
/lium/raid01_c/larcher/data/allies_dev_verif
_ndx.h5'
key_test_filename
=
'
/lium/raid01_c/larcher/data/allies_dev_verif
_key.h5'
data_root_name
=
'/
lium/corpus/base/ALLIES
/wav'
transform_pipeline
=
dict
()
...
...
@@ -234,10 +234,10 @@ def test_metrics(model,
model_filename
=
model
,
data_root_name
=
data_root_name
,
device
=
device
,
loss
=
"aam"
,
transform_pipeline
=
transform_pipeline
,
num_thread
=
num_thread
,
mixed_precision
=
mixed_precision
,
backward
=
False
)
mixed_precision
=
mixed_precision
)
tar
,
non
=
cosine_scoring
(
xv_stat
,
xv_stat
,
...
...
@@ -891,7 +891,7 @@ class Xtractor(torch.nn.Module):
def
update_training_dictionary
(
dataset_description
,
model_description
,
kwargs
)
kwargs
)
:
"""
speaker_number,
dataset_yaml,
...
...
@@ -1562,7 +1562,7 @@ def xtrain(speaker_number,
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
training_df
,
validation_df
=
train_test_split
(
df
,
test_size
=
dataset_params
[
"validation_ratio"
]
)
#
, stratify=df["speaker_idx"])
training_df
,
validation_df
=
train_test_split
(
df
,
test_size
=
dataset_params
[
"validation_ratio"
]
,
stratify
=
df
[
"speaker_idx"
])
torch
.
manual_seed
(
dataset_params
[
'seed'
])
...
...
@@ -1602,6 +1602,7 @@ def xtrain(speaker_number,
num_workers
=
num_thread
,
persistent_workers
=
False
)
"""
Set the training options
"""
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment