Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
b20c6a25
Commit
b20c6a25
authored
Mar 12, 2021
by
Gaël Le Lan
Browse files
add_noise
parent
71e8b37c
Changes
4
Hide whitespace changes
Inline
Side-by-side
nnet/augmentation.py
View file @
b20c6a25
...
...
@@ -105,6 +105,8 @@ class AddNoise(object):
noises
=
[]
left
=
original_duration
noise
=
numpy
.
zeros_like
(
data
)
while
left
>
0
:
# select noise file at random
file
=
random
.
choice
(
self
.
noises
)
...
...
@@ -117,23 +119,23 @@ class AddNoise(object):
duration
=
noise_signal
.
shape
[
0
]
# if noise file is longer than what is needed, crop it
if
duration
>
left
:
noise
=
crop
(
noise_signal
,
left
)
if
duration
>
=
left
:
noise
[
-
left
:]
=
normalize
(
crop
(
noise_signal
,
left
)
)
left
=
0
# otherwise, take the whole file
else
:
noise
=
noise_signal
noise
[
-
left
:
-
left
+
duration
]
=
normalize
(
noise_signal
)
left
-=
duration
# Todo Downsample if needed
# if sample_rate > fs:
#
noise
=
normalize
(
noise
)
noises
.
append
(
noise
.
squeeze
())
#
noise = normalize(noise)
#
noises.append(noise.squeeze())
# concatenate
noise
=
numpy
.
hstack
(
noises
)
#
noise = numpy.hstack(noises)
# select SNR at random
snr
=
(
self
.
snr_max
-
self
.
snr_min
)
*
numpy
.
random
.
random_sample
()
+
self
.
snr_min
...
...
nnet/res_net.py
View file @
b20c6a25
...
...
@@ -607,21 +607,21 @@ class PreFastResNet34(torch.nn.Module):
self
.
in_planes
=
16
self
.
speaker_number
=
speaker_number
#torchaudio.transforms.MelSpectrogram(sample_rate=16000,
# n_fft=2048,
# f_min=133.333,
# f_max=6855.4976,
# win_length=400,
# hop_length=160,
# window_fn=torch.hann_window,
# power=2,
# n_mels=80)
self
.
MelSpec
=
torchaudio
.
transforms
.
MelSpectrogram
(
sample_rate
=
melkwargs
[
'sample_rate'
],
n_fft
=
melkwargs
[
'n_fft'
],
f_min
=
melkwargs
[
'f_min'
],
f_max
=
melkwargs
[
'f_max'
],
win_length
=
melkwargs
[
'win_length'
],
hop_length
=
melkwargs
[
'hop_length'
],
window_fn
=
melkwargs
[
'window_fn'
],
n_mels
=
melkwargs
[
'n_mels'
])
self
.
PreEmphasis
=
PreEmphasis
()
self
.
MFCC
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
n_mfcc
=
1
+
n_mfcc
,
log_mels
=
True
,
melkwargs
=
melspec_dict
)
#self.MFCC = torchaudio.transforms.MFCC(sample_rate=sample_rate,
# n_mfcc=1 + n_mfcc,
# log_mels=True,
# melkwargs = melkwargs)
self
.
CMVN
=
torch
.
nn
.
InstanceNorm1d
(
n_mfcc
)
...
...
@@ -648,9 +648,10 @@ class PreFastResNet34(torch.nn.Module):
with
torch
.
no_grad
():
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
self
.
PreEmphasis
(
x
)
out
=
self
.
MFCC
(
x
)[:,
1
:,
:]
out
=
self
.
CMVN
(
x
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
MelSpec
(
out
)
+
1e-6
out
=
torch
.
log
(
out
)
out
=
self
.
CMVN
(
out
).
unsqueeze
(
1
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
out
=
self
.
layer1
(
out
)
out
=
self
.
layer2
(
out
)
out
=
self
.
layer3
(
out
)
...
...
nnet/xsets.py
View file @
b20c6a25
...
...
@@ -274,7 +274,6 @@ class MFCC(object):
"""
sig
=
sample
[
0
][:,
numpy
.
newaxis
]
# ajout
framed
=
framing
(
sample
[
0
],
self
.
window_length
,
win_shift
=
self
.
window_length
-
self
.
overlap
).
copy
()
framed
=
framing
(
sample
[
0
],
self
.
window_length
,
win_shift
=
self
.
window_length
-
self
.
overlap
).
copy
()
# Pre-emphasis filtering is applied after framing to be consistent with stream processing
framed
=
pre_emphasis
(
framed
,
self
.
prefac
)
# Windowing has been changed to hanning which is supposed to have less noisy sidelobes
...
...
@@ -453,9 +452,13 @@ class SpkSet(Dataset):
current_speaker
=
self
.
_spk_index
[
int
(
math
.
fmod
(
index
,
len
(
self
.
_spk_index
)))]
segment_index
=
numpy
.
random
.
choice
(
self
.
_spk_dict
[
current_speaker
][
'num_segs'
],
p
=
self
.
_spk_dict
[
current_speaker
][
'p'
])
self
.
_spk_dict
[
current_speaker
][
'p'
][
segment_index
]
/=
2
self
.
_spk_dict
[
current_speaker
][
'p'
][
segment_index
]
=
0
#
/= 2
current_segment
=
self
.
_spk_dict
[
current_speaker
][
'segments'
][
segment_index
]
self
.
_spk_dict
[
current_speaker
][
'p'
]
=
self
.
_spk_dict
[
current_speaker
][
'p'
]
/
numpy
.
sum
(
self
.
_spk_dict
[
current_speaker
][
'p'
])
if
numpy
.
sum
(
self
.
_spk_dict
[
current_speaker
][
'p'
])
>
0
:
self
.
_spk_dict
[
current_speaker
][
'p'
]
=
self
.
_spk_dict
[
current_speaker
][
'p'
]
/
numpy
.
sum
(
self
.
_spk_dict
[
current_speaker
][
'p'
])
else
:
self
.
_spk_dict
[
current_speaker
][
'p'
]
+=
1
/
self
.
_spk_dict
[
current_speaker
][
'num_segs'
]
nfo
=
soundfile
.
info
(
f
"
{
self
.
data_path
}
/
{
current_segment
[
'file_id'
]
}{
self
.
data_file_extension
}
"
)
if
self
.
_windowed
:
...
...
@@ -465,16 +468,16 @@ class SpkSet(Dataset):
start_frame
=
int
(
current_segment
[
'start'
]
*
self
.
sample_rate
)
stop_frame
=
int
(
current_segment
[
'duration'
]
*
self
.
sample_rate
)
sig
,
sample_rate2
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
current_segment
[
'file_id'
]
}{
self
.
data_file_extension
}
"
,
frame_offset
=
start_frame
,
num_frames
=
stop_frame
)
#
sig, _ = soundfile.read(f"{self.data_path}/{current_segment['file_id']}{self.data_file_extension}",
#
start=start_frame,
#
stop=stop_frame,
#
dtype=wav_type
#
)
#
sig = sig.astype(numpy.float32)
#
sig += 0.0001 * numpy.random.randn(sig.shape[0])
#
sig, sample_rate2 = torchaudio.load(f"{self.data_path}/{current_segment['file_id']}{self.data_file_extension}",
#
frame_offset=start_frame,
#
num_frames=stop_frame)
sig
,
_
=
soundfile
.
read
(
f
"
{
self
.
data_path
}
/
{
current_segment
[
'file_id'
]
}{
self
.
data_file_extension
}
"
,
start
=
start_frame
,
stop
=
stop_frame
,
dtype
=
wav_type
)
sig
=
sig
.
astype
(
numpy
.
float32
)
sig
+=
0.0001
*
numpy
.
random
.
randn
(
sig
.
shape
[
0
])
speaker_idx
=
self
.
_spk_dict
[
current_speaker
][
"speaker_idx"
]
...
...
nnet/xvector.py
View file @
b20c6a25
...
...
@@ -424,25 +424,27 @@ class Xtractor(torch.nn.Module):
self
.
input_nbdim
=
2
self
.
preprocessor
=
None
melkargs
=
dict
()
melkargs
[
'n_fft'
]
=
2048
melkargs
[
'f_min'
]
=
133.333
melkargs
[
'f_max'
]
=
6855.4976
melkargs
[
'win_length'
]
=
400
melkargs
[
'hop_length'
]
=
160
melkargs
[
'window_fn'
]
=
torch
.
hann_window
melkargs
[
'power'
]
=
2
melkargs
[
'n_mels'
]
=
81
self
.
sequence_network
=
PreFastResNet34
(
n_mfcc
=
80
,
melkargs
=
melkargs
)
melkwargs
=
dict
()
melkwargs
[
'sample_rate'
]
=
16000
melkwargs
[
'n_fft'
]
=
1024
#2048
melkwargs
[
'f_min'
]
=
90
#133.333
melkwargs
[
'f_max'
]
=
7600
#6855.4976
melkwargs
[
'win_length'
]
=
1024
#400
melkwargs
[
'hop_length'
]
=
256
#160
melkwargs
[
'window_fn'
]
=
torch
.
hann_window
#melkwargs['power'] = 2
melkwargs
[
'n_mels'
]
=
80
self
.
sequence_network
=
PreFastResNet34
(
melkwargs
=
melkwargs
)
self
.
embedding_size
=
256
self
.
before_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
2560
,
out_features
=
512
)
out_features
=
self
.
embedding_size
)
self
.
stat_pooling
=
MeanStdPooling
()
self
.
stat_pooling_weight_decay
=
0
self
.
embedding_size
=
512
self
.
loss
=
"aam"
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
...
...
@@ -818,7 +820,7 @@ def xtrain(speaker_number,
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Start from scratch
if
model_name
is
None
and
model_yaml
in
[
"xvector"
,
"rawnet2"
,
"resnet34"
]:
if
model_name
is
None
and
model_yaml
in
[
"xvector"
,
"rawnet2"
,
"resnet34"
,
"fastresnet34"
]:
# Initialize a first model
if
model_yaml
==
"xvector"
:
model
=
Xtractor
(
speaker_number
,
"xvector"
,
loss
=
loss
)
...
...
@@ -826,6 +828,8 @@ def xtrain(speaker_number,
model
=
Xtractor
(
speaker_number
,
"rawnet2"
)
elif
model_yaml
==
"resnet34"
:
model
=
Xtractor
(
speaker_number
,
"resnet34"
)
elif
model_yaml
==
"fastresnet34"
:
model
=
Xtractor
(
speaker_number
,
"fastresnet34"
)
model_archi
=
model_yaml
else
:
with
open
(
model_yaml
,
'r'
)
as
fh
:
...
...
@@ -1009,18 +1013,18 @@ def xtrain(speaker_number,
best_eer
=
100
curr_patience
=
patience
logging
.
critical
(
"Compute EER before starting"
)
val_acc
,
val_loss
,
val_eer
=
cross_validation
(
model
,
validation_loader
,
device
,
[
validation_set
.
__len__
(),
embedding_size
],
mixed_precision
)
test_eer
=
test_metrics
(
model
,
device
,
speaker_number
,
num_thread
,
mixed_precision
)
#logging.critical("Compute EER before starting")
#val_acc, val_loss, val_eer = cross_validation(model,
# validation_loader,
# device,
# [validation_set.__len__(),
# embedding_size],
# mixed_precision)
logging
.
critical
(
f
"***
{
time
.
strftime
(
'%H
:
%
M
:
%
S
', time.localtime())
}
Initial metrics - Cross validation accuracy =
{
val_acc
}
%, EER =
{
val_eer
*
100
}
%"
)
logging
.
critical
(
f
"***
{
time
.
strftime
(
'%H
:
%
M
:
%
S
', time.localtime())
}
Initial metrics - Test EER =
{
test_eer
*
100
}
%"
)
#test_eer = test_metrics(model, device, speaker_number, num_thread, mixed_precision)
test_eer
=
100
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Test EER = {test_eer * 100} %")
for
epoch
in
range
(
1
,
epochs
+
1
):
# Process one epoch and return the current model
...
...
@@ -1782,7 +1786,7 @@ def eer(negatives, positives):
n_index
=
n_index
-
next_n_jump
if
next_p_jump
==
0
and
next_n_jump
==
0
:
break
p_score
=
positives
[
p_index
]
n_score
=
negatives
[
n_index
]
next_p_jump
=
next_p_jump
//
2
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment