Commit e120ee8a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

seq2seq

parent c447ff2b
......@@ -38,7 +38,7 @@ from .clustering.hac_utils import bic_square_root
from .clustering.cc_iv import ConnectedComponent
from .nnet.wavsets import AlliesSet
from .nnet.wavsets import SeqSet
from .nnet.seqtoseq import PreNet
from .nnet.seqtoseq import BLSTM
......
......@@ -23,6 +23,6 @@
Copyright 2014-2020 Anthony Larcher
"""
from .wavsets import AlliesSet
from .wavsets import SeqSet
from .seqtoseq import PreNet
from .seqtoseq import BLSTM
\ No newline at end of file
......@@ -28,6 +28,7 @@ import sys
import numpy
import random
import h5py
import shutil
import torch
import torch.nn as nn
from torch import optim
......@@ -35,6 +36,7 @@ from torch.utils.data import Dataset
import logging
from sidekit.nnet.vad_rnn import BLSTM
from torch.utils.data import DataLoader
__license__ = "LGPL"
__author__ = "Anthony Larcher"
......@@ -45,7 +47,18 @@ __status__ = "Production"
__docformat__ = 'reS'
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
"""
:param state:
:param is_best:
:param filename:
:param best_filename:
:return:
"""
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
class PreNet(nn.Module):
def __init(self,
......@@ -130,21 +143,269 @@ class BLSTM(nn.Module):
class SeqToSeq(nn.Module):
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
def __init__(self,
input_size,
lstm_1,
lstm_2,
linear_1,
linear_2,
output_size=1):
"""
def __init__(self):
self.preprocessor = PreNet(sample_rate=16000,
windows_duration=0.2,
frame_shift=0.01)
:param input_size:
:param lstm_1:
:param lstm_2:
:param linear_1:
:param linear_2:
:param output_size:
"""
super(BLSTM, self).__init__()
self.sequence_model = BLSTM(input_size=1,
lstm_1=64,
lstm_2=40,
linear_1=40,
linear_2=10)
self.lstm_1 = nn.LSTM(input_size, lstm_1 // 2, bidirectional=True, batch_first=True)
self.lstm_2 = nn.LSTM(lstm_1, lstm_2 // 2, bidirectional=True, batch_first=True)
self.linear_1 = nn.Linear(lstm_2, linear_1)
self.linear_2 = nn.Linear(linear_1, linear_2)
self.output = nn.Linear(linear_2, output_size)
self.hidden = None
def forward(self, input):
x = self.preprocessor(input)
output = self.sequence_model(x)
return output
def forward(self, inputs):
"""
:param inputs:
:return:
"""
if self.hidden is None:
hidden_1, hidden_2 = None, None
else:
hidden_1, hidden_2 = self.hidden
tmp, hidden_1 = self.lstm_1(inputs, hidden_1)
x, hidden_2 = self.lstm_2(tmp, hidden_2)
self.hidden = (hidden_1, hidden_2)
x = torch.tanh(self.linear_1(x))
x = torch.tanh(self.linear_2(x))
x = torch.sigmoid(self.output(x))
return x
def seqTrain(data_dir,
mode,
duration=2.,
seg_shift=0.25,
filter_type="gate",
collar_duration=0.1,
framerate=16000,
epochs=100,
batch_size=32,
lr=0.0001,
loss="cross_validation",
patience=10,
tmp_model_name=None,
best_model_name=None,
multi_gpu=True,
opt='sgd',
num_thread=10
):
"""
:param data_dir:
:param mode:
:param duration:
:param seg_shift:
:param filter_type:
:param collar_duration:
:param framerate:
:param epochs:
:param lr:
:param loss:
:param patience:
:param tmp_model_name:
:param best_model_name:
:param multi_gpu:
:param opt:
:return:
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
model = SeqToSeq()
# TODO implement a model adaptation
if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
else:
print("Train on a single GPU")
model.to(device)
"""
Create two dataloaders for training and evaluation
"""
training_set, validation_set = None, None
training_loader = DataLoader(training_set,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=num_thread)
validation_loader = DataLoader(validation_set,
batch_size=batch_size,
drop_last=True,
num_workers=num_thread)
"""
Set the training options
"""
if opt == 'sgd':
_optimizer = torch.optim.SGD
_options = {'lr': lr, 'momentum': 0.9}
elif opt == 'adam':
_optimizer = torch.optim.Adam
_options = {'lr': lr}
elif opt == 'rmsprop':
_optimizer = torch.optim.RMSprop
_options = {'lr': lr}
params = [
{
'params': [
param for name, param in model.named_parameters() if 'bn' not in name
]
},
{
'params': [
param for name, param in model.named_parameters() if 'bn' in name
],
'weight_decay': 0
},
]
if type(model) is SeqToSeq:
optimizer = _optimizer([
{'params': model.parameters(),
'weight_decay': model.weight_decay},],
**_options
)
else:
optimizer = _optimizer([
{'params': model.module.sequence_network.parameters(),
'weight_decay': model.module.sequence_network_weight_decay},],
**_options
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
best_accuracy = 0.0
best_accuracy_epoch = 1
curr_patience = patience
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
if curr_patience == 0:
print(f"Stopping at epoch {epoch} for cause of patience")
break
model = train_epoch(model,
epoch,
training_loader,
optimizer,
log_interval,
device=device)
# Add the cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
# remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
if type(model) is SeqToSeq:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
if is_best:
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
:param model:
:param epoch:
:param training_loader:
:param optimizer:
:param log_interval:
:param device:
:param clipping:
:return:
"""
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
optimizer.zero_grad()
output = model(data.to(device),target=target.to(device))
loss = criterion(output, target.to(device))
loss.backward()
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
if batch_idx % log_interval == 0:
batch_size = target.shape[0]
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
return model
def cross_validation(model, validation_loader, device):
"""
:param model:
:param validation_loader:
:param device:
:return:
"""
model.eval()
accuracy = 0.0
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
output = model(data.to(device),target=target.to(device),is_eval=True)
print(output.shape)
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
loss += criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
......@@ -32,14 +32,20 @@ __status__ = "Production"
__docformat__ = 'reStructuredText'
import numpy
import pathlib
import random
import scipy
import sidekit
import soundfile
import torch
from ..diar import Diar
from pathlib import Path
from torch.utils.data import Dataset
from torchvision import transforms
from collections import namedtuple
#Segment = namedtuple('Segment', ['show', 'start_time', 'end_time'])
def framing(sig, win_size, win_shift=1, context=(0, 0), pad='zeros'):
"""
......@@ -82,39 +88,51 @@ def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
def mdtm_to_label(mdtm_filename,
show_duration,
framerate):
start_time,
stop_time,
sample_number,
speaker_dict):
"""
:param show:
:param show_duration:
:param allies_dir:
:param mode:
:param duration:
:param start:
:param framerate:
:param filter_type:
:param collar_duration:
:param mdtm_filename:
:param start_time:
:param stop_time:
:param sample_number:
:return:
"""
diarization = Diar.read_mdtm(mdtm_filename)
diarization.sort(['show', 'start'])
# Create a dictionary of speakers
speaker_set = diarization.unique('cluster')
speaker_dict = {}
for idx, spk in enumerate(speaker_set):
speaker_dict[spk] = idx
# Create the empty labels
label = numpy.zeros(show_duration, dtype=int)
label = numpy.zeros(sample_number, dtype=int)
# Compute the time stamp of each sample
time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
period = (stop_time - start_time) / sample_number
for t in range(sample_number):
time_stamps[t] = start_time + (2 * t + 1) * period / 2
# Find the label of the first sample
seg_idx = 0
while diarization.segments[seg_idx]['stop'] < start_time:
seg_idx += 1
#REPRENDRE ICI
#ii = 0
#while diarization.segments[seg_idx]['start'] < stop_time:
# while time_stamps[ii] < diarization.segments[seg_idx]['stop']:
# label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
# ii += 1
# Fill the labels with spk_idx
for segment in diarization:
start = int(segment['start']) * framerate // 100
stop = int(segment['stop']) * framerate // 100
spk_idx = speaker_dict[segment['cluster']]
label[start:stop] = spk_idx
# start = int(diarization.segments[seg_idx]['start']) * framerate // sampling_frequency
# stop = int(diarization.segments[seg_idx]['stop']) * framerate // sampling_frequency
# spk_idx = speaker_dict[segment['cluster']]
# label[start:stop] = spk_idx
# seg_idx += 1
# Get label of each sample
return label
......@@ -154,12 +172,12 @@ def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, coll
return segment_label[seg_idx]
class AlliesSet(Dataset):
class DiarSet(Dataset):
"""
Object creates a dataset for
"""
def __init__(self,
allies_dir,
data_dir,
mode,
duration=2.,
seg_shift=0.25,
......@@ -170,7 +188,7 @@ class AlliesSet(Dataset):
Create batches of wavform samples for deep neural network training
:param allies_dir: the root directory of ALLIES data
:param data_dir: the root directory of ALLIES data
:param mode: can be "vad", "spk_turn", "overlap"
:param duration: duration of the segments in seconds
:param seg_shift: shift to generate overlaping segments
......@@ -182,16 +200,16 @@ class AlliesSet(Dataset):
self.segments = []
self.duration = duration
self.seg_shift = seg_shift
self.input_dir = allies_dir
self.input_dir = data_dir
self.mode = mode
self. filter_type = filter_type
self.collar_duration = collar_duration
self.wav_name_format = allies_dir + '/wav/{}.wav'
self.mdtm_name_format = allies_dir + '/mdtm/{}.mdtm'
self.wav_name_format = data_dir + '/wav/{}.wav'
self.mdtm_name_format = data_dir + '/mdtm/{}.mdtm'
# load the list of training file names
training_file_list = [str(f).split("/")[-1].split('.')[
0] for f in list(Path(allies_dir + "/wav/").rglob("*.[wW][aA][vV]"))
0] for f in list(Path(data_dir + "/wav/").rglob("*.[wW][aA][vV]"))
]
for show in training_file_list:
......@@ -236,3 +254,143 @@ class AlliesSet(Dataset):
def __len__(self):
return self.len
def seqSplit(mdtm_dir,
duration=2.):
"""
:param mdtm_dir:
:param duration:
:return:
"""
segment_list = Diar()
speaker_dict = dict()
idx = 0
# For each MDTM
for mdtm_file in pathlib.Path(mdtm_dir).glob('*.mdtm'):
# Load MDTM file
ref = Diar.read_mdtm(mdtm_file)
ref.sort()
last_stop = ref.segments[-1]["stop"]
# Get the borders of the segments (not the start of the first and not the end of the last
# For each border time B get a segment between B - duration and B + duration
# in which we will pick up randomly later
for idx, seg in enumerate(ref.segments):
if idx > 0 and seg["start"] > duration and seg["start"] + duration < last_stop:
segment_list.append(show=seg['show'],
cluster="",
start=float(seg["start"] - duration) / 100.,
stop=float(seg["start"] + duration) / 100.)
elif idx < len(ref.segments) - 1 and seg["stop"] + duration < last_stop:
segment_list.append(show=seg['show'],
cluster="",
start=float(seg["stop"] - duration) / 100.,
stop=float(seg["stop"] + duration) / 100.)
# Get list of unique speakers
speakers = ref.unique('cluster')
for spk in speakers:
if not spk in speaker_dict:
speaker_dict[spk] = idx
idx += 1
return segment_list, speaker_dict
class SeqSet(Dataset):
"""
Object creates a dataset for sequence to sequence training
"""
def __init__(self,
wav_dir,
mdtm_dir,
segment_list,
mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
framerate=16000,
transform_pipeline=None):
self.wav_dir = wav_dir
self.mdtm_dir = mdtm_dir
self.segment_list = segment_list
self.mode = mode
self.duration = duration
self.filter_type = filter_type
self.collar_duration = collar_duration
self.framerate = framerate
self.transform_pipeline = transform_pipeline
_transform = []
if not self.transform_pipeline == '':
trans = self.transform_pipeline.split(',')
for t in trans:
if 'PreEmphasis' in t:
_transform.append(PreEmphasis())
if 'MFCC' in t:
_transform.append(MFCC())
if "CMVN" in t:
_transform.append(CMVN())
if "FrequencyMask" in t:
a = int(t.split('-')[0].split('(')[1])
b = int(t.split('-')[1].split(')')[0])
_transform.append(FrequencyMask(a, b))
if "TemporalMask" in t:
a = int(t.split("(")[1].split(")")[0])
_transform.append(TemporalMask(a))
self.transforms = transforms.Compose(_transform)
segment_list, speaker_dict = seqSplit(mdtm_dir=self.mdtm_dir,
duration=self.duration)
self.segment_list = segment_list
self.speaker_dict = speaker_dict
self.len = len(segment_list)
def __getitem__(self, index):
"""
On renvoie un segment wavform brut mais il faut que les labels soient échantillonés à la bonne fréquence
(trames)
:param index:
:return:
"""
# Get segment info to load from
seg = self.segment_list[index]
# Randomly pick an audio chunk within the current segment
start = random.uniform(seg.start_time, seg.start_time + self.duration)
sig, _ = soundfile.read(self.wav_dir + seg.show + ".wav",
start=start * self.sample_rate,
stop=(start + self.duration) * self.sample_rate
)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline:
sig, _, __, ___ = self.transforms((sig, None, None, None))
label = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg.show + ".mdtm",
start_time=start,
stop_time=start + self.duration,
sample_number=sig.shape[0],
speaker_dict=self.speaker_dict)
# For each sampling_time we need to get the label
# A MODIFIER
label = get_segment_label(tmp_label, idx, self.mode, self.duration, self.framerate,
self.seg_shift, self.collar_duration, filter_type=self.filter_type)
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
def __len__(self):
return self.len