Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
51cc030f
Commit
51cc030f
authored
Jan 03, 2022
by
Anthony Larcher
Browse files
adding wavlm
parent
5bf6d959
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/ecapa_tdnn.py
View file @
51cc030f
...
...
@@ -4,8 +4,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchaudio.transforms
as
trans
#from .utils import UpstreamExpert
from
.pooling
import
AttentiveStatsPool
''' Res2Conv1d + BatchNorm1d + ReLU
'''
...
...
@@ -32,6 +32,11 @@ class Res2Conv1dReluBn(nn.Module):
self
.
bns
=
nn
.
ModuleList
(
self
.
bns
)
def
forward
(
self
,
x
):
"""
:param x:
:return:
"""
out
=
[]
spx
=
torch
.
split
(
x
,
self
.
width
,
1
)
for
i
in
range
(
self
.
nums
):
...
...
@@ -55,12 +60,20 @@ class Res2Conv1dReluBn(nn.Module):
class
Conv1dReluBn
(
nn
.
Module
):
"""
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
bias
=
True
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
bias
=
bias
)
self
.
bn
=
nn
.
BatchNorm1d
(
out_channels
)
def
forward
(
self
,
x
):
"""
:param x:
:return:
"""
return
self
.
bn
(
F
.
relu
(
self
.
conv
(
x
)))
...
...
@@ -69,12 +82,20 @@ class Conv1dReluBn(nn.Module):
class
SE_Connect
(
nn
.
Module
):
"""
"""
def
__init__
(
self
,
channels
,
se_bottleneck_dim
=
128
):
super
().
__init__
()
self
.
linear1
=
nn
.
Linear
(
channels
,
se_bottleneck_dim
)
self
.
linear2
=
nn
.
Linear
(
se_bottleneck_dim
,
channels
)
def
forward
(
self
,
x
):
"""
:param x:
:return:
"""
out
=
x
.
mean
(
dim
=
2
)
out
=
F
.
relu
(
self
.
linear1
(
out
))
out
=
torch
.
sigmoid
(
self
.
linear2
(
out
))
...
...
@@ -97,6 +118,9 @@ class SE_Connect(nn.Module):
class
SE_Res2Block
(
nn
.
Module
):
"""
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
scale
,
se_bottleneck_dim
):
super
().
__init__
()
self
.
Conv1dReluBn1
=
Conv1dReluBn
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
...
...
@@ -113,6 +137,11 @@ class SE_Res2Block(nn.Module):
)
def
forward
(
self
,
x
):
"""
:param x:
:return:
"""
residual
=
x
if
self
.
shortcut
:
residual
=
self
.
shortcut
(
x
)
...
...
@@ -125,44 +154,15 @@ class SE_Res2Block(nn.Module):
return
x
+
residual
''' Attentive weighted mean and standard deviation pooling.
'''
class
AttentiveStatsPool
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
attention_channels
=
128
,
global_context_att
=
False
):
super
().
__init__
()
self
.
global_context_att
=
global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if
global_context_att
:
self
.
linear1
=
nn
.
Conv1d
(
in_dim
*
3
,
attention_channels
,
kernel_size
=
1
)
# equals W and b in the paper
else
:
self
.
linear1
=
nn
.
Conv1d
(
in_dim
,
attention_channels
,
kernel_size
=
1
)
# equals W and b in the paper
self
.
linear2
=
nn
.
Conv1d
(
attention_channels
,
in_dim
,
kernel_size
=
1
)
# equals V and k in the paper
def
forward
(
self
,
x
):
if
self
.
global_context_att
:
context_mean
=
torch
.
mean
(
x
,
dim
=-
1
,
keepdim
=
True
).
expand_as
(
x
)
context_std
=
torch
.
sqrt
(
torch
.
var
(
x
,
dim
=-
1
,
keepdim
=
True
)
+
1e-10
).
expand_as
(
x
)
x_in
=
torch
.
cat
((
x
,
context_mean
,
context_std
),
dim
=
1
)
else
:
x_in
=
x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha
=
torch
.
tanh
(
self
.
linear1
(
x_in
))
# alpha = F.relu(self.linear1(x_in))
alpha
=
torch
.
softmax
(
self
.
linear2
(
alpha
),
dim
=
2
)
mean
=
torch
.
sum
(
alpha
*
x
,
dim
=
2
)
residuals
=
torch
.
sum
(
alpha
*
(
x
**
2
),
dim
=
2
)
-
mean
**
2
std
=
torch
.
sqrt
(
residuals
.
clamp
(
min
=
1e-9
))
return
torch
.
cat
([
mean
,
std
],
dim
=
1
)
class
ECAPA_TDNN
(
nn
.
Module
):
def
__init__
(
self
,
feat_dim
=
80
,
channels
=
512
,
emb_dim
=
192
,
global_context_att
=
False
,
feat_type
=
'fbank'
,
sr
=
16000
,
feature_selection
=
"hidden_states"
,
update_extract
=
False
,
config_path
=
None
):
def
__init__
(
self
,
feat_dim
=
80
,
channels
=
512
,
feat_type
=
'fbank'
,
sr
=
16000
,
feature_selection
=
"hidden_states"
,
update_extract
=
False
,
config_path
=
None
):
super
().
__init__
()
self
.
feat_type
=
feat_type
...
...
@@ -232,8 +232,11 @@ class ECAPA_TDNN(nn.Module):
self
.
bn
=
nn
.
BatchNorm1d
(
self
.
channels
[
-
1
])
#self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def
get_feat_num
(
self
):
"""
:return:
"""
self
.
feature_extract
.
eval
()
wav
=
[
torch
.
randn
(
self
.
sr
).
to
(
next
(
self
.
feature_extract
.
parameters
()).
device
)]
with
torch
.
no_grad
():
...
...
@@ -245,6 +248,11 @@ class ECAPA_TDNN(nn.Module):
return
1
def
get_feat
(
self
,
x
):
"""
:param x:
:return:
"""
if
self
.
update_extract
:
x
=
self
.
feature_extract
([
sample
for
sample
in
x
])
else
:
...
...
@@ -271,24 +279,33 @@ class ECAPA_TDNN(nn.Module):
return
x
def
forward
(
self
,
x
):
#x = self.get_feat(x)
"""
:param x:
:return:
"""
out1
=
self
.
layer1
(
x
)
out2
=
self
.
layer2
(
out1
)
out3
=
self
.
layer3
(
out2
)
out4
=
self
.
layer4
(
out3
)
out
=
torch
.
cat
([
out2
,
out3
,
out4
],
dim
=
1
)
out
=
self
.
bn
(
F
.
relu
(
self
.
conv
(
out
)))
#out = self.bn(self.pooling(out))
#out = self.linear(out)
return
out
def
ECAPA_TDNN_SMALL
(
feat_dim
,
emb_dim
=
256
,
feat_type
=
'fbank'
,
sr
=
16000
,
feature_selection
=
"hidden_states"
,
update_extract
=
False
,
config_path
=
None
):
return
ECAPA_TDNN
(
feat_dim
=
feat_dim
,
channels
=
512
,
emb_dim
=
emb_dim
,
feat_type
=
feat_type
,
sr
=
sr
,
feature_selection
=
feature_selection
,
update_extract
=
update_extract
,
config_path
=
config_path
)
def
ECAPA_TDNN_SMALL
(
feat_dim
,
feat_type
=
'fbank'
,
sr
=
16000
,
feature_selection
=
"hidden_states"
,
update_extract
=
False
,
config_path
=
None
):
return
ECAPA_TDNN
(
feat_dim
=
feat_dim
,
channels
=
512
,
feat_type
=
feat_type
,
sr
=
sr
,
feature_selection
=
feature_selection
,
update_extract
=
update_extract
,
config_path
=
config_path
):
if
__name__
==
'__main__'
:
x
=
torch
.
zeros
(
2
,
32000
)
...
...
nnet/pooling.py
View file @
51cc030f
...
...
@@ -171,6 +171,44 @@ class AttentivePooling(torch.nn.Module):
return
x
class
AttentiveStatsPool
(
torch
.
nn
.
Module
):
"""
Attentive weighted mean and standard deviation pooling.
"""
def
__init__
(
self
,
in_dim
,
attention_channels
=
128
,
global_context_att
=
False
):
super
().
__init__
()
self
.
global_context_att
=
global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if
global_context_att
:
self
.
linear1
=
nn
.
Conv1d
(
in_dim
*
3
,
attention_channels
,
kernel_size
=
1
)
# equals W and b in the paper
else
:
self
.
linear1
=
nn
.
Conv1d
(
in_dim
,
attention_channels
,
kernel_size
=
1
)
# equals W and b in the paper
self
.
linear2
=
nn
.
Conv1d
(
attention_channels
,
in_dim
,
kernel_size
=
1
)
# equals V and k in the paper
def
forward
(
self
,
x
):
"""
:param x:
:return:
"""
if
self
.
global_context_att
:
context_mean
=
torch
.
mean
(
x
,
dim
=-
1
,
keepdim
=
True
).
expand_as
(
x
)
context_std
=
torch
.
sqrt
(
torch
.
var
(
x
,
dim
=-
1
,
keepdim
=
True
)
+
1e-10
).
expand_as
(
x
)
x_in
=
torch
.
cat
((
x
,
context_mean
,
context_std
),
dim
=
1
)
else
:
x_in
=
x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha
=
torch
.
tanh
(
self
.
linear1
(
x_in
))
# alpha = F.relu(self.linear1(x_in))
alpha
=
torch
.
softmax
(
self
.
linear2
(
alpha
),
dim
=
2
)
mean
=
torch
.
sum
(
alpha
*
x
,
dim
=
2
)
residuals
=
torch
.
sum
(
alpha
*
(
x
**
2
),
dim
=
2
)
-
mean
**
2
std
=
torch
.
sqrt
(
residuals
.
clamp
(
min
=
1e-9
))
return
torch
.
cat
([
mean
,
std
],
dim
=
1
)
class
GruPooling
(
torch
.
nn
.
Module
):
"""
Pooling done by using a recurrent network
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment