Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
d8ac6322
Commit
d8ac6322
authored
Mar 26, 2021
by
Anthony Larcher
Browse files
debug
parent
7bbe5324
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/xsets.py
View file @
d8ac6322
...
...
@@ -293,8 +293,8 @@ class IdMapSet(Dataset):
idmap_name
,
data_path
,
file_extension
,
transform_pipeline
=
{}
,
sliding_window
=
Tru
e
,
transform_pipeline
=
""
,
sliding_window
=
Fals
e
,
window_len
=
24000
,
window_shift
=
8000
,
sample_rate
=
16000
,
...
...
@@ -358,9 +358,9 @@ class IdMapSet(Dataset):
start
=
max
(
0
,
int
(
middle
-
(
self
.
min_sample_nb
/
2
)))
duration
=
self
.
min_sample_nb
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
frame_offset
=
start
,
num_frames
=
duration
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
frame_offset
=
start
,
num_frames
=
duration
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
...
...
nnet/xvector.py
View file @
d8ac6322
...
...
@@ -226,7 +226,7 @@ def test_metrics(model,
key_test_filename
=
'h5f/key_test.h5'
data_root_name
=
'/lium/scratch/larcher/voxceleb1/test/wav'
transform_pipeline
=
dict
()
transform_pipeline
=
""
xv_stat
=
extract_embeddings
(
idmap_name
=
idmap_test_filename
,
model_filename
=
model
,
...
...
@@ -837,9 +837,6 @@ def xtrain(speaker_number,
new_model_dict
.
update
(
pretrained_dict
)
model
.
load_state_dict
(
new_model_dict
)
print
(
"Modifiy margin: 0.4"
)
model
.
after_speaker_embedding
.
m
=
0.4
# Freeze required layers
for
name
,
param
in
model
.
named_parameters
():
...
...
@@ -1201,7 +1198,6 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
:return:
"""
model
.
eval
()
print
(
"In cross validation"
)
if
isinstance
(
model
,
Xtractor
):
loss_criteria
=
model
.
loss
else
:
...
...
@@ -1214,9 +1210,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
classes
=
torch
.
zeros
([
validation_shape
[
0
]])
with
torch
.
no_grad
():
#for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
print
(
"In cross validation 2"
)
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
validation_loader
):
print
(
"In cross validation 3"
)
batch_size
=
target
.
shape
[
0
]
target
=
target
.
squeeze
().
to
(
device
)
data
=
data
.
squeeze
().
to
(
device
)
...
...
@@ -1254,7 +1248,7 @@ def extract_embeddings(idmap_name,
data_root_name
,
device
,
file_extension
=
"wav"
,
transform_pipeline
=
{}
,
transform_pipeline
=
""
,
frame_shift
=
0.01
,
frame_duration
=
0.025
,
extract_after_pooling
=
False
,
...
...
@@ -1306,7 +1300,6 @@ def extract_embeddings(idmap_name,
data_path
=
data_root_name
,
file_extension
=
file_extension
,
transform_pipeline
=
transform_pipeline
,
frame_rate
=
int
(
1.
/
frame_shift
),
min_duration
=
(
model_cs
+
2
)
*
frame_shift
*
2
)
...
...
@@ -1602,7 +1595,7 @@ def extract_sliding_embedding(idmap_name,
segset
+=
[
seg
,]
*
data
.
shape
[
0
]
starts
+=
[
numpy
.
arange
(
start
,
start
+
embeddings
.
shape
[
0
]
*
window_shift
,
window_shift
),]
REPRENDRE
ICI
#
REPRENDRE ICI
# Create the StatServer
embeddings
=
StatServer
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment