Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
80656610
Commit
80656610
authored
Apr 13, 2021
by
Anthony Larcher
Browse files
update extract_embeddings
parent
5491cb86
Changes
2
Show whitespace changes
Inline
Side-by-side
nnet/xsets.py
View file @
80656610
...
...
@@ -360,7 +360,7 @@ class IdMapSet(Dataset):
transform_number
=
1
,
sliding_window
=
False
,
window_len
=
3.
,
window_shift
=
1.
,
window_shift
=
1.
5
,
sample_rate
=
16000
,
min_duration
=
0.165
):
...
...
@@ -378,18 +378,19 @@ class IdMapSet(Dataset):
self
.
file_extension
=
file_extension
self
.
len
=
self
.
idmap
.
leftids
.
shape
[
0
]
self
.
transformation
=
transform_pipeline
self
.
min_
sample_nb
=
min_duration
*
sample_rate
self
.
min_
duration
=
min_duration
self
.
sample_rate
=
sample_rate
self
.
sliding_window
=
sliding_window
self
.
window_len
=
window_len
self
.
window_shift
=
window_shift
self
.
window_len
=
int
(
window_len
*
self
.
sample_rate
)
self
.
window_shift
=
int
(
window_shift
*
self
.
sample_rate
)
self
.
transform_number
=
transform_number
self
.
noise_df
=
None
if
"add_noise"
in
self
.
transformation
:
# Load the noise dataset, filter according to the duration
noise_df
=
pandas
.
read_csv
(
self
.
transformation
[
"add_noise"
][
"noise_db_csv"
])
self
.
noise_df
=
noise_df
.
set_index
(
noise_df
.
type
)
tmp_df
=
noise_df
.
loc
[
noise_df
[
'duration'
]
>
self
.
duration
]
self
.
noise_df
=
tmp_df
[
'file_id'
].
tolist
()
self
.
rir_df
=
None
if
"add_reverb"
in
self
.
transformation
:
...
...
@@ -403,18 +404,19 @@ class IdMapSet(Dataset):
:param index:
:return:
"""
# Read start and stop and convert to time in seconds
if
self
.
idmap
.
start
[
index
]
is
None
:
start
=
0
else
:
start
=
int
(
self
.
idmap
.
start
[
index
]
)
*
160
start
=
int
(
self
.
idmap
.
start
[
index
]
*
0.01
*
self
.
sample_rate
)
if
self
.
idmap
.
stop
[
index
]
is
None
:
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
duration
=
int
(
speech
.
shape
[
1
]
-
start
)
else
:
duration
=
int
(
self
.
idmap
.
stop
[
index
]
)
*
160
-
start
duration
=
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
)
*
self
.
sample_rate
-
start
# add this in case the segment is too short
if
duration
<=
self
.
min_
sample_
nb
:
if
duration
<=
self
.
self
.
min_duration
*
self
.
sample_
rate
:
middle
=
start
+
duration
//
2
start
=
max
(
0
,
int
(
middle
-
(
self
.
min_sample_nb
/
2
)))
duration
=
int
(
self
.
min_sample_nb
)
...
...
@@ -426,7 +428,7 @@ class IdMapSet(Dataset):
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
if
self
.
sliding_window
:
speech
=
speech
.
squeeze
().
unfold
(
0
,
self
.
window_len
,
self
.
window_shift
)
speech
=
speech
.
squeeze
().
unfold
(
0
,
self
.
window_len
,
self
.
window_shift
)
if
len
(
self
.
transformation
.
keys
())
>
0
:
speech
=
data_augmentation
(
speech
,
...
...
nnet/xvector.py
View file @
80656610
...
...
@@ -1449,6 +1449,14 @@ def train_epoch(model,
else
:
output
,
_
=
model
(
data
,
target
=
None
)
loss
=
criterion
(
output
,
target
)
scaler
.
scale
(
loss
).
backward
()
if
clipping
:
scaler
.
unscale_
(
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.
)
scaler
.
step
(
optimizer
)
scaler
.
update
()
else
:
if
loss_criteria
==
'aam'
:
output
,
_
=
model
(
data
,
target
=
target
)
...
...
@@ -1461,18 +1469,9 @@ def train_epoch(model,
output
,
_
=
model
(
data
,
target
=
None
)
loss
=
criterion
(
output
,
target
)
#if not torch.isnan(loss):
if
True
:
if
scaler
is
not
None
:
scaler
.
scale
(
loss
).
backward
()
if
clipping
:
scaler
.
unscale_
(
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.
)
scaler
.
step
(
optimizer
)
scaler
.
update
()
else
:
loss
.
backward
()
optimizer
.
step
()
running_loss
+=
loss
.
item
()
accuracy
+=
(
torch
.
argmax
(
output
.
data
,
1
)
==
target
).
sum
()
...
...
@@ -1489,18 +1488,6 @@ def train_epoch(model,
loss
.
item
(),
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
)))
#else:
# save_checkpoint({
# 'epoch': training_monitor.current_epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': 0.0,
# 'scheduler': 0.0
# }, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
# with open("batch_loss_NAN.pkl", "wb") as fh:
# pickle.dump(data.cpu(), fh)
# import sys
# sys.exit()
running_loss
=
0.0
if
isinstance
(
scheduler
,
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
):
...
...
@@ -1570,12 +1557,10 @@ def extract_embeddings(idmap_name,
model_filename
,
data_root_name
,
device
,
loss
,
file_extension
=
"wav"
,
transform_pipeline
=
""
,
frame_shift
=
0.01
,
frame_duration
=
0.025
,
extract_after_pooling
=
False
,
frame_shift
=
1.5
,
frame_duration
=
3.
,
num_thread
=
1
,
mixed_precision
=
False
):
"""
...
...
@@ -1622,7 +1607,7 @@ def extract_embeddings(idmap_name,
data_path
=
data_root_name
,
file_extension
=
file_extension
,
transform_pipeline
=
transform_pipeline
,
min_duration
=
(
model_cs
+
2
)
*
frame_shift
*
2
min_duration
=
1.5
)
dataloader
=
DataLoader
(
dataset
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment