Commit 51cc030f authored by Anthony Larcher's avatar Anthony Larcher
Browse files

adding wavlm

parent 5bf6d959
......@@ -4,8 +4,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as trans
#from .utils import UpstreamExpert
from .pooling import AttentiveStatsPool
''' Res2Conv1d + BatchNorm1d + ReLU
'''
......@@ -32,6 +32,11 @@ class Res2Conv1dReluBn(nn.Module):
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
"""
:param x:
:return:
"""
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
......@@ -55,12 +60,20 @@ class Res2Conv1dReluBn(nn.Module):
class Conv1dReluBn(nn.Module):
"""
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
"""
:param x:
:return:
"""
return self.bn(F.relu(self.conv(x)))
......@@ -69,12 +82,20 @@ class Conv1dReluBn(nn.Module):
class SE_Connect(nn.Module):
"""
"""
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
"""
:param x:
:return:
"""
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
......@@ -97,6 +118,9 @@ class SE_Connect(nn.Module):
class SE_Res2Block(nn.Module):
"""
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
......@@ -113,6 +137,11 @@ class SE_Res2Block(nn.Module):
)
def forward(self, x):
"""
:param x:
:return:
"""
residual = x
if self.shortcut:
residual = self.shortcut(x)
......@@ -125,44 +154,15 @@ class SE_Res2Block(nn.Module):
return x + residual
''' Attentive weighted mean and standard deviation pooling.
'''
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
def __init__(self,
feat_dim=80,
channels=512,
feat_type='fbank',
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None):
super().__init__()
self.feat_type = feat_type
......@@ -232,8 +232,11 @@ class ECAPA_TDNN(nn.Module):
self.bn = nn.BatchNorm1d(self.channels[-1])
#self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
"""
:return:
"""
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
......@@ -245,6 +248,11 @@ class ECAPA_TDNN(nn.Module):
return 1
def get_feat(self, x):
"""
:param x:
:return:
"""
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
......@@ -271,24 +279,33 @@ class ECAPA_TDNN(nn.Module):
return x
def forward(self, x):
#x = self.get_feat(x)
"""
:param x:
:return:
"""
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = self.bn(F.relu(self.conv(out)))
#out = self.bn(self.pooling(out))
#out = self.linear(out)
return out
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
def ECAPA_TDNN_SMALL(feat_dim,
feat_type='fbank',
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None):
return ECAPA_TDNN(feat_dim=feat_dim,
channels=512,
feat_type=feat_type,
sr=sr,
feature_selection=feature_selection,
update_extract=update_extract,
config_path=config_path):
if __name__ == '__main__':
x = torch.zeros(2, 32000)
......
......@@ -171,6 +171,44 @@ class AttentivePooling(torch.nn.Module):
return x
class AttentiveStatsPool(torch.nn.Module):
"""
Attentive weighted mean and standard deviation pooling.
"""
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
"""
:param x:
:return:
"""
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class GruPooling(torch.nn.Module):
"""
Pooling done by using a recurrent network
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment