Commit cc57f1bf authored by Anthony Larcher's avatar Anthony Larcher
Browse files

+ rawnet2

parent 3478ea30
......@@ -37,10 +37,200 @@ from torchvision import transforms
from collections import OrderedDict
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset, VoxDataset
from .xsets import FrequencyMask, CMVN, TemporalMask
from .sincnet import SincNet, SincConv1d
from ..bosaris import IdMap
from ..statserver import StatServer
from torch.utils.data import DataLoader
class FeatureMapScaling(nn.Module):
"""
"""
def __init__(self, nb_dim, do_add = True, do_mul = True):
"""
:param nb_dim:
:param do_add:
:param do_mul:
"""
super(FFeatureMapScalingRM, self).__init__()
self.fc = torch.nn.Linear(nb_dim, nb_dim)
self.sig = torch.nn.Sigmoid()
self.do_add = do_add
self.do_mul = do_mul
def forward(self, x):
"""
:param x:
:return:
"""
y = torch.nn.functional.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
if self.do_mul:
x = x * y
if self.do_add:
x = x + y
return x
class ResBlockWFMS(nn.Module):
"""
"""
def __init__(self, nb_filts, first=False):
"""
:param nb_filts:
:param first:
"""
super(ResBlockWFMS, self).__init__()
self.first = first
if not self.first:
self.bn1 = torch.nn.BatchNorm1d(num_features=nb_filts[0])
self.lrelu = torch.nn.LeakyReLU()
self.lrelu_keras = torch.nn.LeakyReLU(negative_slope=0.3)
self.conv1 = torch.nn.Conv1d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=3,
padding=1,
stride=1)
self.bn2 = torch.nn.BatchNorm1d(num_features=nb_filts[1])
self.conv2 = torch.nn.Conv1d(in_channels=nb_filts[1],
out_channels=nb_filts[1],
padding=1,
kernel_size=3,
stride=1)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = torch.nn.Conv1d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=0,
kernel_size=1,
stride=1)
else:
self.downsample = False
self.mp = torch.nn.MaxPool1d(3)
self.fms = FeatureMapScaling(nb_dim=nb_filts[1],
do_add=True,
do_mul=True
)
def forward(self, x):
"""
:param x:
:return:
"""
identity = x
if not self.first:
out = self.bn1(x)
out = self.lrelu_keras(out)
else:
out = x
#out = self.conv1(x)
out = self.conv1(out) # modif Anthony
out = self.bn2(out)
out = self.lrelu_keras(out)
out = self.conv2(out)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
out = self.fms(out)
return out
class LayerNorm(nn.Module):
"""
"""
def __init__(self, features, eps=1e-6):
"""
:param features:
:param eps:
"""
super(LayerNorm,self).__init__()
self.gamma = torch.nn.Parameter(torch.ones(features))
self.beta = torch.nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
"""
:param x:
:return:
"""
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class RawPreprocessor(torch.nn.module):
"""
"""
def __init__(self, nb_samp, in_channels, filts, first_conv):
"""
:param nb_samp:
:param in_channels:
:param filts:
:param first_conv:
"""
super(RawPreprocessor, self).__init__()
self.ln = LayerNorm(nb_samp)
self.first_conv = SincConv1d(in_channels = in_channels,
out_channels = filts,
kernel_size = first_conv
)
self.first_bn = torch.nn.BatchNorm1d(num_features = filts)
self.lrelu = torch.nn.LeakyReLU()
self.lrelu_keras = torch.nn.LeakyReLU(negative_slope = 0.3)
def forward(self, x):
"""
:param x:
:return:
"""
nb_samp = x.shape[0]
len_seq = x.shape[1]
out = self.ln(x)
out = out.view(nb_samp, 1, len_seq)
out = torch.nn.functional.max_pool1d(torch.abs(self.first_conv(out)), 3)
out = self.first_bn(out)
out = self.lrelu_keras(out)
return out
class ResBlock(torch.nn.Module):
"""
......
......@@ -35,16 +35,13 @@
# SLT 2018. https://arxiv.org/abs/1808.00158
from typing import List
import numpy as np
import numpy
import torch
import torch.nn.functional as F
import torch.nn as nn
#import torch.nn.functional as F
import math
#from pyannote.core import SlidingWindow
#from pyannote.audio.train.task import Task
class SincConv1d(nn.Module):
class SincConv1d(torch.nn.Module):
"""Sinc-based 1D convolution
Parameters
......@@ -76,79 +73,66 @@ class SincConv1d(nn.Module):
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
return 2595 * numpy.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)
def __init__(
self,
in_channels,
out_channels,
kernel_size,
sample_rate=16000,
min_low_hz=50,
min_band_hz=50,
stride=1,
padding=0,
dilation=1,
bias=False,
groups=1,
):
super().__init__()
def __init__(self,
out_channels,
kernel_size,
sample_rate=16000,
in_channels=1,
stride=1,
padding=0,
dilation=1,
bias=False,
groups=1,
min_low_hz=50,
min_band_hz=50):
super(SincConv1d, self).__init__()
if in_channels != 1:
msg = (
f"SincConv1d only supports one input channel. "
f"Here, in_channels = {in_channels}."
)
msg = f"SincConv1d only supports one input channel. (Here, in_channels = {in_channels})."
raise ValueError(msg)
self.in_channels = in_channels
self.in_channels = in_channels
self.out_channels = out_channels
if kernel_size % 2 == 0:
msg = (
f"SincConv1d only support odd kernel size. "
f"Here, kernel_size = {kernel_size}."
)
raise ValueError(msg)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
self.kernel_size = kernel_size
if kernel_size % 2 == 0:
self.kernel_size=self.kernel_size+1
if bias:
raise ValueError("SincConv1d does not support bias.")
if groups > 1:
raise ValueError("SincConv1d does not support groups.")
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
# initialize filterbanks such that they are equally spaced in Mel scale
low_hz = 30
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(
self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1
)
mel = numpy.linspace(self.to_mel(low_hz),
self.to_mel(high_hz),
self.out_channels + 1)
hz = self.to_hz(mel)
# filter lower frequency (out_channels, 1)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
self.low_hz_ = torch.nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
# filter frequency band (out_channels, 1)
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
self.band_hz_ = torch.nn.Parameter(torch.Tensor(numpy.diff(hz)).view(-1, 1))
# Half Hamming half window
n_lin = torch.linspace(
0, self.kernel_size / 2 - 1, steps=int((self.kernel_size / 2))
)
# Half Hamming window
n_lin = torch.linspace(0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)))
self.window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / self.kernel_size)
# (kernel_size, 1)
......@@ -169,48 +153,38 @@ class SincConv1d(nn.Module):
features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
Batch of sinc filters activations.
"""
self.n_ = self.n_.to(waveforms.device)
self.window_ = self.window_.to(waveforms.device)
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(
low + self.min_band_hz + torch.abs(self.band_hz_),
self.min_low_hz,
self.sample_rate / 2,
)
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),
self.min_low_hz,
self.sample_rate / 2)
band = (high - low)[:, 0]
f_times_t_low = torch.matmul(low, self.n_)
f_times_t_high = torch.matmul(high, self.n_)
# Equivalent to Eq.4 of the reference paper
# I just have expanded the sinc and simplified the terms.
# expanded the sinc and simplified the terms.
# This way I avoid several useless computations.
band_pass_left = (
(torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) / (self.n_ / 2)
) * self.window_
band_pass_left = ((torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) / (self.n_ / 2)) * self.window_
band_pass_center = 2 * band.view(-1, 1)
band_pass_right = torch.flip(band_pass_left, dims=[1])
band_pass = torch.cat(
[band_pass_left, band_pass_center, band_pass_right], dim=1
)
band_pass = torch.cat([band_pass_left, band_pass_center, band_pass_right], dim=1)
band_pass = band_pass / (2 * band[:, None])
self.filters = (band_pass).view(self.out_channels, 1, self.kernel_size)
self.filters = band_pass.view(self.out_channels, 1, self.kernel_size)
return F.conv1d(
waveforms,
self.filters,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=None,
groups=1,
)
return torch.nn.functional.conv1d(waveforms,
self.filters,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=None,
groups=1)
class SincNet(nn.Module):
......
......@@ -388,7 +388,7 @@ class SideSet(Dataset):
int(self.sample_rate * (df.iloc[idx].duration - self.duration)),
self.sample_number - int(self.sample_rate * overlap)
)
possible_starts += self.sample_number * df.iloc[idx].start
possible_starts += self.sample_rate * df.iloc[idx].start
# Select max(seg_nb, possible_segments) segments
if chunk_per_segment == -1:
......@@ -475,6 +475,7 @@ class SideSet(Dataset):
# TODO: add data augmentation here!
if self.transform_pipeline:
print(f"shape sig = {sig.shape}")
sig, speaker_idx, _, __ = self.transforms((sig, speaker_idx, self.spec_aug[index], self.temp_aug[index]))
return torch.from_numpy(sig).type(torch.FloatTensor), speaker_idx
......
......@@ -40,11 +40,12 @@ from collections import OrderedDict
from .xsets import XvectorMultiDataset, StatDataset, VoxDataset, SideSet
from .xsets import IdMapSet
from .xsets import FrequencyMask, CMVN, TemporalMask, MFCC
from .res_net import RawPreprocessor, ResBlockWFMS
from ..bosaris import IdMap
from ..statserver import StatServer
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .sincnet import SincNet
from .sincnet import SincNet, SincConv1d
from tqdm import tqdm
__license__ = "LGPL"
......@@ -80,12 +81,45 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename
shutil.copyfile(filename, best_filename)
class GruPooling(torch.nn.Module):
"""
"""
def __init__(self, input_size, gru_node, nb_gru_layer):
"""
:param input_size:
:param gru_node:
:param nb_gru_layer:
"""
self.lrelu_keras = torch.nn.LeakyReLU(negative_slope = 0.3)
self.bn_before_gru = torch.nn.BatchNorm1d(num_features = input_size)
self.gru = torch.nn.GRU(input_size = input_size,
hidden_size = gru_node,
num_layers = nb_gru_layer,
batch_first = True)
def forward(self, x):
"""
:param x:
:return:
"""
x = self.bn_before_gru(x)
x = self.lrelu_keras(x)
x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
self.gru.flatten_parameters()
x, _ = self.gru(x)
x = x[:,-1,:]
return x
class Xtractor(torch.nn.Module):
"""
Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
"""
def __init__(self, speaker_number, model_archi=None):
def __init__(self, speaker_number, model_archi="xvector", norm_embedding=False):
"""
If config is None, default architecture is created
:param model_archi:
......@@ -93,8 +127,9 @@ class Xtractor(torch.nn.Module):
super(Xtractor, self).__init__()
self.speaker_number = speaker_number
self.feature_size = None
self.norm_embedding = norm_embedding
if model_archi is None:
if model_archi == "xvector":
self.feature_size = 30
self.activation = torch.nn.LeakyReLU(0.2)
......@@ -136,6 +171,35 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
elif model_archi == "rawnet2":
filts = [128, [128, 128], [128, 256], [256, 256]]
self.norm_embedding = True
self.preprocessor = RawPreprocessor(nb_samp=48000,
in_channels=1,
filts=filts[0],
first_conv=3)
self.sequence_network = torch.nn.Sequential(OrderedDict([
("block0", ResBlockWFMS(nb_filts=filts[1], first=True)),
("block1", ResBlockWFMS(nb_filts=filts[1])),
("block2", ResBlockWFMS(nb_filts=filts[2])),
("block3", ResBlockWFMS(nb_filts=[filts[2][1], filts[2][1]])),
("block4", ResBlockWFMS(nb_filts=[filts[2][1], filts[2][1]])),
("block5", ResBlockWFMS(nb_filts=[filts[2][1], filts[2][1]]))
]))
self.stat_pooling = GruPooling(input_size=filts[2][-1],
gru_node=1024,
nb_gru_layer=1)
self.before_speaker_embedding = torch.nn.Linear(in_features = 1024,
out_features = 1024)
self.after_speaker_embedding = torch.nn.Linear(in_features = 1024,
out_features = int(self.speaker_number),
bias = True)
else:
# Load Yaml configuration
with open(model_archi, 'r') as fh:
......@@ -161,6 +225,8 @@ class Xtractor(torch.nn.Module):
dropout=cfg['preprocessor']["dropout"]
)
self.feature_size = self.preprocessor.dimension
elif cfg['preprocessor']["type"] == "rawnet2":
self.preprocessor =
"""
Prepare sequence network
......@@ -271,6 +337,10 @@ class Xtractor(torch.nn.Module):
if is_eval:
return x
if self.norm_embedding:
x_norm = x.norm(p=2,dim=1, keepdim=True) / 10.
x = torch.div(x, x_norm)
x = self.after_speaker_embedding(x)
return x
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment