Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
c11e4144
Commit
c11e4144
authored
Apr 13, 2021
by
Anthony Larcher
Browse files
merge and sidesampler
parent
acb426e0
Changes
9
Hide whitespace changes
Inline
Side-by-side
__init__.py
View file @
c11e4144
...
...
@@ -35,7 +35,7 @@ import sys
# Read environment variable if it exists
SIDEKIT_CONFIG
=
{
"libsvm"
:
Tru
e
,
SIDEKIT_CONFIG
=
{
"libsvm"
:
Fals
e
,
"mpi"
:
False
,
"cuda"
:
True
}
...
...
iv_scoring.py
View file @
c11e4144
...
...
@@ -103,10 +103,10 @@ def cosine_scoring(enroll, test, ndx, wccn=None, check_missing=True, device=None
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
and
s_size_in_bytes
<
3e9
else
"cpu"
)
else
:
device
=
device
if
torch
.
cuda
.
is_available
()
and
s_size_in_bytes
<
3e9
else
torch
.
device
(
"cpu"
)
s
=
torch
.
mm
(
torch
.
FloatTensor
(
enroll_copy
.
stat1
).
to
(
device
),
torch
.
FloatTensor
(
test_copy
.
stat1
).
to
(
device
).
T
).
cpu
().
numpy
()
score
=
Scores
()
score
.
scoremat
=
s
score
.
scoremat
=
torch
.
einsum
(
'ij,kj'
,
torch
.
FloatTensor
(
enroll_copy
.
stat1
).
to
(
device
),
torch
.
FloatTensor
(
test_copy
.
stat1
).
to
(
device
)).
cpu
().
numpy
()
score
.
modelset
=
clean_ndx
.
modelset
score
.
segset
=
clean_ndx
.
segset
score
.
scoremask
=
clean_ndx
.
trialmask
...
...
nnet/__init__.py
View file @
c11e4144
...
...
@@ -33,9 +33,16 @@ from .feed_forward import kaldi_to_hdf5
from
.xsets
import
IdMapSetPerSpeaker
from
.xsets
import
SideSet
from
.xsets
import
SideSampler
from
.xvector
import
Xtractor
,
xtrain
,
extract_embeddings
,
extract_sliding_embedding
,
MeanStdPooling
from
.res_net
import
ResBlock
,
PreResNet34
from
.rawnet
import
prepare_voxceleb1
,
Vox1Set
,
PreEmphasis
from
.xvector
import
Xtractor
from
.xvector
import
xtrain
from
.xvector
import
extract_embeddings
from
.xvector
import
extract_sliding_embedding
from
.pooling
import
MeanStdPooling
from
.pooling
import
AttentivePooling
from
.pooling
import
GruPooling
from
.res_net
import
ResBlock
from
.res_net
import
PreResNet34
from
.res_net
import
PreFastResNet34
from
.sincnet
import
SincNet
from
.preprocessor
import
RawPreprocessor
from
.preprocessor
import
MfccFrontEnd
...
...
nnet/augmentation.py
View file @
c11e4144
...
...
@@ -164,26 +164,6 @@ def data_augmentation(speech,
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
if
"phone_filtering"
in
augmentations
:
speech
,
sample_rate
=
torchaudio
.
sox_effects
.
apply_effects_tensor
(
speech
,
sample_rate
,
effects
=
[
[
"lowpass"
,
"4000"
],
[
"compand"
,
"0.02,0.05"
,
"-60,-60,-30,-10,-20,-8,-5,-8,-2,-8"
,
"-8"
,
"-7"
,
"0.05"
],
[
"rate"
,
"16000"
],
])
if
"filtering"
in
augmentations
:
effects
=
[
[
"bandpass"
,
"2000"
,
"3500"
],
[
"bandstop"
,
"200"
,
"500"
]]
speech
,
sample_rate
=
torchaudio
.
sox_eefects
.
apply_effects_tensor
(
speech
,
sample_rate
,
effects
=
[
effects
[
random
.
randint
(
0
,
1
)]],
)
if
"stretch"
in
augmentations
:
strech
=
torchaudio
.
functional
.
TimeStretch
()
rate
=
random
.
uniform
(
0.8
,
1.2
)
...
...
@@ -242,6 +222,28 @@ def data_augmentation(speech,
scale
=
snr
*
noise_power
/
speech_power
speech
=
(
scale
*
speech
+
noise
)
/
2
if
"phone_filtering"
in
augmentations
:
final_shape
=
speech
.
shape
[
1
]
speech
,
sample_rate
=
torchaudio
.
sox_effects
.
apply_effects_tensor
(
speech
,
sample_rate
,
effects
=
[
[
"lowpass"
,
"4000"
],
[
"compand"
,
"0.02,0.05"
,
"-60,-60,-30,-10,-20,-8,-5,-8,-2,-8"
,
"-8"
,
"-7"
,
"0.05"
],
[
"rate"
,
"16000"
],
])
speech
=
speech
[:,
:
final_shape
]
if
"filtering"
in
augmentations
:
effects
=
[
[
"bandpass"
,
"2000"
,
"3500"
],
[
"bandstop"
,
"200"
,
"500"
]]
speech
,
sample_rate
=
torchaudio
.
sox_eefects
.
apply_effects_tensor
(
speech
,
sample_rate
,
effects
=
[
effects
[
random
.
randint
(
0
,
1
)]],
)
if
"codec"
in
augmentations
:
final_shape
=
speech
.
shape
[
1
]
configs
=
[
...
...
@@ -273,7 +275,8 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
if
noise_duration
*
sample_rate
>
speech_shape
[
1
]:
# It is recommended to split noise files (especially speech noise type) in shorter subfiles
# When frame_offset is too high, loading the segment can take much longer
frame_offset
=
random
.
randrange
(
noise_start
*
sample_rate
,
int
((
noise_start
+
noise_duration
)
*
sample_rate
-
speech_shape
[
1
]))
frame_offset
=
random
.
randrange
(
noise_start
*
sample_rate
,
int
((
noise_start
+
noise_duration
)
*
sample_rate
-
speech_shape
[
1
]))
else
:
frame_offset
=
noise_start
*
sample_rate
...
...
@@ -281,10 +284,10 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
if
noise_duration
*
sample_rate
>
speech_shape
[
1
]:
noise_seg
,
noise_sr
=
torchaudio
.
load
(
noise_fn
,
frame_offset
=
int
(
frame_offset
),
num_frames
=
int
(
speech_shape
[
1
]))
else
:
noise_seg
,
noise_sr
=
torchaudio
.
load
(
noise_fn
,
frame_offset
=
int
(
frame_offset
),
num_frames
=
int
(
noise_duration
*
sample_rate
))
noise_seg
,
noise_sr
=
torchaudio
.
load
(
noise_fn
,
frame_offset
=
int
(
frame_offset
),
num_frames
=
int
(
noise_duration
*
sample_rate
))
assert
noise_sr
==
sample_rate
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
if
noise_seg
.
shape
[
1
]
<
speech_shape
[
1
]:
noise_seg
=
torch
.
tensor
(
numpy
.
resize
(
noise_seg
.
numpy
(),
speech_shape
))
...
...
nnet/loss.py
View file @
c11e4144
...
...
@@ -304,5 +304,62 @@ class SoftmaxAngularProto(torch.nn.Module):
cos_sim_matrix
=
torch
.
nn
.
functional
.
cosine_similarity
(
out_positive
.
unsqueeze
(
-
1
),
out_anchor
.
unsqueeze
(
-
1
).
transpose
(
0
,
2
))
torch
.
clamp
(
self
.
w
,
1e-6
)
cos_sim_matrix
=
cos_sim_matrix
*
self
.
w
+
self
.
b
loss
=
self
.
criterion
(
cos_sim_matrix
,
torch
.
arange
(
0
,
cos_sim_matrix
.
shape
[
0
],
device
=
x
.
device
))
+
self
.
criterion
(
cce_prediction
,
target
)
return
loss
,
cce_prediction
return
cos_sim_matrix
,
cce_prediction
class
AngularProximityMagnet
(
torch
.
nn
.
Module
):
# from https://github.com/clovaai/voxceleb_trainer/blob/3bfd557fab5a3e6cd59d717f5029b3a20d22a281/loss/angleproto.py
def
__init__
(
self
,
spk_count
,
emb_dim
=
256
,
batch_size
=
512
,
init_w
=
10.0
,
init_b
=-
5.0
,
**
kwargs
):
super
(
AngularProximityMagnet
,
self
).
__init__
()
self
.
test_normalize
=
True
self
.
w
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
init_w
))
self
.
b1
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
init_b
))
self
.
b2
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
+
5.54
))
#last_linear = torch.nn.Linear(512, 1)
#last_linear.bias.data += 1
#self.magnitude = torch.nn.Sequential(OrderedDict([
# ("linear9", torch.nn.Linear(emb_dim, 512)),
# ("relu9", torch.nn.ReLU()),
# ("linear10", torch.nn.Linear(512, 512)),
# ("relu10", torch.nn.ReLU()),
# ("linear11", last_linear),
# ("relu11", torch.nn.ReLU())
# ]))
self
.
cce_backend
=
torch
.
nn
.
Sequential
(
OrderedDict
([
(
"linear8"
,
torch
.
nn
.
Linear
(
emb_dim
,
spk_count
))
]))
self
.
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
self
.
magnet_criterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
reduction
=
'mean'
)
def
forward
(
self
,
x
,
target
=
None
):
assert
x
.
size
()[
1
]
>=
2
cce_prediction
=
self
.
cce_backend
(
x
)
#x = self.magnitude(x) * torch.nn.functional.normalize(x)
if
target
==
None
:
return
x
,
cce_prediction
x
=
x
.
reshape
(
-
1
,
2
,
x
.
size
()[
-
1
]).
squeeze
(
1
)
out_anchor
=
torch
.
mean
(
x
[:,
1
:,:],
1
)
out_positive
=
x
[:,
0
,:]
ap_sim_matrix
=
torch
.
nn
.
functional
.
cosine_similarity
(
out_positive
.
unsqueeze
(
-
1
),
out_anchor
.
unsqueeze
(
-
1
).
transpose
(
0
,
2
))
torch
.
clamp
(
self
.
w
,
1e-6
)
ap_sim_matrix
=
ap_sim_matrix
*
self
.
w
+
self
.
b1
labels
=
torch
.
arange
(
0
,
int
(
out_positive
.
shape
[
0
]),
device
=
torch
.
device
(
"cuda:0"
)).
unsqueeze
(
1
)
cos_sim_matrix
=
torch
.
mm
(
out_positive
,
out_anchor
.
T
)
cos_sim_matrix
=
cos_sim_matrix
+
self
.
b2
cos_sim_matrix
=
cos_sim_matrix
+
numpy
.
log
(
1
/
out_positive
.
shape
[
0
]
/
(
1
-
1
/
out_positive
.
shape
[
0
]))
mask
=
(
torch
.
tile
(
labels
,
(
1
,
labels
.
shape
[
0
]))
==
labels
.
T
).
float
()
batch_loss
=
self
.
criterion
(
ap_sim_matrix
,
torch
.
arange
(
0
,
int
(
out_positive
.
shape
[
0
]),
device
=
torch
.
device
(
"cuda:0"
)))
\
+
self
.
magnet_criterion
(
cos_sim_matrix
.
flatten
().
unsqueeze
(
1
),
mask
.
flatten
().
unsqueeze
(
1
))
return
batch_loss
,
cce_prediction
nnet/preprocessor.py
View file @
c11e4144
...
...
@@ -162,9 +162,9 @@ class MelSpecFrontEnd(torch.nn.Module):
n_fft
=
1024
,
f_min
=
90
,
f_max
=
7600
,
win_length
=
400
,
win_length
=
1024
,
window_fn
=
torch
.
hann_window
,
hop_length
=
160
,
hop_length
=
256
,
power
=
2.0
,
n_mels
=
80
):
...
...
@@ -227,7 +227,6 @@ class MelSpecFrontEnd(torch.nn.Module):
return
out
class
RawPreprocessor
(
torch
.
nn
.
Module
):
"""
...
...
nnet/res_net.py
View file @
c11e4144
...
...
@@ -268,7 +268,28 @@ class ResBlock(torch.nn.Module):
return
out
class
SELayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
channel
,
reduction
=
16
):
super
(
SELayer
,
self
).
__init__
()
self
.
avg_pool
=
torch
.
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
fc
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
channel
,
channel
//
reduction
,
bias
=
False
),
torch
.
nn
.
ReLU
(
inplace
=
True
),
torch
.
nn
.
Linear
(
channel
//
reduction
,
channel
,
bias
=
False
),
torch
.
nn
.
Sigmoid
()
)
def
forward
(
self
,
x
):
b
,
c
,
_
,
_
=
x
.
size
()
y
=
self
.
avg_pool
(
x
).
view
(
b
,
c
)
y
=
self
.
fc
(
y
).
view
(
b
,
c
,
1
,
1
)
return
x
*
y
.
expand_as
(
x
)
class
BasicBlock
(
torch
.
nn
.
Module
):
"""
"""
expansion
=
1
def
__init__
(
self
,
in_planes
,
planes
,
stride
=
1
):
...
...
@@ -280,6 +301,8 @@ class BasicBlock(torch.nn.Module):
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
planes
)
self
.
se
=
SELayer
(
planes
)
self
.
shortcut
=
torch
.
nn
.
Sequential
()
if
stride
!=
1
or
in_planes
!=
self
.
expansion
*
planes
:
self
.
shortcut
=
torch
.
nn
.
Sequential
(
...
...
@@ -291,6 +314,7 @@ class BasicBlock(torch.nn.Module):
def
forward
(
self
,
x
):
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
bn2
(
self
.
conv2
(
out
))
out
=
self
.
se
(
out
)
out
+=
self
.
shortcut
(
x
)
out
=
torch
.
nn
.
functional
.
relu
(
out
)
return
out
...
...
@@ -463,11 +487,13 @@ class PreFastResNet34(torch.nn.Module):
def
forward
(
self
,
x
):
out
=
x
.
unsqueeze
(
1
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
channels_last
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
out
=
self
.
layer1
(
out
)
out
=
self
.
layer2
(
out
)
out
=
self
.
layer3
(
out
)
out
=
self
.
layer4
(
out
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
contiguous_format
)
out
=
torch
.
flatten
(
out
,
start_dim
=
1
,
end_dim
=
2
)
return
out
...
...
@@ -475,119 +501,3 @@ class PreFastResNet34(torch.nn.Module):
def
ResNet34
():
return
ResNet
(
BasicBlock
,
[
3
,
1
,
3
,
1
,
5
,
1
,
2
])
def
restrain
(
args
):
"""
Initialize and train an ResNet for Speaker Recognition
:param args:
:return:
"""
# Initialize a first model and save to disk
model
=
ResNet18
(
args
.
class_number
,
entry_conv_kernel_size
=
(
7
,
7
),
entry_conv_out_channels
=
64
,
megablock_out_channels
=
(
64
,
128
,
128
,
128
),
megablock_size
=
(
2
,
2
,
2
,
2
),
block_type
=
ResBlock
)
current_model_file_name
=
"initial_model"
torch
.
save
(
model
.
state_dict
(),
current_model_file_name
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
current_model_file_name
=
train_resnet_epoch
(
epoch
,
args
,
current_model_file_name
)
# Add the cross validation here
accuracy
=
resnet_cross_validation
(
args
,
current_model_file_name
)
print
(
"*** Cross validation accuracy = {} %"
.
format
(
accuracy
))
# Decrease learning rate after every epoch
args
.
lr
=
args
.
lr
*
0.9
print
(
" Decrease learning rate: {}"
.
format
(
args
.
lr
))
def
train_resnet_epoch
(
model
,
epoch
,
train_seg_df
,
speaker_dict
,
args
):
"""
:param model:
:param epoch:
:param train_seg_df:
:param args:
:return:
"""
device
=
torch
.
device
(
"cuda:0"
)
torch
.
manual_seed
(
args
.
seed
)
train_transform
=
[]
if
not
args
.
train_transformation
==
''
:
trans
=
args
.
train_transformation
.
split
(
','
)
for
t
in
trans
:
if
"CMVN"
in
t
:
train_transform
.
append
(
CMVN
())
if
"FrequencyMask"
in
t
:
a
=
t
.
split
(
","
)[
0
].
split
(
"("
)[
1
]
b
=
t
.
split
(
","
)[
1
].
split
(
"("
)[
0
]
train_transform
.
append
(
FrequencyMask
(
a
,
b
))
if
"TemporalMask"
in
t
:
a
=
t
.
split
(
","
)[
0
].
split
(
"("
)[
1
]
train_transform
.
append
(
TemporalMask
(
a
,
b
))
train_set
=
VoxDataset
(
train_seg_df
,
speaker_dict
,
500
,
transform
=
transforms
.
Compose
(
train_transform
),
spec_aug_ratio
=
args
.
spec_aug
,
temp_aug_ratio
=
args
.
temp_aug
)
train_loader
=
DataLoader
(
train_set
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
15
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
())
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
accuracy
=
0.0
for
batch_idx
,
(
data
,
target
,
_
,
__
)
in
enumerate
(
train_loader
):
target
=
target
.
squeeze
()
optimizer
.
zero_grad
()
output
=
model
(
data
.
to
(
device
))
loss
=
criterion
(
output
,
target
.
to
(
device
))
loss
.
backward
()
optimizer
.
step
()
accuracy
+=
(
torch
.
argmax
(
output
.
data
,
1
)
==
target
.
to
(
device
)).
sum
()
if
batch_idx
%
args
.
log_interval
==
0
:
logging
.
critical
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
epoch
,
batch_idx
+
1
,
train_loader
.
__len__
(),
100.
*
batch_idx
/
train_loader
.
__len__
(),
loss
.
item
(),
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
args
.
batch_size
)))
return
model
def
resnet_cross_validation
(
args
,
model
,
cv_seg_df
,
speaker_dict
):
"""
:param args:
:param model:
:param cv_seg_df:
:return:
"""
cv_transform
=
[]
if
not
args
.
cv_transformation
==
''
:
trans
=
args
.
cv_transformation
.
split
(
','
)
for
t
in
trans
:
if
"CMVN"
in
t
:
cv_transform
.
append
(
CMVN
())
if
"FrequencyMask"
in
t
:
a
=
t
.
split
(
","
)[
0
].
split
(
"("
)[
1
]
b
=
t
.
split
(
","
)[
1
].
split
(
"("
)[
0
]
cv_transform
.
append
(
FrequencyMask
(
a
,
b
))
if
"TemporalMask"
in
t
:
a
=
t
.
split
(
","
)[
0
].
split
(
"("
)[
1
]
cv_transform
.
append
(
TemporalMask
(
a
,
b
))
cv_set
=
VoxDataset
(
cv_seg_df
,
speaker_dict
,
500
,
transform
=
transforms
.
Compose
(
cv_transform
),
spec_aug_ratio
=
args
.
spec_aug
,
temp_aug_ratio
=
args
.
temp_aug
)
cv_loader
=
DataLoader
(
cv_set
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
15
)
model
.
eval
()
device
=
torch
.
device
(
"cuda:0"
)
model
.
to
(
device
)
accuracy
=
0.0
print
(
cv_set
.
__len__
())
for
batch_idx
,
(
data
,
target
,
_
,
__
)
in
enumerate
(
cv_loader
):
target
=
target
.
squeeze
()
output
=
model
(
data
.
to
(
device
))
accuracy
+=
(
torch
.
argmax
(
output
.
data
,
1
)
==
target
.
to
(
device
)).
sum
()
return
100.
*
accuracy
.
cpu
().
numpy
()
/
((
batch_idx
+
1
)
*
args
.
batch_size
)
nnet/xsets.py
View file @
c11e4144
...
...
@@ -389,7 +389,6 @@ class IdMapSet(Dataset):
if
"add_noise"
in
self
.
transformation
:
# Load the noise dataset, filter according to the duration
noise_df
=
pandas
.
read_csv
(
self
.
transformation
[
"add_noise"
][
"noise_db_csv"
])
#tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self
.
noise_df
=
noise_df
.
set_index
(
noise_df
.
type
)
self
.
rir_df
=
None
...
...
nnet/xvector.py
View file @
c11e4144
...
...
@@ -30,13 +30,10 @@ import logging
import
math
import
os
import
numpy
import
random
import
pandas
import
pickle
import
shutil
import
tabulate
import
time
import
torch
import
torchaudio
import
tqdm
import
yaml
...
...
@@ -64,9 +61,11 @@ from ..statserver import StatServer
from
..iv_scoring
import
cosine_scoring
from
.sincnet
import
SincNet
from
..bosaris.detplot
import
rocch
,
rocch2eer
from
.loss
import
SoftmaxAngularProto
,
ArcLinear
from
.loss
import
SoftmaxAngularProto
from
.loss
import
l2_norm
from
.loss
import
ArcMarginProduct
from
.loss
import
ArcLinear
from
.loss
import
AngularProximityMagnet
os
.
environ
[
'MKL_THREADING_LAYER'
]
=
'GNU'
...
...
@@ -80,17 +79,25 @@ __status__ = "Production"
__docformat__
=
'reS'
def
ee
r
(
negatives
,
positives
):
"""
Logarithmic complexity EER computation
def
s
ee
d_worker
(
):
"""
Args:
negative_scores (numpy array): impostor scores
positive_scores (numpy array): genuine scores
:param worker_id:
:return:
"""
worker_seed
=
torch
.
initial_seed
()
%
2
**
32
numpy
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
Returns:
float: Equal Error Rate (EER)
def
eer
(
negatives
,
positives
):
"""
Logarithmic complexity EER computation
:param negatives: negative_scores (numpy array): impostor scores
:param positives: positive_scores (numpy array): genuine scores
:return: float: Equal Error Rate (EER)
"""
positives
=
numpy
.
sort
(
positives
)
negatives
=
numpy
.
sort
(
negatives
)[::
-
1
]
...
...
@@ -234,7 +241,6 @@ def test_metrics(model,
device
=
device
).
get_tar_non
(
Key
(
data_opts
[
"test"
][
"key"
]))
#test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
pmiss
,
pfa
=
rocch
(
tar
,
non
)
return
rocch2eer
(
pmiss
,
pfa
)
...
...
@@ -476,7 +482,7 @@ class Xtractor(torch.nn.Module):
self
.
before_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
2560
,
out_features
=
self
.
embedding_size
)
self
.
stat_pooling
=
MeanStdPooling
(
)
self
.
stat_pooling
=
AttentivePooling
(
128
,
80
,
global_context
=
False
)
self
.
stat_pooling_weight_decay
=
0
self
.
loss
=
loss
...
...
@@ -489,6 +495,8 @@ class Xtractor(torch.nn.Module):
elif
self
.
loss
==
'aps'
:
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
))
elif
self
.
loss
==
'smn'
:
self
.
after_speaker_embedding
=
AngularProximityMagnet
(
int
(
self
.
speaker_number
))
self
.
preprocessor_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
...
...
@@ -908,7 +916,9 @@ def update_training_dictionary(dataset_description,
# Initialize training options
training_opts
[
"log_file"
]
=
"sidekit.log"
training_opts
[
"seed"
]
=
42
training_opts
[
"numpy_seed"
]
=
0
training_opts
[
"torch_seed"
]
=
0
training_opts
[
"random_seed"
]
=
0
training_opts
[
"deterministic"
]
=
False
training_opts
[
"epochs"
]
=
100
training_opts
[
"lr"
]
=
1e-3
...
...
@@ -995,7 +1005,6 @@ def get_network(model_opts, local_rank):
if
local_rank
<
1
:
logging
.
info
(
model
)
logging
.
info
(
"Model_parameters_count: {:d}"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
sequence_network
.
parameters
()
...
...
@@ -1010,7 +1019,7 @@ def get_network(model_opts, local_rank):
return
model
def
get_loaders
(
dataset_opts
,
training_opts
,
model_opts
):
def
get_loaders
(
dataset_opts
,
training_opts
,
model_opts
,
local_rank
=
0
):
"""
:param dataset_opts:
...
...
@@ -1055,19 +1064,28 @@ def get_loaders(dataset_opts, training_opts, model_opts):
assert
dataset_opts
[
"batch_size"
]
%
samples_per_speaker
==
0
batch_size
=
dataset_opts
[
"batch_size"
]
//
torch
.
cuda
.
device_count
()
side_sampler
=
SideSampler
(
training_set
.
sessions
[
'speaker_idx'
],
model_opts
[
"speaker_number"
],
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
dataset_opts
[
"train"
][
"sampler"
][
"samples_per_speaker"
],
dataset_opts
[
"batch_size"
])
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
samples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"samples_per_speaker"
],
batch_size
=
batch_size
,
seed
=
training_opts
[
'torch_seed'
],
rank
=
local_rank
,
num_process
=
torch
.
cuda
.
device_count
(),
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replicas"
]
)
else
:
batch_size
=
dataset_opts
[
"batch_size"
]
side_sampler
=
SideSampler
(
training_set
.
sessions
[
'speaker_idx'
],
model_opts
[
"speaker_number"
],
samples_per_speaker
,
batch_size
,
batch_size
,
seed
=
dataset_opts
[
'seed'
])
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
samples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"samples_per_speaker"
],
batch_size
=
batch_size
,
seed
=
training_opts
[
'torch_seed'
],
rank
=
0
,
num_process
=
torch
.
cuda
.
device_count
(),
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replicas"
]
)
training_loader
=
DataLoader
(
training_set
,
batch_size
=
batch_size
,
...
...
@@ -1076,14 +1094,16 @@ def get_loaders(dataset_opts, training_opts, model_opts):
pin_memory
=
True
,
sampler
=
side_sampler
,
num_workers
=
training_opts
[
"num_cpu"
],
persistent_workers
=
False
)
persistent_workers
=
False
,
worker_init_fn
=
seed_worker
)
validation_loader
=
DataLoader
(
validation_set
,
batch_size
=
batch_size
,
drop_last
=
False
,
pin_memory
=
True
,
num_workers
=
training_opts
[
"num_cpu"
],
persistent_workers
=
False
)
persistent_workers
=
False
,
worker_init_fn
=
seed_worker
)
# Compute indices for target and non-target trials once only to avoid recomputing for each epoch
classes
=
torch
.
ShortTensor
(
validation_set
.
sessions
[
'speaker_idx'
].
to_numpy
())
...
...
@@ -1161,12 +1181,12 @@ def get_optimizer(model, model_opts, train_opts):
elif
train_opts
[
"scheduler"
][
"type"
]
==
"StepLR"
:
scheduler
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
optimizer
=
optimizer
,
step_size
=
2e3
,
step_size
=
1
*
training_loader
.
__len__
()
,
gamma
=
0.95
)
elif
train_opts
[
"scheduler"
][
"type"
]
==
"StepLR2"
:
scheduler
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
optimizer
=
optimizer
,
step_size
=
2000
,
step_size
=
1
*
training_loader
.
__len__
()
,
gamma
=
0.5
)
else
:
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
=
optimizer
,
...
...
@@ -1249,9 +1269,9 @@ def xtrain(dataset_description,
torch
.
backends
.
cudnn
.
deterministic
=
True
# Set all the seeds
numpy
.
random
.
seed
(
training_opts
[
"seed"
])
# Set the random seed of numpy for the data split.
torch
.
manual_seed
(
training_opts
[
"seed"
])
torch
.
cuda
.
manual_seed
(
training_opts
[
"seed"
])
numpy
.
random
.
seed
(
training_opts
[
"
numpy_
seed"
])
# Set the random seed of numpy for the data split.
torch
.
manual_seed
(
training_opts
[
"
torch_
seed"
])
torch
.
cuda
.
manual_seed
(
training_opts
[
"
torch_
seed"
])
# Display the entire configurations as YAML dictionaries
if
local_rank
<
1
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment