Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
e7c23b4f
Commit
e7c23b4f
authored
Mar 25, 2021
by
Anthony Larcher
Browse files
debug
parent
03d0e5e1
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/xsets.py
View file @
e7c23b4f
...
...
@@ -32,6 +32,7 @@ import pandas
import
random
import
torch
import
torchaudio
torchaudio
.
set_audio_backend
(
"sox_io"
)
import
tqdm
import
soundfile
import
yaml
...
...
@@ -347,7 +348,7 @@ class IdMapSet(Dataset):
if
self
.
idmap
.
stop
[
index
]
is
None
:
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
duration
=
len
(
speech
)
-
start
duration
=
speech
.
shape
[
1
]
-
start
else
:
start
=
int
(
self
.
idmap
.
start
[
index
])
duration
=
int
(
self
.
idmap
.
stop
[
index
])
-
start
...
...
nnet/xvector.py
View file @
e7c23b4f
...
...
@@ -1193,7 +1193,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
:return:
"""
model
.
eval
()
print
(
"In cross validation"
)
if
isinstance
(
model
,
Xtractor
):
loss_criteria
=
model
.
loss
else
:
...
...
@@ -1205,7 +1205,10 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
embeddings
=
torch
.
zeros
(
validation_shape
)
classes
=
torch
.
zeros
([
validation_shape
[
0
]])
with
torch
.
no_grad
():
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
tqdm
.
tqdm
(
validation_loader
,
desc
=
'validation compute'
,
mininterval
=
1
)):
#for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
print
(
"In cross validation 2"
)
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
validation_loader
):
print
(
"In cross validation 3"
)
batch_size
=
target
.
shape
[
0
]
target
=
target
.
squeeze
().
to
(
device
)
data
=
data
.
squeeze
().
to
(
device
)
...
...
@@ -1364,7 +1367,7 @@ def extract_embeddings_per_speaker(idmap_name,
model_archi
=
checkpoint
[
"model_archi"
]
model
=
Xtractor
(
checkpoint
[
"speaker_number"
],
model_archi
=
model_archi
)
model
=
Xtractor
(
checkpoint
[
"speaker_number"
],
model_archi
=
model_archi
,
loss
=
"aam"
)
model
.
load_state_dict
(
checkpoint
[
"model_state_dict"
])
else
:
model
=
model_filename
...
...
@@ -1429,7 +1432,7 @@ def extract_sliding_embedding(idmap_name,
file_extension="wav",
transform_pipeline=None,
num_thread=1):
"""
:param idmap_name:
:param window_length:
...
...
@@ -1443,7 +1446,7 @@ def extract_sliding_embedding(idmap_name,
:param file_extension:
:param transform_pipeline:
:return:
"""
# From the original IdMap, create the new one to extract x-vectors
if not isinstance(idmap_name, IdMap):
input_idmap = IdMap(idmap_name)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment