Commit 5bf6d959 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

adding wavlm

parent c4684601
......@@ -4,7 +4,7 @@
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#PARALLEL_MODULE
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
......@@ -50,8 +50,8 @@ if 'SIDEKIT' in os.environ:
if val == "true":
SIDEKIT_CONFIG["mpi"] = True
if k == "cuda":
if val == "false":
SIDEKIT_CONFIG["cuda"] = False
if val == "true":
SIDEKIT_CONFIG["cuda"] = True
PARALLEL_MODULE = 'multiprocessing' # can be , threading, multiprocessing MPI is planned in the future
......
......@@ -186,7 +186,7 @@ def data_augmentation(speech,
rir_fn = transform_dict["add_reverb"]["data_path"] + rir_nfo # TODO harmonize with noise
rir, rir_fs = torchaudio.load(rir_fn)
assert rir_fs == sample_rate
#rir = rir[rir_nfo[1], :] #keep selected channel
# rir = rir[rir_nfo[1], :] #keep selected channel
speech = torch.tensor(signal.convolve(speech, rir, mode='full')[:, :speech.shape[1]])
if "add_noise" in augmentations:
......@@ -261,11 +261,10 @@ def data_augmentation(speech,
)
if "codec" in augmentations:
final_shape = speech.shape[1]
final_shape = speech.shape[1]
configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "wav", "encoding": 'ALAW', "bits_per_sample": 8}, "8 bit a-law"),
({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis")
]
......
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as trans
#from .utils import UpstreamExpert
''' Res2Conv1d + BatchNorm1d + ReLU
'''
class Res2Conv1dReluBn(nn.Module):
'''
in_channels == out_channels == channels
'''
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
''' Conv1d + BatchNorm1d + ReLU
'''
class Conv1dReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
''' The SE connection of 1D case.
'''
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
''' SE-Res2Block of the ECAPA-TDNN architecture.
'''
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.Conv1dReluBn1(x)
x = self.Res2Conv1dReluBn(x)
x = self.Conv1dReluBn2(x)
x = self.SE_Connect(x)
return x + residual
''' Attentive weighted mean and standard deviation pooling.
'''
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
super().__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
if feat_type == "fbank" or feat_type == "mfcc":
self.update_extract = False
win_len = int(sr * 0.025)
hop_len = int(sr * 0.01)
if feat_type == 'fbank':
self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
hop_length=hop_len, f_min=0.0, f_max=sr // 2,
pad=0, n_mels=feat_dim)
elif feat_type == 'mfcc':
melkwargs = {
'n_fft': 512,
'win_length': win_len,
'hop_length': hop_len,
'f_min': 0.0,
'f_max': sr // 2,
'pad': 0
}
self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
melkwargs=melkwargs)
else:
if config_path is None:
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
#else:
# self.feature_extract = UpstreamExpert(config_path)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != 'fbank' and feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
#self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
self.bn = nn.BatchNorm1d(self.channels[-1])
#self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == 'fbank':
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
x = x[self.feature_selection]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x):
#x = self.get_feat(x)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = self.bn(F.relu(self.conv(out)))
#out = self.bn(self.pooling(out))
#out = self.linear(out)
return out
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
if __name__ == '__main__':
x = torch.zeros(2, 32000)
model = ECAPA_TDNN_SMALL(feat_dim=768, emb_dim=256, feat_type='hubert_base', feature_selection="hidden_states",
update_extract=False)
out = model(x)
# print(model)
print(out.shape)
......@@ -435,3 +435,38 @@ class AngularProximityMagnet(torch.nn.Module):
batch_loss = self.criterion(ap_sim_matrix, torch.arange(0, int(out_positive.shape[0]), device=torch.device("cuda:0"))) \
+ self.magnet_criterion(cos_sim_matrix.flatten().unsqueeze(1), mask.flatten().unsqueeze(1))
return batch_loss, cce_prediction
class CircleMargin(torch.nn.Module):
"""
"""
def __init__(self, in_features, out_features, s=256, m=0.25) -> None:
super(CircleMargin, self).__init__()
self.margin = m
self.gamma = s
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
torch.nn.init.xavier_uniform_(self.weight)
def forward(self, x, target=None):
"""
:param x:
:param target:
:return:
"""
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(x),
torch.nn.functional.normalize(self.weight))
if target is None:
return cosine * self.gamma
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, target.view(-1, 1), 1)
output = (one_hot * (self.margin ** 2 - (1 - cosine) ** 2)) +\
((1.0 - one_hot) * (cosine ** 2 - self.margin ** 2))
output = output * self.gamma
return output, cosine * self.gamma
......@@ -71,6 +71,9 @@ class MeanStdPooling(torch.nn.Module):
class ChannelWiseCorrPooling(torch.nn.Module):
"""
"""
def __init__(self, in_channels=256, out_channels=64, in_freqs=10, channels_dropout=0.25):
super(ChannelWiseCorrPooling, self).__init__()
self.channels_dropout = channels_dropout
......@@ -80,7 +83,7 @@ class ChannelWiseCorrPooling(torch.nn.Module):
self.out_channels = out_channels
self.out_dim = int(self.out_channels*(self.out_channels-1)/2)*self.groups
self.L_proj = torch.nn.Conv2d(in_channels*self.groups, out_channels*self.groups, kernel_size=(1, 1), groups=self.groups)
#self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
# self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.mask = torch.tril(torch.ones((out_channels, out_channels)), diagonal=-1).type(torch.BoolTensor)
def forward(self, x):
......@@ -94,34 +97,34 @@ class ChannelWiseCorrPooling(torch.nn.Module):
self.mask = self.mask.to(x.device)
if self.training:
x *= torch.nn.functional.dropout(torch.ones((1, x.shape[1], 1, 1), device=x.device), p=self.channels_dropout)
#[B, T, C, F]
# [B, T, C, F]
x = x.permute(0, 2, 1, 3)
#[B, T, C, Fr, f]
# [B, T, C, Fr, f]
x = x.reshape(x.shape[0], x.shape[1], x.shape[-2], self.groups, self.merge_freqs_count)
#[B, T, f, Fr, C]
# [B, T, f, Fr, C]
x = x.permute(0, 1, 4, 3, 2)
#[B, T, f, Fr*C]
# [B, T, f, Fr*C]
x = x.flatten(start_dim=3, end_dim=4)
#[B, Fr*C, T, f]
# [B, Fr*C, T, f]
x = x.permute(0, 3, 1, 2)
#[B, Fr*C', T, f]
# [B, Fr*C', T, f]
x = self.L_proj(x)
#[B, Fr, C', Tr]
# [B, Fr, C', Tr]
x = x.reshape(x.shape[0], self.groups, self.out_channels, -1)
x -= torch.mean(x, axis=-1, keepdims=True)
out = x/(torch.std(x, axis=-1, keepdims=True) + 1e-5)
#[B, C', C']
# [B, C', C']
out = torch.einsum('abci,abdi->abcd', out, out)
#[B, C'*(C'-1)/2]
# [B, C'*(C'-1)/2]
out = torch.masked_select(out, self.mask).reshape(batch_size, -1)
out = out/num_locations
out = out / num_locations
return out
class AttentivePooling(torch.nn.Module):
"""
Mean and Standard deviation attentive pooling
"""
def __init__(self, num_channels, n_mels, reduction=2, global_context=False):
def __init__(self, num_channels, num_freqs=10, attention_channels=128, global_context=False):
"""
"""
......@@ -130,11 +133,11 @@ class AttentivePooling(torch.nn.Module):
super(AttentivePooling, self).__init__()
in_factor = 3 if global_context else 1
self.attention = torch.nn.Sequential(
torch.nn.Conv1d(num_channels * (n_mels//8) * in_factor, num_channels//reduction, kernel_size=1),
torch.nn.Conv1d(num_channels * num_freqs * in_factor, attention_channels, kernel_size=1),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(num_channels//reduction),
torch.nn.BatchNorm1d(attention_channels),
torch.nn.Tanh(),
torch.nn.Conv1d(num_channels//reduction, num_channels * (n_mels//8), kernel_size=1),
torch.nn.Conv1d(attention_channels, num_channels * num_freqs, kernel_size=1),
torch.nn.Softmax(dim=2),
)
self.global_context = global_context
......@@ -162,7 +165,7 @@ class AttentivePooling(torch.nn.Module):
w = self.attention(x)
mu = torch.sum(x * w, dim=2)
rh = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-5) )
rh = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-9) )
x = torch.cat((mu, rh),1)
x = x.view(x.size()[0], -1)
return x
......
......@@ -124,6 +124,91 @@ class MfccFrontEnd(torch.nn.Module):
return mfcc
class WavLmFrontEnd(torch.nn.Module):
"""
AJOUTER le HOW TO...
"""
def __init__(self):
super(WavLmFrontEnd, self).__init__()
self.feat_type = 'wavlm_large'
self.feature_extract = torch.hub.load('s3prl/s3prl', self.feat_type)
self.update_extract = False
self.feature_selection = 'hidden_states'
self.sr = 16000
self.feat_num = self.get_feat_num()
self.instance_norm = torch.nn.InstanceNorm1d(1024)
self.feature_weight = torch.nn.Parameter(torch.zeros(self.feat_num))
if self.feat_type != 'fbank' and self.feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
def get_feat_num(self):
"""
:return:
"""
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
"""
:param x:
:return:
"""
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == 'fbank':
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
x = x[self.feature_selection]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = torch.nn.functional.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x, is_eval=False):
"""
:param x:
:param is_eval:
:return:
"""
return self.get_feat(x)
class MelSpecFrontEnd(torch.nn.Module):
"""
Module that compute Mel spetrogramm on an audio signal
......
......@@ -438,13 +438,17 @@ class PreResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def __init__(self, block=BasicBlock, num_blocks=[3, 1, 3, 1, 5, 1, 2], speaker_number=10):
def __init__(self, block=BasicBlock, num_blocks=(3, 1, 3, 1, 5, 1, 2), speaker_number=10):
super(PreResNet34, self).__init__()
self.in_planes = 128
self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 128, kernel_size=3,
stride=1, padding=1, bias=False)
self.conv1 = torch.nn.Conv2d(1,
128,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = torch.nn.BatchNorm2d(128)
......@@ -457,7 +461,6 @@ class PreResNet34(torch.nn.Module):
self.layer6 = self._make_layer(block, 256, num_blocks[5], stride=2)
self.layer7 = self._make_layer(block, 256, num_blocks[5], stride=1)
def _make_layer(self, block, planes, num_blocks, stride):
"""
......@@ -498,13 +501,17 @@ class PreHalfResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def __init__(self, block=BasicBlock, num_blocks=[3, 4, 6, 3], speaker_number=10):