Commit be2b8c7e authored by Anthony Larcher's avatar Anthony Larcher
Browse files

New WavSet and SeqToSeq model

parent 03d0b79e
......@@ -55,10 +55,10 @@ from .scoring import DER
__author__ = "Sylvain Meignier and Anthony Larcher"
__copyright__ = "Copyright 2014-20120 Sylvain Meignier and Anthony Larcher"
__copyright__ = "Copyright 2014-2020 Sylvain Meignier and Anthony Larcher"
__license__ = "LGPL"
__maintainer__ = "Sylvain Meignier"
__email__ = "sylvain.meignierr@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'
__version__ = "0.1.4.6"
__version__ = "0.1.4.7"
......@@ -25,4 +25,5 @@ Copyright 2014-2020 Anthony Larcher
from .wavsets import SeqSet
from .seqtoseq import PreNet
from .seqtoseq import BLSTM
\ No newline at end of file
from .seqtoseq import BLSTM
from .seqtoseq import SeqToSeq
\ No newline at end of file
......@@ -26,16 +26,18 @@ Copyright 2014-2020 Anthony Larcher
import os
import sys
import numpy
import OrderedDict
import random
import h5py
import shutil
import torch
import torch.nn as nn
import yaml
from torch import optim
from torch.utils.data import Dataset
import logging
from sidekit.nnet.vad_rnn import BLSTM
from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader
__license__ = "LGPL"
......@@ -47,81 +49,42 @@ __status__ = "Production"
__docformat__ = 'reS'
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
"""
:param state:
:param is_best:
:param filename:
:param best_filename:
:return:
"""
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
class PreNet(nn.Module):
def __init__(self,
sample_rate=16000,
windows_duration=0.2,
frame_shift=0.01):
super(PreNet, self).__init__()
windows_length = int(sample_rate * windows_duration)
if windows_length % 2:
windows_length += 1
stride_0 = int(sample_rate * frame_shift)
self.conv0 = torch.nn.Conv1d(1, 64, windows_length, stride=stride_0, dilation=1)
self.conv1 = torch.nn.Conv1d(64, 64, 3, dilation=1)
self.conv2 = torch.nn.Conv1d(64, 64, 3, dilation=1)
self.norm0 = torch.nn.BatchNorm1d(64)
self.norm1 = torch.nn.BatchNorm1d(64)
self.norm2 = torch.nn.BatchNorm1d(64)
self.activation = torch.nn.LeakyReLU(0.2)
def forward(self, input):
x = self.norm0(self.activation(self.conv0(input)))
x = self.norm1(self.activation(self.conv1(x)))
output = self.norm2(self.activation(self.conv2(x)))
return output
# def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
# """
#
# :param state:
# :param is_best:
# :param filename:
# :param best_filename:
# :return:
# """
# torch.save(state, filename)
# if is_best:
# shutil.copyfile(filename, best_filename)
class BLSTM(nn.Module):
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
def __init__(self,
input_size,
lstm_1,
lstm_2,
linear_1,
linear_2,
output_size=1):
blstm_sizes):
"""
:param input_size:
:param lstm_1:
:param lstm_2:
:param linear_1:
:param linear_2:
:param output_size:
:param blstm_sizes:
"""
super(BLSTM, self).__init__()
self.input_size = input_size
self.blstm_sizes = blstm_sizes
self.blstm_layers = []
for blstm_size in blstm_sizes:
self.blstm_layers.append(nn.LSTM(input_size, blstm_size // 2, bidirectional=True, batch_first=True))
input_size = blstm_size
self.output_size = blstm_size
self.lstm_1 = nn.LSTM(input_size, lstm_1 // 2, bidirectional=True, batch_first=True)
self.lstm_2 = nn.LSTM(lstm_1, lstm_2 // 2, bidirectional=True, batch_first=True)
self.linear_1 = nn.Linear(lstm_2, linear_1)
self.linear_2 = nn.Linear(linear_1, linear_2)
self.output = nn.Linear(linear_2, output_size)
self.hidden = None
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
def forward(self, inputs):
"""
......@@ -129,42 +92,104 @@ class BLSTM(nn.Module):
:param inputs:
:return:
"""
hiddens = []
if self.hidden is None:
hidden_1, hidden_2 = None, None
#hidden_1, hidden_2 = None, None
for _s in self.lstm_sizes:
hiddens.append(None)
else:
hidden_1, hidden_2 = self.hidden
tmp, hidden_1 = self.lstm_1(inputs, hidden_1)
x, hidden_2 = self.lstm_2(tmp, hidden_2)
self.hidden = (hidden_1, hidden_2)
x = torch.tanh(self.linear_1(x))
x = torch.tanh(self.linear_2(x))
x = torch.sigmoid(self.output(x))
hiddens = self.hidden
x = inputs
for idx, _s in enumerate(self.lstm_sizes):
x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
self.hidden = tuple(hiddens)
return x
def output_size(self):
return self.output_size
class SeqToSeq(nn.Module):
"""
Bi LSTM model used for voice activity detection or speaker turn detection
Model used for voice activity detection or speaker turn detection
This model can include a pre-processor to input raw waveform,
a BLSTM module to process the sequence-to-sequence
and other linear of convolutional layers
"""
def __init__(self,
input_size,
lstm_1,
lstm_2,
linear_1,
linear_2,
output_size=1):
model_archi):
# Todo Write like the Xtractor in order to enable a flexible build of the model including \
# Sincnet preprocessor, Convolutional filters, TDNN, BLSTM and other possible layers
super(SeqToSeq, self).__init__()
self.preprocessor = PreNet(sample_rate=16000,
windows_duration=0.2,
frame_shift=0.01)
self.lstm_1 = nn.LSTM(input_size, lstm_1 // 2, bidirectional=True, batch_first=True)
self.lstm_2 = nn.LSTM(lstm_1, lstm_2 // 2, bidirectional=True, batch_first=True)
self.linear_1 = nn.Linear(lstm_2, linear_1)
self.linear_2 = nn.Linear(linear_1, linear_2)
self.output = nn.Linear(linear_2, output_size)
self.hidden = None
# Load Yaml configuration
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
self.loss = cfg["loss"]
"""
Prepare Preprocessor
"""
self.preprocessor = None
if "preprocessor" in cfg:
if cfg['preprocessor']["type"] == "sincnet":
self.preprocessor = SincNet(
waveform_normalize=cfg['preprocessor']["waveform_normalize"],
sample_rate=cfg['preprocessor']["sample_rate"],
min_low_hz=cfg['preprocessor']["min_low_hz"],
min_band_hz=cfg['preprocessor']["min_band_hz"],
out_channels=cfg['preprocessor']["out_channels"],
kernel_size=cfg['preprocessor']["kernel_size"],
stride=cfg['preprocessor']["stride"],
max_pool=cfg['preprocessor']["max_pool"],
instance_normalize=cfg['preprocessor']["instance_normalize"],
activation=cfg['preprocessor']["activation"],
dropout=cfg['preprocessor']["dropout"]
)
self.feature_size = self.preprocessor.dimension
"""
Prepare sequence to sequence network
"""
# Get Feature size
if self.feature_size is None:
self.feature_size = cfg["feature_size"]
input_size = self.feature_size
sequence_to_sequence = BLSTM(input_size=input_size,
blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"])
input_size = sequence_to_sequence.output_size()
"""
Prepare post-processing network
"""
# Create sequential object for the second part of the network
post_processing_layers = []
for k in cfg["post_processing"].keys():
if k.startswith("lin"):
post_processing_layers.append((k, torch.nn.Linear(input_size,
cfg["post_processing"][k]["output"])))
input_size = cfg["post_processing"][k]["output"]
elif k.startswith("activation"):
post_processing_layers.append((k, self.activation))
elif k.startswith('batch_norm'):
post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(post_processing_layers))
self.before_speaker_embedding_weight_decay = cfg["before_embedding"]["weight_decay"]
def forward(self, inputs):
"""
......@@ -180,8 +205,9 @@ class SeqToSeq(nn.Module):
x, hidden_2 = self.lstm_2(tmp, hidden_2)
self.hidden = (hidden_1, hidden_2)
x = torch.tanh(self.linear_1(x))
x = torch.tanh(self.linear_2(x))
x = torch.sigmoid(self.output(x))
#x = torch.tanh(self.linear_2(x))
x = self.linear_2(x)
#x = torch.sigmoid(self.output(x))
return x
......
......@@ -41,6 +41,11 @@ import torch
from ..diar import Diar
from pathlib import Path
from sidekit.nnet.xsets import PreEmphasis
from sidekit.nnet.xsets import MFCC
from sidekit.nnet.xsets import CMVN
from sidekit.nnet.xsets import FrequencyMask
from sidekit.nnet.xsets import TemporalMask
from torch.utils.data import Dataset
from torchvision import transforms
from collections import namedtuple
......@@ -98,11 +103,22 @@ def mdtm_to_label(mdtm_filename,
:param start_time:
:param stop_time:
:param sample_number:
:param speaker_dict:
:return:
"""
diarization = Diar.read_mdtm(mdtm_filename)
diarization.sort(['show', 'start'])
# When one segment starts just the frame after the previous one ends, o
# we replace the time of the start by the time of the previous stop to avoid artificial holes
previous_stop = 0
for ii, seg in enumerate(diarization.segments):
if ii == 0:
previous_stop = seg['stop']
else:
if seg['start'] == diarization.segments[ii - 1]['stop'] + 1:
diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']
# Create the empty labels
label = numpy.zeros(sample_number, dtype=int)
......@@ -112,32 +128,52 @@ def mdtm_to_label(mdtm_filename,
for t in range(sample_number):
time_stamps[t] = start_time + (2 * t + 1) * period / 2
# Find the label of the first sample
# Find the label of the
# first sample
seg_idx = 0
while diarization.segments[seg_idx]['stop'] < start_time:
while diarization.segments[seg_idx]['stop'] / 100. < start_time:
seg_idx += 1
#REPRENDRE ICI
#ii = 0
#while diarization.segments[seg_idx]['start'] < stop_time:
# while time_stamps[ii] < diarization.segments[seg_idx]['stop']:
# label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
# ii += 1
# start = int(diarization.segments[seg_idx]['start']) * framerate // sampling_frequency
# stop = int(diarization.segments[seg_idx]['stop']) * framerate // sampling_frequency
# spk_idx = speaker_dict[segment['cluster']]
# label[start:stop] = spk_idx
# seg_idx += 1
# Get label of each sample
for ii, t in enumerate(time_stamps):
# Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
if t <= diarization.segments[seg_idx]['start']/100.:
# On laisse le label 0 (non-speech)
pass
# Si on est déjà dans le premier segment qui overlape
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100. :
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
# Si on change de segment
elif diarization.segments[seg_idx]['stop']/100. < t:
seg_idx += 1
# On est entre deux segments:
if t < diarization.segments[seg_idx]['start']/100.:
pass
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100.:
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
return label
def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, collar_duration, filter_type="gate"):
def get_segment_label(label,
seg_idx,
mode,
duration,
framerate,
seg_shift,
collar_duration,
filter_type="gate"):
"""
:param label:
:param seg_idx:
:param mode:
:param duration:
:param framerate:
:param seg_shift:
:param collar_duration:
:param filter_type:
:return:
"""
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
......@@ -172,88 +208,48 @@ def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, coll
return segment_label[seg_idx]
class DiarSet(Dataset):
def process_segment_label(label,
mode,
framerate,
collar_duration,
filter_type="gate"):
"""
Object creates a dataset for
"""
def __init__(self,
data_dir,
mode,
duration=2.,
seg_shift=0.25,
filter_type="gate",
collar_duration=0.1,
framerate=16000):
"""
Create batches of wavform samples for deep neural network training
:param data_dir: the root directory of ALLIES data
:param mode: can be "vad", "spk_turn", "overlap"
:param duration: duration of the segments in seconds
:param seg_shift: shift to generate overlaping segments
:param filter_type:
:param collar_duration:
"""
self.framerate = framerate
self.show_duration = {}
self.segments = []
self.duration = duration
self.seg_shift = seg_shift
self.input_dir = data_dir
self.mode = mode
self. filter_type = filter_type
self.collar_duration = collar_duration
self.wav_name_format = data_dir + '/wav/{}.wav'
self.mdtm_name_format = data_dir + '/mdtm/{}.mdtm'
# load the list of training file names
training_file_list = [str(f).split("/")[-1].split('.')[
0] for f in list(Path(data_dir + "/wav/").rglob("*.[wW][aA][vV]"))
]
for show in training_file_list:
# Load waveform
signal = sidekit.frontend.io.read_audio(self.wav_name_format.format(show), self.framerate)[0]
# Get speaker labels from MDTM
label = mdtm_to_label(self.mdtm_name_format.format(show), signal.shape, self.framerate)
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(signal.shape, dtype=int)
spk_change[:-1] = label[:-1] ^ label[1:]
spk_change = numpy.not_equal(spk_change, numpy.zeros(signal.shape, dtype=int))
# Create short segments with overlap
tmp = framing(spk_change,
int(self.framerate * duration),
win_shift=int(self.framerate * seg_shift),
context=(0, 0),
pad='zeros')
# Select only segments with at least a speaker change
keep_seg = numpy.not_equal(tmp.sum(1), 0)
keep_idx = numpy.argwhere(keep_seg.squeeze()).squeeze()
:param label:
:param seg_idx:
:param mode:
:param duration:
:param framerate:
:param seg_shift:
:param collar_duration:
:param filter_type:
:return:
"""
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
spk_change[:-1] = label[:-1] ^ label[1:]
spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))
for idx in keep_idx:
self.segments.append((show, idx))
# depending of the mode, generates the labels and select the segments
if mode == "vad":
output_label = (label > 0.5).astype(numpy.long)
self.len = len(self.segments)
elif mode == "spk_turn":
# Apply convolution to replace diracs by a chosen shape (gate or triangle)
filter_sample = int(collar_duration * framerate * 2 + 1)
def __getitem__(self, index):
show, idx = self.segments[index]
conv_filt = numpy.ones(filter_sample)
if filter_type == "triangle":
conv_filt = scipy.signal.triang(filter_sample)
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
data, total_duration = load_wav_segment(self.wav_name_format.format(show),
idx, self.duration, self.seg_shift, framerate=self.framerate)
elif mode == "overlap":
raise NotImplementedError()
tmp_label = mdtm_to_label(self.mdtm_name_format.format(show), total_duration, self.framerate)
label = get_segment_label(tmp_label, idx, self.mode, self.duration, self.framerate,
self.seg_shift, self.collar_duration, filter_type=self.filter_type)
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
def __len__(self):
return self.len
return output_label
def seqSplit(mdtm_dir,
......@@ -280,16 +276,17 @@ def seqSplit(mdtm_dir,
# For each border time B get a segment between B - duration and B + duration
# in which we will pick up randomly later
for idx, seg in enumerate(ref.segments):
if idx > 0 and seg["start"] > duration and seg["start"] + duration < last_stop:
if idx > 0 and seg["start"] / 100. > duration and seg["start"] + duration < last_stop:
segment_list.append(show=seg['show'],
cluster="",
start=float(seg["start"] - duration) / 100.,
stop=float(seg["start"] + duration) / 100.)
start=float(seg["start"]) / 100. - duration,
stop=float(seg["start"]) / 100. + duration)
elif idx < len(ref.segments) - 1 and seg["stop"] + duration < last_stop:
segment_list.append(show=seg['show'],
cluster="",
start=float(seg["stop"] - duration) / 100.,
stop=float(seg["stop"] + duration) / 100.)
start=float(seg["stop"]) / 100. - duration,
stop=float(seg["stop"]) / 100. + duration)
# Get list of unique speakers
speakers = ref.unique('cluster')
......@@ -308,22 +305,34 @@ class SeqSet(Dataset):
def __init__(self,
wav_dir,
mdtm_dir,
segment_list,
mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
framerate=16000,
transform_pipeline=None):
audio_framerate=16000,
output_framerate=100,
transform_pipeline=""):
"""
:param wav_dir:
:param mdtm_dir:
:param mode:
:param duration:
:param filter_type:
:param collar_duration:
:param audio_framerate:
:param output_framerate:
:param transform_pipeline:
"""
self.wav_dir = wav_dir
self.mdtm_dir = mdtm_dir
self.segment_list = segment_list
self.mode = mode
self.duration = duration
self.filter_type = filter_type
self.collar_duration = collar_duration
self.framerate = framerate
self.audio_framerate = audio_framerate
self.output_framerate = output_framerate
self.transform_pipeline = transform_pipeline
......@@ -363,34 +372,30 @@ class SeqSet(Dataset):
seg = self.segment_list[index]
# Randomly pick an audio chunk within the current segment
start = random.uniform(seg.start_time, seg.start_time + self.duration)
start = random.uniform(seg["start"], seg["start"] + self.duration)
sig, _ = soundfile.read(self.wav_dir + seg.show + ".wav",
start=start * self.sample_rate,
stop=(start + self.duration) * self.sample_rate
sig, _ = soundfile.read(self.wav_dir + seg["show"] + ".wav",
start=int(start * self.audio_framerate),
stop=int((start + self.duration) * self.audio_framerate)
)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline:
sig, _, __, ___ = self.transforms((sig, None, None, None))
label = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg.show + ".mdtm",
start_time=start,
stop_time=start + self.duration,
sample_number=sig.shape[0],
speaker_dict=self.speaker_dict)
sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None, None, None, None, None))
tmp_label = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg["show"] + ".mdtm",
start_time=start,
stop_time=start + self.duration,
sample_number=sig.shape[1],
speaker_dict=self.speaker_dict)
label = process_segment_label(label=tmp_label,