Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
8ccb5265
Commit
8ccb5265
authored
Apr 13, 2021
by
Anthony Larcher
Browse files
track bug
parent
80656610
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/xsets.py
View file @
8ccb5265
...
@@ -62,7 +62,6 @@ class SideSampler(torch.utils.data.Sampler):
...
@@ -62,7 +62,6 @@ class SideSampler(torch.utils.data.Sampler):
batch_size
,
batch_size
,
seed
=
0
,
seed
=
0
,
rank
=
0
,
rank
=
0
,
num_process
=
1
,
num_replicas
=
1
):
num_replicas
=
1
):
"""[summary]
"""[summary]
...
@@ -82,33 +81,29 @@ class SideSampler(torch.utils.data.Sampler):
...
@@ -82,33 +81,29 @@ class SideSampler(torch.utils.data.Sampler):
self
.
epoch
=
0
self
.
epoch
=
0
self
.
seed
=
seed
self
.
seed
=
seed
self
.
rank
=
rank
self
.
rank
=
rank
self
.
num_process
=
num_process
self
.
num_replicas
=
num_replicas
self
.
num_replicas
=
num_replicas
assert
batch_size
%
examples_per_speaker
==
0
assert
batch_size
%
examples_per_speaker
==
0
assert
(
self
.
samples_per_speaker
*
self
.
spk_count
*
self
.
examples_per_speaker
)
%
self
.
num_
proces
s
==
0
assert
(
self
.
samples_per_speaker
*
self
.
spk_count
*
self
.
examples_per_speaker
)
%
self
.
num_
replica
s
==
0
self
.
batch_size
=
batch_size
//
(
self
.
examples_per_speaker
*
self
.
num_replicas
)
self
.
batch_size
=
batch_size
//
examples_per_speaker
# reference all segment indexes per speaker
# reference all segment indexes per speaker
for
idx
in
range
(
self
.
spk_count
):
for
idx
in
range
(
self
.
spk_count
):
self
.
labels_to_indices
[
idx
]
=
list
()
self
.
labels_to_indices
[
idx
]
=
list
()
for
idx
,
value
in
enumerate
(
self
.
train_sessions
):
for
idx
,
value
in
enumerate
(
self
.
train_sessions
):
self
.
labels_to_indices
[
value
].
append
(
idx
)
self
.
labels_to_indices
[
value
].
append
(
idx
)
# s
h
uffle segments per speaker
# suffle segments per speaker
g
=
torch
.
Generator
()
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
for
idx
,
ldlist
in
enumerate
(
self
.
labels_to_indices
.
values
()):
for
idx
,
ldlist
in
enumerate
(
self
.
labels_to_indices
.
values
()):
ldlist
=
numpy
.
array
(
ldlist
)
ldlist
=
numpy
.
array
(
ldlist
)
self
.
labels_to_indices
[
idx
]
=
ldlist
[
torch
.
randperm
(
ldlist
.
shape
[
0
]
,
generator
=
g
).
numpy
()]
self
.
labels_to_indices
[
idx
]
=
ldlist
[
torch
.
randperm
(
ldlist
.
shape
[
0
]).
numpy
()]
self
.
segment_cursors
=
numpy
.
zeros
((
len
(
self
.
labels_to_indices
),),
dtype
=
numpy
.
int
)
self
.
segment_cursors
=
numpy
.
zeros
((
len
(
self
.
labels_to_indices
),),
dtype
=
numpy
.
int
)
def
__iter__
(
self
):
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
numpy
.
random
.
seed
(
self
.
seed
+
self
.
epoch
)
def
__iter__
(
self
):
# Generate batches per speaker
# Generate batches per speaker
straight
=
numpy
.
arange
(
self
.
spk_count
)
straight
=
numpy
.
arange
(
self
.
spk_count
)
indices
=
numpy
.
ones
((
self
.
samples_per_speaker
,
self
.
spk_count
),
dtype
=
numpy
.
int
)
*
straight
indices
=
numpy
.
ones
((
self
.
samples_per_speaker
,
self
.
spk_count
),
dtype
=
numpy
.
int
)
*
straight
...
@@ -139,6 +134,9 @@ class SideSampler(torch.utils.data.Sampler):
...
@@ -139,6 +134,9 @@ class SideSampler(torch.utils.data.Sampler):
# we want to convert the speaker indexes into segment indexes
# we want to convert the speaker indexes into segment indexes
self
.
index_iterator
=
numpy
.
zeros_like
(
batch_matrix
)
self
.
index_iterator
=
numpy
.
zeros_like
(
batch_matrix
)
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
# keep track of next segment index to sample for each speaker
# keep track of next segment index to sample for each speaker
for
idx
,
value
in
enumerate
(
batch_matrix
):
for
idx
,
value
in
enumerate
(
batch_matrix
):
if
self
.
segment_cursors
[
value
]
>
len
(
self
.
labels_to_indices
[
value
])
-
1
:
if
self
.
segment_cursors
[
value
]
>
len
(
self
.
labels_to_indices
[
value
])
-
1
:
...
@@ -146,15 +144,13 @@ class SideSampler(torch.utils.data.Sampler):
...
@@ -146,15 +144,13 @@ class SideSampler(torch.utils.data.Sampler):
self
.
segment_cursors
[
value
]
=
0
self
.
segment_cursors
[
value
]
=
0
self
.
index_iterator
[
idx
]
=
self
.
labels_to_indices
[
value
][
self
.
segment_cursors
[
value
]]
self
.
index_iterator
[
idx
]
=
self
.
labels_to_indices
[
value
][
self
.
segment_cursors
[
value
]]
self
.
segment_cursors
[
value
]
+=
1
self
.
segment_cursors
[
value
]
+=
1
self
.
index_iterator
=
self
.
index_iterator
.
reshape
(
-
1
,
self
.
num_replicas
*
self
.
examples_per_speaker
)[:,
self
.
rank
*
self
.
examples_per_speaker
:(
self
.
rank
+
1
)
*
self
.
examples_per_speaker
].
flatten
()
self
.
index_iterator
=
numpy
.
repeat
(
self
.
index_iterator
,
self
.
num_replicas
)
self
.
index_iterator
=
self
.
index_iterator
.
reshape
(
-
1
,
self
.
num_process
*
self
.
examples_per_speaker
*
self
.
num_replicas
)[:,
self
.
rank
*
self
.
examples_per_speaker
*
self
.
num_replicas
:(
self
.
rank
+
1
)
*
self
.
examples_per_speaker
*
self
.
num_replicas
].
flatten
()
return
iter
(
self
.
index_iterator
)
return
iter
(
self
.
index_iterator
)
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
(
self
.
samples_per_speaker
*
self
.
spk_count
*
self
.
examples_per_speaker
*
self
.
num_replicas
)
//
self
.
num_process
return
(
self
.
samples_per_speaker
*
self
.
spk_count
*
self
.
examples_per_speaker
)
//
self
.
num_replicas
def
set_epoch
(
self
,
epoch
:
int
)
->
None
:
def
set_epoch
(
self
,
epoch
:
int
)
->
None
:
self
.
epoch
=
epoch
self
.
epoch
=
epoch
...
...
nnet/xvector.py
View file @
8ccb5265
...
@@ -1559,9 +1559,11 @@ def extract_embeddings(idmap_name,
...
@@ -1559,9 +1559,11 @@ def extract_embeddings(idmap_name,
device
,
device
,
file_extension
=
"wav"
,
file_extension
=
"wav"
,
transform_pipeline
=
""
,
transform_pipeline
=
""
,
frame_shift
=
1.5
,
sliding_window
=
False
,
frame_duration
=
3.
,
win_duration
=
3.
,
win_shift
=
1.5
,
num_thread
=
1
,
num_thread
=
1
,
sample_rate
=
16000
,
mixed_precision
=
False
):
mixed_precision
=
False
):
"""
"""
...
@@ -1569,14 +1571,13 @@ def extract_embeddings(idmap_name,
...
@@ -1569,14 +1571,13 @@ def extract_embeddings(idmap_name,
:param model_filename:
:param model_filename:
:param data_root_name:
:param data_root_name:
:param device:
:param device:
:param model_yaml:
:param speaker_number:
:param file_extension:
:param file_extension:
:param transform_pipeline:
:param transform_pipeline:
:param
frame_shift
:
:param
sliding_window
:
:param
frame
_duration:
:param
win
_duration:
:param
extract_after_pooling
:
:param
win_shift
:
:param num_thread:
:param num_thread:
:param sample_rate:
:param mixed_precision:
:param mixed_precision:
:return:
:return:
"""
"""
...
@@ -1595,21 +1596,27 @@ def extract_embeddings(idmap_name,
...
@@ -1595,21 +1596,27 @@ def extract_embeddings(idmap_name,
else
:
else
:
idmap
=
IdMap
(
idmap_name
)
idmap
=
IdMap
(
idmap_name
)
if
type
(
model
)
is
Xtractor
:
#
if type(model) is Xtractor:
min_duration
=
(
model
.
context_size
()
-
1
)
*
frame
_shift
+
frame
_duration
#
min_duration = (model.context_size() - 1) *
win
_shift +
win
_duration
model_cs
=
model
.
context_size
()
#
model_cs = model.context_size()
else
:
#
else:
min_duration
=
(
model
.
module
.
context_size
()
-
1
)
*
frame
_shift
+
frame
_duration
#
min_duration = (model.module.context_size() - 1) *
win
_shift +
win
_duration
model_cs
=
model
.
module
.
context_size
()
#
model_cs = model.module.context_size()
# Create dataset to load the data
# Create dataset to load the data
dataset
=
IdMapSet
(
idmap_name
=
idmap_name
,
dataset
=
IdMapSet
(
idmap_name
=
idmap_name
,
data_path
=
data_root_name
,
data_path
=
data_root_name
,
file_extension
=
file_extension
,
file_extension
=
file_extension
,
transform_pipeline
=
transform_pipeline
,
transform_pipeline
=
transform_pipeline
,
min_duration
=
1.5
transform_number
=
0
,
sliding_window
=
sliding_window
,
window_len
=
win_duration
,
window_shift
=
win_shift
,
sample_rate
=
sample_rate
,
min_duration
=
0.165
)
)
dataloader
=
DataLoader
(
dataset
,
dataloader
=
DataLoader
(
dataset
,
batch_size
=
1
,
batch_size
=
1
,
shuffle
=
False
,
shuffle
=
False
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment