Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
5da3e20f
Commit
5da3e20f
authored
Nov 13, 2021
by
Le Lan Gaël
Browse files
Merge branch 'dev-gl3lan' into corr_pooling
parents
fa238b80
be3c8478
Changes
4
Show whitespace changes
Inline
Side-by-side
.gitignore
View file @
5da3e20f
*.pyc
*.pyc
*.DS_Store
*.DS_Store
docs
docs
.vscode
/settings.json
.vscode
.gitignore
.gitignore
.vscode
.history
nnet/augmentation.py
View file @
5da3e20f
...
@@ -173,6 +173,9 @@ def data_augmentation(speech,
...
@@ -173,6 +173,9 @@ def data_augmentation(speech,
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
if
"none"
in
augmentations
:
pass
if
"stretch"
in
augmentations
:
if
"stretch"
in
augmentations
:
strech
=
torchaudio
.
functional
.
TimeStretch
()
strech
=
torchaudio
.
functional
.
TimeStretch
()
rate
=
random
.
uniform
(
0.8
,
1.2
)
rate
=
random
.
uniform
(
0.8
,
1.2
)
...
@@ -261,6 +264,7 @@ def data_augmentation(speech,
...
@@ -261,6 +264,7 @@ def data_augmentation(speech,
final_shape
=
speech
.
shape
[
1
]
final_shape
=
speech
.
shape
[
1
]
configs
=
[
configs
=
[
({
"format"
:
"wav"
,
"encoding"
:
'ULAW'
,
"bits_per_sample"
:
8
},
"8 bit mu-law"
),
({
"format"
:
"wav"
,
"encoding"
:
'ULAW'
,
"bits_per_sample"
:
8
},
"8 bit mu-law"
),
({
"format"
:
"wav"
,
"encoding"
:
'ALAW'
,
"bits_per_sample"
:
8
},
"8 bit a-law"
),
({
"format"
:
"gsm"
},
"GSM-FR"
),
({
"format"
:
"gsm"
},
"GSM-FR"
),
({
"format"
:
"mp3"
,
"compression"
:
-
9
},
"MP3"
),
({
"format"
:
"mp3"
,
"compression"
:
-
9
},
"MP3"
),
({
"format"
:
"vorbis"
,
"compression"
:
-
1
},
"Vorbis"
)
({
"format"
:
"vorbis"
,
"compression"
:
-
1
},
"Vorbis"
)
...
...
nnet/xsets.py
View file @
5da3e20f
...
@@ -175,7 +175,7 @@ class SideSet(Dataset):
...
@@ -175,7 +175,7 @@ class SideSet(Dataset):
overlap
=
0.
,
overlap
=
0.
,
dataset_df
=
None
,
dataset_df
=
None
,
min_duration
=
0.165
,
min_duration
=
0.165
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
):
):
"""
"""
...
@@ -269,6 +269,8 @@ class SideSet(Dataset):
...
@@ -269,6 +269,8 @@ class SideSet(Dataset):
self
.
transform
[
"codec"
]
=
[]
self
.
transform
[
"codec"
]
=
[]
if
"phone_filtering"
in
transforms
:
if
"phone_filtering"
in
transforms
:
self
.
transform
[
"phone_filtering"
]
=
[]
self
.
transform
[
"phone_filtering"
]
=
[]
if
"stretch"
in
transforms
:
self
.
transform
[
"stretch"
]
=
[]
self
.
noise_df
=
None
self
.
noise_df
=
None
if
"add_noise"
in
self
.
transform
:
if
"add_noise"
in
self
.
transform
:
...
@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
...
@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
start
=
int
(
self
.
idmap
.
start
[
index
]
*
0.01
*
self
.
sample_rate
)
start
=
int
(
self
.
idmap
.
start
[
index
]
*
0.01
*
self
.
sample_rate
)
if
self
.
idmap
.
stop
[
index
]
is
None
:
if
self
.
idmap
.
stop
[
index
]
is
None
:
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
duration
=
int
(
speech
.
shape
[
1
]
-
start
)
duration
=
int
(
speech
.
shape
[
1
]
-
start
)
else
:
else
:
duration
=
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
# TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
assert
nfo
.
sample_rate
==
self
.
sample_rate
conversion_rate
=
nfo
.
sample_rate
//
self
.
sample_rate
duration
=
(
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
)
# add this in case the segment is too short
# add this in case the segment is too short
if
duration
<=
self
.
min_duration
*
self
.
sample_rate
:
if
duration
<=
self
.
min_duration
*
self
.
sample_rate
:
middle
=
start
+
duration
//
2
middle
=
start
+
duration
//
2
start
=
int
(
max
(
0
,
int
(
middle
-
(
self
.
min_duration
*
self
.
sample_rate
/
2
))))
start
=
int
(
max
(
0
,
int
(
middle
-
(
self
.
min_duration
*
self
.
sample_rate
/
2
))))
duration
=
int
(
self
.
min_duration
*
self
.
sample_rate
)
duration
=
int
(
self
.
min_duration
*
self
.
sample_rate
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
frame_offset
=
start
,
frame_offset
=
start
*
conversion_rate
,
num_frames
=
duration
)
num_frames
=
duration
*
conversion_rate
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
...
...
nnet/xvector.py
View file @
5da3e20f
...
@@ -530,7 +530,8 @@ class Xtractor(torch.nn.Module):
...
@@ -530,7 +530,8 @@ class Xtractor(torch.nn.Module):
m
=
0.2
,
m
=
0.2
,
easy_margin
=
False
)
easy_margin
=
False
)
elif
self
.
loss
==
'aps'
:
elif
self
.
loss
==
'aps'
:
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
))
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
),
emb_dim
=
self
.
embedding_size
)
self
.
preprocessor_weight_decay
=
0.00002
self
.
preprocessor_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
self
.
stat_pooling_weight_decay
=
0.00002
self
.
stat_pooling_weight_decay
=
0.00002
...
@@ -1096,6 +1097,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1096,6 +1097,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
test_size
=
dataset_opts
[
"validation_ratio"
],
test_size
=
dataset_opts
[
"validation_ratio"
],
stratify
=
stratify
)
stratify
=
stratify
)
# TODO
torch
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
cuda
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
cuda
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
...
@@ -1105,7 +1107,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1105,7 +1107,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number
=
dataset_opts
[
'train'
][
'transform_number'
],
transform_number
=
dataset_opts
[
'train'
][
'transform_number'
],
overlap
=
dataset_opts
[
'train'
][
'overlap'
],
overlap
=
dataset_opts
[
'train'
][
'overlap'
],
dataset_df
=
training_df
,
dataset_df
=
training_df
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
)
)
validation_set
=
SideSet
(
dataset_opts
,
validation_set
=
SideSet
(
dataset_opts
,
...
@@ -1628,7 +1630,7 @@ def train_epoch(model,
...
@@ -1628,7 +1630,7 @@ def train_epoch(model,
if
math
.
fmod
(
batch_idx
,
training_opts
[
"log_interval"
])
==
0
:
if
math
.
fmod
(
batch_idx
,
training_opts
[
"log_interval"
])
==
0
:
batch_size
=
target
.
shape
[
0
]
batch_size
=
target
.
shape
[
0
]
training_monitor
.
update
(
training_loss
=
loss
.
item
(),
training_monitor
.
update
(
training_loss
=
loss
.
item
(),
training_acc
=
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
))
training_acc
=
100.0
*
accuracy
/
((
batch_idx
+
1
)
*
batch_size
))
training_monitor
.
logger
.
info
(
'Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
training_monitor
.
logger
.
info
(
'Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
training_monitor
.
current_epoch
,
training_monitor
.
current_epoch
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment