Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
5da3e20f
Commit
5da3e20f
authored
Nov 13, 2021
by
Le Lan Gaël
Browse files
Merge branch 'dev-gl3lan' into corr_pooling
parents
fa238b80
be3c8478
Changes
4
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
5da3e20f
*.pyc
*.DS_Store
docs
.vscode
/settings.json
.vscode
.gitignore
.vscode
.history
nnet/augmentation.py
View file @
5da3e20f
...
...
@@ -173,6 +173,9 @@ def data_augmentation(speech,
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
if
"none"
in
augmentations
:
pass
if
"stretch"
in
augmentations
:
strech
=
torchaudio
.
functional
.
TimeStretch
()
rate
=
random
.
uniform
(
0.8
,
1.2
)
...
...
@@ -261,6 +264,7 @@ def data_augmentation(speech,
final_shape
=
speech
.
shape
[
1
]
configs
=
[
({
"format"
:
"wav"
,
"encoding"
:
'ULAW'
,
"bits_per_sample"
:
8
},
"8 bit mu-law"
),
({
"format"
:
"wav"
,
"encoding"
:
'ALAW'
,
"bits_per_sample"
:
8
},
"8 bit a-law"
),
({
"format"
:
"gsm"
},
"GSM-FR"
),
({
"format"
:
"mp3"
,
"compression"
:
-
9
},
"MP3"
),
({
"format"
:
"vorbis"
,
"compression"
:
-
1
},
"Vorbis"
)
...
...
nnet/xsets.py
View file @
5da3e20f
...
...
@@ -106,7 +106,7 @@ class SideSampler(torch.utils.data.Sampler):
self
.
segment_cursors
=
numpy
.
zeros
((
len
(
self
.
labels_to_indices
),),
dtype
=
numpy
.
int
)
def
__iter__
(
self
):
def
__iter__
(
self
):
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
numpy
.
random
.
seed
(
self
.
seed
+
self
.
epoch
)
...
...
@@ -175,7 +175,7 @@ class SideSet(Dataset):
overlap
=
0.
,
dataset_df
=
None
,
min_duration
=
0.165
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
):
"""
...
...
@@ -269,6 +269,8 @@ class SideSet(Dataset):
self
.
transform
[
"codec"
]
=
[]
if
"phone_filtering"
in
transforms
:
self
.
transform
[
"phone_filtering"
]
=
[]
if
"stretch"
in
transforms
:
self
.
transform
[
"stretch"
]
=
[]
self
.
noise_df
=
None
if
"add_noise"
in
self
.
transform
:
...
...
@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
start
=
int
(
self
.
idmap
.
start
[
index
]
*
0.01
*
self
.
sample_rate
)
if
self
.
idmap
.
stop
[
index
]
is
None
:
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
duration
=
int
(
speech
.
shape
[
1
]
-
start
)
else
:
duration
=
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
# TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
assert
nfo
.
sample_rate
==
self
.
sample_rate
conversion_rate
=
nfo
.
sample_rate
//
self
.
sample_rate
duration
=
(
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
)
# add this in case the segment is too short
if
duration
<=
self
.
min_duration
*
self
.
sample_rate
:
middle
=
start
+
duration
//
2
start
=
int
(
max
(
0
,
int
(
middle
-
(
self
.
min_duration
*
self
.
sample_rate
/
2
))))
duration
=
int
(
self
.
min_duration
*
self
.
sample_rate
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
frame_offset
=
start
,
num_frames
=
duration
)
frame_offset
=
start
*
conversion_rate
,
num_frames
=
duration
*
conversion_rate
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
...
...
nnet/xvector.py
View file @
5da3e20f
...
...
@@ -530,7 +530,8 @@ class Xtractor(torch.nn.Module):
m
=
0.2
,
easy_margin
=
False
)
elif
self
.
loss
==
'aps'
:
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
))
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
),
emb_dim
=
self
.
embedding_size
)
self
.
preprocessor_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
self
.
stat_pooling_weight_decay
=
0.00002
...
...
@@ -1095,7 +1096,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
training_df
,
validation_df
=
train_test_split
(
df
,
test_size
=
dataset_opts
[
"validation_ratio"
],
stratify
=
stratify
)
# TODO
torch
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
cuda
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
...
...
@@ -1105,7 +1107,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number
=
dataset_opts
[
'train'
][
'transform_number'
],
overlap
=
dataset_opts
[
'train'
][
'overlap'
],
dataset_df
=
training_df
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
)
validation_set
=
SideSet
(
dataset_opts
,
...
...
@@ -1628,7 +1630,7 @@ def train_epoch(model,
if
math
.
fmod
(
batch_idx
,
training_opts
[
"log_interval"
])
==
0
:
batch_size
=
target
.
shape
[
0
]
training_monitor
.
update
(
training_loss
=
loss
.
item
(),
training_acc
=
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
))
training_acc
=
100.0
*
accuracy
/
((
batch_idx
+
1
)
*
batch_size
))
training_monitor
.
logger
.
info
(
'Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
training_monitor
.
current_epoch
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment