Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
5bf6d959
Commit
5bf6d959
authored
Jan 03, 2022
by
Anthony Larcher
Browse files
adding wavlm
parent
c4684601
Changes
10
Hide whitespace changes
Inline
Side-by-side
__init__.py
View file @
5bf6d959
...
...
@@ -4,7 +4,7 @@
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
PARALLEL_MODULE
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
...
...
@@ -50,8 +50,8 @@ if 'SIDEKIT' in os.environ:
if
val
==
"true"
:
SIDEKIT_CONFIG
[
"mpi"
]
=
True
if
k
==
"cuda"
:
if
val
==
"
fals
e"
:
SIDEKIT_CONFIG
[
"cuda"
]
=
Fals
e
if
val
==
"
tru
e"
:
SIDEKIT_CONFIG
[
"cuda"
]
=
Tru
e
PARALLEL_MODULE
=
'multiprocessing'
# can be , threading, multiprocessing MPI is planned in the future
...
...
nnet/augmentation.py
View file @
5bf6d959
...
...
@@ -186,7 +186,7 @@ def data_augmentation(speech,
rir_fn
=
transform_dict
[
"add_reverb"
][
"data_path"
]
+
rir_nfo
# TODO harmonize with noise
rir
,
rir_fs
=
torchaudio
.
load
(
rir_fn
)
assert
rir_fs
==
sample_rate
#rir = rir[rir_nfo[1], :] #keep selected channel
#
rir = rir[rir_nfo[1], :] #keep selected channel
speech
=
torch
.
tensor
(
signal
.
convolve
(
speech
,
rir
,
mode
=
'full'
)[:,
:
speech
.
shape
[
1
]])
if
"add_noise"
in
augmentations
:
...
...
@@ -261,11 +261,10 @@ def data_augmentation(speech,
)
if
"codec"
in
augmentations
:
final_shape
=
speech
.
shape
[
1
]
final_shape
=
speech
.
shape
[
1
]
configs
=
[
({
"format"
:
"wav"
,
"encoding"
:
'ULAW'
,
"bits_per_sample"
:
8
},
"8 bit mu-law"
),
({
"format"
:
"wav"
,
"encoding"
:
'ALAW'
,
"bits_per_sample"
:
8
},
"8 bit a-law"
),
({
"format"
:
"gsm"
},
"GSM-FR"
),
({
"format"
:
"mp3"
,
"compression"
:
-
9
},
"MP3"
),
({
"format"
:
"vorbis"
,
"compression"
:
-
1
},
"Vorbis"
)
]
...
...
nnet/ecapa_tdnn.py
0 → 100644
View file @
5bf6d959
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchaudio.transforms
as
trans
#from .utils import UpstreamExpert
''' Res2Conv1d + BatchNorm1d + ReLU
'''
class
Res2Conv1dReluBn
(
nn
.
Module
):
'''
in_channels == out_channels == channels
'''
def
__init__
(
self
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
bias
=
True
,
scale
=
4
):
super
().
__init__
()
assert
channels
%
scale
==
0
,
"{} % {} != 0"
.
format
(
channels
,
scale
)
self
.
scale
=
scale
self
.
width
=
channels
//
scale
self
.
nums
=
scale
if
scale
==
1
else
scale
-
1
self
.
convs
=
[]
self
.
bns
=
[]
for
i
in
range
(
self
.
nums
):
self
.
convs
.
append
(
nn
.
Conv1d
(
self
.
width
,
self
.
width
,
kernel_size
,
stride
,
padding
,
dilation
,
bias
=
bias
))
self
.
bns
.
append
(
nn
.
BatchNorm1d
(
self
.
width
))
self
.
convs
=
nn
.
ModuleList
(
self
.
convs
)
self
.
bns
=
nn
.
ModuleList
(
self
.
bns
)
def
forward
(
self
,
x
):
out
=
[]
spx
=
torch
.
split
(
x
,
self
.
width
,
1
)
for
i
in
range
(
self
.
nums
):
if
i
==
0
:
sp
=
spx
[
i
]
else
:
sp
=
sp
+
spx
[
i
]
# Order: conv -> relu -> bn
sp
=
self
.
convs
[
i
](
sp
)
sp
=
self
.
bns
[
i
](
F
.
relu
(
sp
))
out
.
append
(
sp
)
if
self
.
scale
!=
1
:
out
.
append
(
spx
[
self
.
nums
])
out
=
torch
.
cat
(
out
,
dim
=
1
)
return
out
''' Conv1d + BatchNorm1d + ReLU
'''
class
Conv1dReluBn
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
bias
=
True
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
bias
=
bias
)
self
.
bn
=
nn
.
BatchNorm1d
(
out_channels
)
def
forward
(
self
,
x
):
return
self
.
bn
(
F
.
relu
(
self
.
conv
(
x
)))
''' The SE connection of 1D case.
'''
class
SE_Connect
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
se_bottleneck_dim
=
128
):
super
().
__init__
()
self
.
linear1
=
nn
.
Linear
(
channels
,
se_bottleneck_dim
)
self
.
linear2
=
nn
.
Linear
(
se_bottleneck_dim
,
channels
)
def
forward
(
self
,
x
):
out
=
x
.
mean
(
dim
=
2
)
out
=
F
.
relu
(
self
.
linear1
(
out
))
out
=
torch
.
sigmoid
(
self
.
linear2
(
out
))
out
=
x
*
out
.
unsqueeze
(
2
)
return
out
''' SE-Res2Block of the ECAPA-TDNN architecture.
'''
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class
SE_Res2Block
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
scale
,
se_bottleneck_dim
):
super
().
__init__
()
self
.
Conv1dReluBn1
=
Conv1dReluBn
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
Res2Conv1dReluBn
=
Res2Conv1dReluBn
(
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
scale
=
scale
)
self
.
Conv1dReluBn2
=
Conv1dReluBn
(
out_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
SE_Connect
=
SE_Connect
(
out_channels
,
se_bottleneck_dim
)
self
.
shortcut
=
None
if
in_channels
!=
out_channels
:
self
.
shortcut
=
nn
.
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
):
residual
=
x
if
self
.
shortcut
:
residual
=
self
.
shortcut
(
x
)
x
=
self
.
Conv1dReluBn1
(
x
)
x
=
self
.
Res2Conv1dReluBn
(
x
)
x
=
self
.
Conv1dReluBn2
(
x
)
x
=
self
.
SE_Connect
(
x
)
return
x
+
residual
''' Attentive weighted mean and standard deviation pooling.
'''
class
AttentiveStatsPool
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
attention_channels
=
128
,
global_context_att
=
False
):
super
().
__init__
()
self
.
global_context_att
=
global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if
global_context_att
:
self
.
linear1
=
nn
.
Conv1d
(
in_dim
*
3
,
attention_channels
,
kernel_size
=
1
)
# equals W and b in the paper
else
:
self
.
linear1
=
nn
.
Conv1d
(
in_dim
,
attention_channels
,
kernel_size
=
1
)
# equals W and b in the paper
self
.
linear2
=
nn
.
Conv1d
(
attention_channels
,
in_dim
,
kernel_size
=
1
)
# equals V and k in the paper
def
forward
(
self
,
x
):
if
self
.
global_context_att
:
context_mean
=
torch
.
mean
(
x
,
dim
=-
1
,
keepdim
=
True
).
expand_as
(
x
)
context_std
=
torch
.
sqrt
(
torch
.
var
(
x
,
dim
=-
1
,
keepdim
=
True
)
+
1e-10
).
expand_as
(
x
)
x_in
=
torch
.
cat
((
x
,
context_mean
,
context_std
),
dim
=
1
)
else
:
x_in
=
x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha
=
torch
.
tanh
(
self
.
linear1
(
x_in
))
# alpha = F.relu(self.linear1(x_in))
alpha
=
torch
.
softmax
(
self
.
linear2
(
alpha
),
dim
=
2
)
mean
=
torch
.
sum
(
alpha
*
x
,
dim
=
2
)
residuals
=
torch
.
sum
(
alpha
*
(
x
**
2
),
dim
=
2
)
-
mean
**
2
std
=
torch
.
sqrt
(
residuals
.
clamp
(
min
=
1e-9
))
return
torch
.
cat
([
mean
,
std
],
dim
=
1
)
class
ECAPA_TDNN
(
nn
.
Module
):
def
__init__
(
self
,
feat_dim
=
80
,
channels
=
512
,
emb_dim
=
192
,
global_context_att
=
False
,
feat_type
=
'fbank'
,
sr
=
16000
,
feature_selection
=
"hidden_states"
,
update_extract
=
False
,
config_path
=
None
):
super
().
__init__
()
self
.
feat_type
=
feat_type
self
.
feature_selection
=
feature_selection
self
.
update_extract
=
update_extract
self
.
sr
=
sr
if
feat_type
==
"fbank"
or
feat_type
==
"mfcc"
:
self
.
update_extract
=
False
win_len
=
int
(
sr
*
0.025
)
hop_len
=
int
(
sr
*
0.01
)
if
feat_type
==
'fbank'
:
self
.
feature_extract
=
trans
.
MelSpectrogram
(
sample_rate
=
sr
,
n_fft
=
512
,
win_length
=
win_len
,
hop_length
=
hop_len
,
f_min
=
0.0
,
f_max
=
sr
//
2
,
pad
=
0
,
n_mels
=
feat_dim
)
elif
feat_type
==
'mfcc'
:
melkwargs
=
{
'n_fft'
:
512
,
'win_length'
:
win_len
,
'hop_length'
:
hop_len
,
'f_min'
:
0.0
,
'f_max'
:
sr
//
2
,
'pad'
:
0
}
self
.
feature_extract
=
trans
.
MFCC
(
sample_rate
=
sr
,
n_mfcc
=
feat_dim
,
log_mels
=
False
,
melkwargs
=
melkwargs
)
else
:
if
config_path
is
None
:
self
.
feature_extract
=
torch
.
hub
.
load
(
's3prl/s3prl'
,
feat_type
)
#else:
# self.feature_extract = UpstreamExpert(config_path)
if
len
(
self
.
feature_extract
.
model
.
encoder
.
layers
)
==
24
and
hasattr
(
self
.
feature_extract
.
model
.
encoder
.
layers
[
23
].
self_attn
,
"fp32_attention"
):
self
.
feature_extract
.
model
.
encoder
.
layers
[
23
].
self_attn
.
fp32_attention
=
False
if
len
(
self
.
feature_extract
.
model
.
encoder
.
layers
)
==
24
and
hasattr
(
self
.
feature_extract
.
model
.
encoder
.
layers
[
11
].
self_attn
,
"fp32_attention"
):
self
.
feature_extract
.
model
.
encoder
.
layers
[
11
].
self_attn
.
fp32_attention
=
False
self
.
feat_num
=
self
.
get_feat_num
()
self
.
feature_weight
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
feat_num
))
if
feat_type
!=
'fbank'
and
feat_type
!=
'mfcc'
:
freeze_list
=
[
'final_proj'
,
'label_embs_concat'
,
'mask_emb'
,
'project_q'
,
'quantizer'
]
for
name
,
param
in
self
.
feature_extract
.
named_parameters
():
for
freeze_val
in
freeze_list
:
if
freeze_val
in
name
:
param
.
requires_grad
=
False
break
if
not
self
.
update_extract
:
for
param
in
self
.
feature_extract
.
parameters
():
param
.
requires_grad
=
False
self
.
instance_norm
=
nn
.
InstanceNorm1d
(
feat_dim
)
# self.channels = [channels] * 4 + [channels * 3]
self
.
channels
=
[
channels
]
*
4
+
[
1536
]
self
.
layer1
=
Conv1dReluBn
(
feat_dim
,
self
.
channels
[
0
],
kernel_size
=
5
,
padding
=
2
)
self
.
layer2
=
SE_Res2Block
(
self
.
channels
[
0
],
self
.
channels
[
1
],
kernel_size
=
3
,
stride
=
1
,
padding
=
2
,
dilation
=
2
,
scale
=
8
,
se_bottleneck_dim
=
128
)
self
.
layer3
=
SE_Res2Block
(
self
.
channels
[
1
],
self
.
channels
[
2
],
kernel_size
=
3
,
stride
=
1
,
padding
=
3
,
dilation
=
3
,
scale
=
8
,
se_bottleneck_dim
=
128
)
self
.
layer4
=
SE_Res2Block
(
self
.
channels
[
2
],
self
.
channels
[
3
],
kernel_size
=
3
,
stride
=
1
,
padding
=
4
,
dilation
=
4
,
scale
=
8
,
se_bottleneck_dim
=
128
)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels
=
channels
*
3
self
.
conv
=
nn
.
Conv1d
(
cat_channels
,
self
.
channels
[
-
1
],
kernel_size
=
1
)
#self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
self
.
bn
=
nn
.
BatchNorm1d
(
self
.
channels
[
-
1
])
#self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def
get_feat_num
(
self
):
self
.
feature_extract
.
eval
()
wav
=
[
torch
.
randn
(
self
.
sr
).
to
(
next
(
self
.
feature_extract
.
parameters
()).
device
)]
with
torch
.
no_grad
():
features
=
self
.
feature_extract
(
wav
)
select_feature
=
features
[
self
.
feature_selection
]
if
isinstance
(
select_feature
,
(
list
,
tuple
)):
return
len
(
select_feature
)
else
:
return
1
def
get_feat
(
self
,
x
):
if
self
.
update_extract
:
x
=
self
.
feature_extract
([
sample
for
sample
in
x
])
else
:
with
torch
.
no_grad
():
if
self
.
feat_type
==
'fbank'
or
self
.
feat_type
==
'mfcc'
:
x
=
self
.
feature_extract
(
x
)
+
1e-6
# B x feat_dim x time_len
else
:
x
=
self
.
feature_extract
([
sample
for
sample
in
x
])
if
self
.
feat_type
==
'fbank'
:
x
=
x
.
log
()
if
self
.
feat_type
!=
"fbank"
and
self
.
feat_type
!=
"mfcc"
:
x
=
x
[
self
.
feature_selection
]
if
isinstance
(
x
,
(
list
,
tuple
)):
x
=
torch
.
stack
(
x
,
dim
=
0
)
else
:
x
=
x
.
unsqueeze
(
0
)
norm_weights
=
F
.
softmax
(
self
.
feature_weight
,
dim
=-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
x
=
(
norm_weights
*
x
).
sum
(
dim
=
0
)
x
=
torch
.
transpose
(
x
,
1
,
2
)
+
1e-6
x
=
self
.
instance_norm
(
x
)
return
x
def
forward
(
self
,
x
):
#x = self.get_feat(x)
out1
=
self
.
layer1
(
x
)
out2
=
self
.
layer2
(
out1
)
out3
=
self
.
layer3
(
out2
)
out4
=
self
.
layer4
(
out3
)
out
=
torch
.
cat
([
out2
,
out3
,
out4
],
dim
=
1
)
out
=
self
.
bn
(
F
.
relu
(
self
.
conv
(
out
)))
#out = self.bn(self.pooling(out))
#out = self.linear(out)
return
out
def
ECAPA_TDNN_SMALL
(
feat_dim
,
emb_dim
=
256
,
feat_type
=
'fbank'
,
sr
=
16000
,
feature_selection
=
"hidden_states"
,
update_extract
=
False
,
config_path
=
None
):
return
ECAPA_TDNN
(
feat_dim
=
feat_dim
,
channels
=
512
,
emb_dim
=
emb_dim
,
feat_type
=
feat_type
,
sr
=
sr
,
feature_selection
=
feature_selection
,
update_extract
=
update_extract
,
config_path
=
config_path
)
if
__name__
==
'__main__'
:
x
=
torch
.
zeros
(
2
,
32000
)
model
=
ECAPA_TDNN_SMALL
(
feat_dim
=
768
,
emb_dim
=
256
,
feat_type
=
'hubert_base'
,
feature_selection
=
"hidden_states"
,
update_extract
=
False
)
out
=
model
(
x
)
# print(model)
print
(
out
.
shape
)
nnet/loss.py
View file @
5bf6d959
...
...
@@ -435,3 +435,38 @@ class AngularProximityMagnet(torch.nn.Module):
batch_loss
=
self
.
criterion
(
ap_sim_matrix
,
torch
.
arange
(
0
,
int
(
out_positive
.
shape
[
0
]),
device
=
torch
.
device
(
"cuda:0"
)))
\
+
self
.
magnet_criterion
(
cos_sim_matrix
.
flatten
().
unsqueeze
(
1
),
mask
.
flatten
().
unsqueeze
(
1
))
return
batch_loss
,
cce_prediction
class
CircleMargin
(
torch
.
nn
.
Module
):
"""
"""
def
__init__
(
self
,
in_features
,
out_features
,
s
=
256
,
m
=
0.25
)
->
None
:
super
(
CircleMargin
,
self
).
__init__
()
self
.
margin
=
m
self
.
gamma
=
s
self
.
weight
=
Parameter
(
torch
.
FloatTensor
(
out_features
,
in_features
))
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
def
forward
(
self
,
x
,
target
=
None
):
"""
:param x:
:param target:
:return:
"""
cosine
=
torch
.
nn
.
functional
.
linear
(
torch
.
nn
.
functional
.
normalize
(
x
),
torch
.
nn
.
functional
.
normalize
(
self
.
weight
))
if
target
is
None
:
return
cosine
*
self
.
gamma
one_hot
=
torch
.
zeros_like
(
cosine
)
one_hot
.
scatter_
(
1
,
target
.
view
(
-
1
,
1
),
1
)
output
=
(
one_hot
*
(
self
.
margin
**
2
-
(
1
-
cosine
)
**
2
))
+
\
((
1.0
-
one_hot
)
*
(
cosine
**
2
-
self
.
margin
**
2
))
output
=
output
*
self
.
gamma
return
output
,
cosine
*
self
.
gamma
nnet/pooling.py
View file @
5bf6d959
...
...
@@ -71,6 +71,9 @@ class MeanStdPooling(torch.nn.Module):
class
ChannelWiseCorrPooling
(
torch
.
nn
.
Module
):
"""
"""
def
__init__
(
self
,
in_channels
=
256
,
out_channels
=
64
,
in_freqs
=
10
,
channels_dropout
=
0.25
):
super
(
ChannelWiseCorrPooling
,
self
).
__init__
()
self
.
channels_dropout
=
channels_dropout
...
...
@@ -80,7 +83,7 @@ class ChannelWiseCorrPooling(torch.nn.Module):
self
.
out_channels
=
out_channels
self
.
out_dim
=
int
(
self
.
out_channels
*
(
self
.
out_channels
-
1
)
/
2
)
*
self
.
groups
self
.
L_proj
=
torch
.
nn
.
Conv2d
(
in_channels
*
self
.
groups
,
out_channels
*
self
.
groups
,
kernel_size
=
(
1
,
1
),
groups
=
self
.
groups
)
#self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
#
self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self
.
mask
=
torch
.
tril
(
torch
.
ones
((
out_channels
,
out_channels
)),
diagonal
=-
1
).
type
(
torch
.
BoolTensor
)
def
forward
(
self
,
x
):
...
...
@@ -94,34 +97,34 @@ class ChannelWiseCorrPooling(torch.nn.Module):
self
.
mask
=
self
.
mask
.
to
(
x
.
device
)
if
self
.
training
:
x
*=
torch
.
nn
.
functional
.
dropout
(
torch
.
ones
((
1
,
x
.
shape
[
1
],
1
,
1
),
device
=
x
.
device
),
p
=
self
.
channels_dropout
)
#[B, T, C, F]
#
[B, T, C, F]
x
=
x
.
permute
(
0
,
2
,
1
,
3
)
#[B, T, C, Fr, f]
#
[B, T, C, Fr, f]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
-
2
],
self
.
groups
,
self
.
merge_freqs_count
)
#[B, T, f, Fr, C]
#
[B, T, f, Fr, C]
x
=
x
.
permute
(
0
,
1
,
4
,
3
,
2
)
#[B, T, f, Fr*C]
#
[B, T, f, Fr*C]
x
=
x
.
flatten
(
start_dim
=
3
,
end_dim
=
4
)
#[B, Fr*C, T, f]
#
[B, Fr*C, T, f]
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
#[B, Fr*C', T, f]
#
[B, Fr*C', T, f]
x
=
self
.
L_proj
(
x
)
#[B, Fr, C', Tr]
#
[B, Fr, C', Tr]
x
=
x
.
reshape
(
x
.
shape
[
0
],
self
.
groups
,
self
.
out_channels
,
-
1
)
x
-=
torch
.
mean
(
x
,
axis
=-
1
,
keepdims
=
True
)
out
=
x
/
(
torch
.
std
(
x
,
axis
=-
1
,
keepdims
=
True
)
+
1e-5
)
#[B, C', C']
#
[B, C', C']
out
=
torch
.
einsum
(
'abci,abdi->abcd'
,
out
,
out
)
#[B, C'*(C'-1)/2]
#
[B, C'*(C'-1)/2]
out
=
torch
.
masked_select
(
out
,
self
.
mask
).
reshape
(
batch_size
,
-
1
)
out
=
out
/
num_locations
out
=
out
/
num_locations
return
out
class
AttentivePooling
(
torch
.
nn
.
Module
):
"""
Mean and Standard deviation attentive pooling
"""
def
__init__
(
self
,
num_channels
,
n
_mels
,
reduction
=
2
,
global_context
=
False
):
def
__init__
(
self
,
num_channels
,
n
um_freqs
=
10
,
attention_channels
=
128
,
global_context
=
False
):
"""
"""
...
...
@@ -130,11 +133,11 @@ class AttentivePooling(torch.nn.Module):
super
(
AttentivePooling
,
self
).
__init__
()
in_factor
=
3
if
global_context
else
1
self
.
attention
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
num_channels
*
(
n_mels
//
8
)
*
in_factor
,
num_channels
//
reduction
,
kernel_size
=
1
),
torch
.
nn
.
Conv1d
(
num_channels
*
num_freqs
*
in_factor
,
attention_channels
,
kernel_size
=
1
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
BatchNorm1d
(
num_channels
//
reduction
),
torch
.
nn
.
BatchNorm1d
(
attention_channels
),
torch
.
nn
.
Tanh
(),
torch
.
nn
.
Conv1d
(
num_channels
//
reduction
,
num_channels
*
(
n_mels
//
8
)
,
kernel_size
=
1
),
torch
.
nn
.
Conv1d
(
attention_channels
,
num_channels
*
num_freqs
,
kernel_size
=
1
),
torch
.
nn
.
Softmax
(
dim
=
2
),
)
self
.
global_context
=
global_context
...
...
@@ -162,7 +165,7 @@ class AttentivePooling(torch.nn.Module):
w
=
self
.
attention
(
x
)
mu
=
torch
.
sum
(
x
*
w
,
dim
=
2
)
rh
=
torch
.
sqrt
(
(
torch
.
sum
((
x
**
2
)
*
w
,
dim
=
2
)
-
mu
**
2
).
clamp
(
min
=
1e-
5
)
)
rh
=
torch
.
sqrt
(
(
torch
.
sum
((
x
**
2
)
*
w
,
dim
=
2
)
-
mu
**
2
).
clamp
(
min
=
1e-
9
)
)
x
=
torch
.
cat
((
mu
,
rh
),
1
)
x
=
x
.
view
(
x
.
size
()[
0
],
-
1
)
return
x
...
...
nnet/preprocessor.py
View file @
5bf6d959
...
...
@@ -124,6 +124,91 @@ class MfccFrontEnd(torch.nn.Module):
return
mfcc
class
WavLmFrontEnd
(
torch
.
nn
.
Module
):
"""
AJOUTER le HOW TO...
"""
def
__init__
(
self
):
super
(
WavLmFrontEnd
,
self
).
__init__
()
self
.
feat_type
=
'wavlm_large'
self
.
feature_extract
=
torch
.
hub
.
load
(
's3prl/s3prl'
,
self
.
feat_type
)
self
.
update_extract
=
False
self
.
feature_selection
=
'hidden_states'
self
.
sr
=
16000
self
.
feat_num
=
self
.
get_feat_num
()
self
.
instance_norm
=
torch
.
nn
.
InstanceNorm1d
(
1024
)
self
.
feature_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
self
.
feat_num
))
if
self
.
feat_type
!=
'fbank'
and
self
.
feat_type
!=
'mfcc'
:
freeze_list
=
[
'final_proj'
,
'label_embs_concat'
,
'mask_emb'
,
'project_q'
,
'quantizer'
]
for
name
,
param
in
self
.
feature_extract
.
named_parameters
():
for
freeze_val
in
freeze_list
:
if
freeze_val
in
name
:
param
.
requires_grad
=
False
break
if
not
self
.
update_extract
:
for
param
in
self
.
feature_extract
.
parameters
():
param
.
requires_grad
=
False
def
get_feat_num
(
self
):
"""
:return:
"""
self
.
feature_extract
.
eval
()
wav
=
[
torch
.
randn
(
self
.
sr
).
to
(
next
(
self
.
feature_extract
.
parameters
()).
device
)]
with
torch
.
no_grad
():
features
=
self
.
feature_extract
(
wav
)
select_feature
=
features
[
self
.
feature_selection
]
if
isinstance
(
select_feature
,
(
list
,
tuple
)):
return
len
(
select_feature
)
else
:
return
1
def
get_feat
(
self
,
x
):
"""
:param x:
:return:
"""
if
self
.
update_extract
:
x
=
self
.
feature_extract
([
sample
for
sample
in
x
])
else
:
with
torch
.
no_grad
():
if
self
.
feat_type
==
'fbank'
or
self
.
feat_type
==
'mfcc'
:
x
=
self
.
feature_extract
(
x
)
+
1e-6
# B x feat_dim x time_len
else
:
x
=
self
.
feature_extract
([
sample
for
sample
in
x
])
if
self
.
feat_type
==
'fbank'
:
x
=
x
.
log
()
if
self
.
feat_type
!=
"fbank"
and
self
.
feat_type
!=
"mfcc"
:
x
=
x
[
self
.
feature_selection
]
if
isinstance
(
x
,
(
list
,
tuple
)):
x
=
torch
.
stack
(
x
,
dim
=
0
)
else
:
x
=
x
.
unsqueeze
(
0
)
norm_weights
=
torch
.
nn
.
functional
.
softmax
(
self
.
feature_weight
,
dim
=-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
x
=
(
norm_weights
*
x
).
sum
(
dim
=
0
)
x
=
torch
.
transpose
(
x
,
1
,
2
)
+
1e-6
x
=
self
.
instance_norm
(
x
)
return
x
def
forward
(
self
,
x
,
is_eval
=
False
):
"""
:param x:
:param is_eval:
:return:
"""
return
self
.
get_feat
(
x
)
class
MelSpecFrontEnd
(
torch
.
nn
.
Module
):
"""
Module that compute Mel spetrogramm on an audio signal
...
...
nnet/res_net.py
View file @
5bf6d959
...
...
@@ -438,13 +438,17 @@ class PreResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def
__init__
(
self
,
block
=
BasicBlock
,
num_blocks
=
[
3
,
1
,
3
,
1
,
5
,
1
,
2
]
,
speaker_number
=
10
):
def
__init__
(
self
,
block
=
BasicBlock
,
num_blocks
=
(
3
,
1
,
3
,
1
,
5
,
1
,
2
)
,
speaker_number
=
10
):
super
(
PreResNet34
,
self
).
__init__
()
self
.
in_planes
=
128
self
.
speaker_number
=
speaker_number
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
128
)
...
...
@@ -457,7 +461,6 @@ class PreResNet34(torch.nn.Module):
self
.
layer6
=
self
.
_make_layer
(
block
,
256
,
num_blocks
[
5
],
stride
=
2
)
self
.
layer7
=
self
.
_make_layer
(
block
,
256
,
num_blocks
[
5
],
stride
=
1
)
def
_make_layer
(
self
,
block
,
planes
,
num_blocks
,
stride
):
"""
...
...
@@ -498,13 +501,17 @@ class PreHalfResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def
__init__
(
self
,
block
=
BasicBlock
,
num_blocks
=
[
3
,
4
,
6
,
3
]
,
speaker_number
=
10
):