Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
8e8fb525
Commit
8e8fb525
authored
Dec 13, 2021
by
Anthony Larcher
Browse files
update anthony
parent
88f4d2b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/xsets.py
View file @
8e8fb525
...
...
@@ -429,7 +429,7 @@ class IdMapSet(Dataset):
frame_offset
=
start
,
num_frames
=
duration
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
#
speech += 10e-6 * torch.randn(speech.shape)
if
self
.
sliding_window
:
speech
=
speech
.
squeeze
().
unfold
(
0
,
self
.
window_len
,
self
.
window_shift
)
...
...
@@ -564,7 +564,7 @@ class IdMapSetPerSpeaker(Dataset):
tmp_data
.
append
(
speech
)
speech
=
torch
.
cat
(
tmp_data
,
dim
=
1
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
#
speech += 10e-6 * torch.randn(speech.shape)
if
len
(
self
.
transformation
.
keys
())
>
0
:
speech
=
data_augmentation
(
speech
,
...
...
nnet/xvector.py
View file @
8e8fb525
...
...
@@ -978,6 +978,8 @@ def update_training_dictionary(dataset_description,
fill_dict
(
model_opts
,
tmp_model_dict
)
fill_dict
(
training_opts
,
tmp_train_dict
)
print
(
model_opts
)
# Overwrite with manually given parameters
if
"lr"
in
kwargs
:
training_opts
[
"lr"
]
=
kwargs
[
'lr'
]
...
...
@@ -1010,26 +1012,28 @@ def get_network(model_opts, local_rank):
model
=
Xtractor
(
model_opts
[
"speaker_number"
],
model_opts
,
loss
=
model_opts
[
"loss"
][
"type"
],
embedding_size
=
model_opts
[
"embedding_size"
])
# Load the model if it exists
if
model_opts
[
"initial_model_name"
]
is
not
None
and
os
.
path
.
isfile
(
model_opts
[
"initial_model_name"
]):
logging
.
critical
(
f
"*** Load model from =
{
model_opts
[
'initial_model_name'
]
}
"
)
checkpoint
=
torch
.
load
(
model_opts
[
"initial_model_name"
])
if
model_opts
[
"initial_model_name"
]
is
not
None
:
if
os
.
path
.
isfile
(
model_opts
[
"initial_model_name"
]):
print
(
f
"model_opts['initial_model_name'] =
{
model_opts
[
'initial_model_name'
]
}
et os.path.isfile(model_opts['initial_model_name']):
{
os
.
path
.
isfile
(
model_opts
[
'initial_model_name'
])
}
"
)
logging
.
critical
(
f
"*** Load model from =
{
model_opts
[
'initial_model_name'
]
}
"
)
checkpoint
=
torch
.
load
(
model_opts
[
"initial_model_name"
],
map_location
=
{
"cuda:0"
:
"cuda:%d"
%
local_rank
})
"""
Here we remove all layers that we don't want to reload
"""
Here we remove all layers that we don't want to reload
"""
pretrained_dict
=
checkpoint
[
"model_state_dict"
]
for
part
in
model_opts
[
"reset_parts"
]:
pretrained_dict
=
{
k
:
v
for
k
,
v
in
pretrained_dict
.
items
()
if
not
k
.
startswith
(
part
)}
"""
pretrained_dict
=
checkpoint
[
"model_state_dict"
]
for
part
in
model_opts
[
"reset_parts"
]:
pretrained_dict
=
{
k
:
v
for
k
,
v
in
pretrained_dict
.
items
()
if
not
k
.
startswith
(
part
)}
new_model_dict
=
model
.
state_dict
()
new_model_dict
.
update
(
pretrained_dict
)
model
.
load_state_dict
(
new_model_dict
)
new_model_dict
=
model
.
state_dict
()
new_model_dict
.
update
(
pretrained_dict
)
model
.
load_state_dict
(
new_model_dict
)
# Freeze required layers
for
name
,
param
in
model
.
named_parameters
():
if
name
.
split
(
"."
)[
0
]
in
model_opts
[
"reset_parts"
]:
param
.
requires_grad
=
False
# Freeze required layers
for
name
,
param
in
model
.
named_parameters
():
if
name
.
split
(
"."
)[
0
]
in
model_opts
[
"reset_parts"
]:
param
.
requires_grad
=
False
if
model_opts
[
"loss"
][
"type"
]
==
"aam"
and
not
(
model_opts
[
"loss"
][
"aam_margin"
]
==
0.2
and
model_opts
[
"loss"
][
"aam_s"
]
==
30
):
model
.
after_speaker_embedding
.
change_params
(
model_opts
[
"loss"
][
"aam_s"
],
model_opts
[
"loss"
][
"aam_margin"
])
...
...
@@ -1755,6 +1759,7 @@ def extract_embeddings(idmap_name,
modelset
=
[]
segset
=
[]
starts
=
[]
stops
=
[]
for
idx
,
(
data
,
mod
,
seg
,
start
,
stop
)
in
enumerate
(
tqdm
.
tqdm
(
dataloader
,
desc
=
'xvector extraction'
,
...
...
@@ -1774,15 +1779,20 @@ def extract_embeddings(idmap_name,
if
sliding_window
:
tmp_start
=
numpy
.
arange
(
0
,
data
.
shape
[
0
]
*
win_shift
,
win_shift
)
starts
.
extend
(
tmp_start
*
sample_rate
+
start
.
detach
().
cpu
().
numpy
())
win_duration
=
int
(
len
(
tmp_data
))
else
:
starts
.
append
(
start
.
numpy
())
stops
.
append
(
tmp_data
[
0
].
shape
[
1
])
embeddings
=
StatServer
()
embeddings
.
stat1
=
numpy
.
concatenate
(
embed
)
embeddings
.
modelset
=
numpy
.
array
(
modelset
).
astype
(
'>U'
)
embeddings
.
segset
=
numpy
.
array
(
segset
).
astype
(
'>U'
)
embeddings
.
start
=
numpy
.
array
(
starts
).
squeeze
()
embeddings
.
stop
=
embeddings
.
start
+
win_duration
if
sliding_window
:
embeddings
.
stop
=
embeddings
.
start
+
win_duration
else
:
embeddings
.
stop
=
embeddings
.
start
+
numpy
.
array
(
stops
).
squeeze
()
embeddings
.
stat0
=
numpy
.
ones
((
embeddings
.
modelset
.
shape
[
0
],
1
))
return
embeddings
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment