Commit c809d6b7 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

VAD

parent 836858d6
......@@ -23,28 +23,20 @@
Copyright 2014-2020 Anthony Larcher
"""
import os
import sys
import logging
import pandas
import numpy
from collections import OrderedDict
import random
import h5py
import shutil
import torch
import torch.nn as nn
import yaml
from torch import optim
from torch.utils.data import Dataset
from .wavsets import SeqSet
from collections import OrderedDict
from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader
from .wavsets import SeqSet
from .wavsets import create_train_val_seqtoseq
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__author__ = "Anthony Larcher, Martin Lebourdais, Meysam Shamsi"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......@@ -52,21 +44,33 @@ __status__ = "Production"
__docformat__ = 'reS'
# def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
# """
#
# :param state:
# :param is_best:
# :param filename:
# :param best_filename:
# :return:
# """
# torch.save(state, filename)
# if is_best:
# shutil.copyfile(filename, best_filename)
class BLSTM(nn.Module):
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
"""
:param state:
:param is_best:
:param filename:
:param best_filename:
:return:
"""
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
def init_weights(m):
"""
:return:
"""
if type(m) == torch.nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class BLSTM(torch.nn.Module):
"""
Bi LSTM model used for voice activity detection, speaker turn detection, overlap detection and resegmentation
"""
def __init__(self,
input_size,
blstm_sizes):
......@@ -78,17 +82,13 @@ class BLSTM(nn.Module):
super(BLSTM, self).__init__()
self.input_size = input_size
self.blstm_sizes = blstm_sizes
self.blstm_layers = []
for blstm_size in blstm_sizes:
self.blstm_layers.append(nn.LSTM(input_size, blstm_size // 2, bidirectional=True, batch_first=True))
input_size = blstm_size
self.output_size = blstm_size
self.output_size = blstm_sizes[0] * 2
self.hidden = None
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
self.blstm_layers = torch.nn.LSTM(input_size,
blstm_sizes,
bidirectional=True,
batch_first=True,
num_layers=2)
def forward(self, inputs):
"""
......@@ -96,32 +96,18 @@ class BLSTM(nn.Module):
:param inputs:
:return:
"""
#for idx, _s in enumerate(self.blstm_sizes):
# self.blstm_layers[idx].flatten_parameters()
hiddens = []
if self.hidden is None:
#hidden_1, hidden_2 = None, None
for _s in self.blstm_sizes:
hiddens.append(None)
else:
hiddens = list(self.hidden)
x = inputs
outputs = []
for idx, _s in enumerate(self.blstm_sizes):
x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
outputs.append(x)
self.hidden = tuple(hiddens)
output = torch.cat(outputs, dim=2)
return x
output, h = self.blstm_layers(inputs)
return output
def output_size(self):
"""
:return:
"""
return self.output_size
class SeqToSeq(nn.Module):
class SeqToSeq(torch.nn.Module):
"""
Model used for voice activity detection or speaker turn detection
This model can include a pre-processor to input raw waveform,
......@@ -198,8 +184,7 @@ class SeqToSeq(nn.Module):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
#self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"]
self.post_processing.apply(init_weights)
def forward(self, inputs):
"""
......@@ -271,27 +256,16 @@ def seqTrain(dataset_yaml,
model = torch.nn.DataParallel(model)
else:
print("Train on a single GPU")
model.to(device)
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
"""
Create two dataloaders for training and evaluation
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
#df = pandas.read_csv(dataset_params["dataset_description"])
#training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SeqSet(dataset_yaml,
wav_dir="data/wav/",
mdtm_dir="data/mdtm/",
mode="vad",
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=16000,
output_framerate=100,
transform_pipeline="MFCC")
training_set, validation_set = create_train_val_seqtoseq(dataset_yaml)
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
......@@ -300,15 +274,11 @@ def seqTrain(dataset_yaml,
pin_memory=True,
num_workers=num_thread)
#validation_set = SeqSet(dataset_yaml,
# set_type="validation",
# dataset_df=validation_df)
#validation_loader = DataLoader(validation_set,
# batch_size=dataset_params["batch_size"],
# drop_last=True,
# pin_memory=True,
# num_workers=num_thread)
validation_loader = DataLoader(validation_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
pin_memory=True,
num_workers=num_thread)
"""
Set the training options
......@@ -338,24 +308,13 @@ def seqTrain(dataset_yaml,
]
optimizer = _optimizer([{'params': model.parameters()},], **_options)
#if type(model) is SeqToSeq:
# optimizer = _optimizer([
# {'params': model.parameters(),
# 'weight_decay': model.weight_decay},],
# **_options
# )
#else:
# optimizer = _optimizer([
# {'params': model.module.sequence_network.parameters(),
# #'weight_decay': model.module.sequence_network_weight_decay},],
# **_options
# )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
best_accuracy = 0.0
best_accuracy_epoch = 1
curr_patience = patience
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
if curr_patience == 0:
......@@ -369,41 +328,41 @@ def seqTrain(dataset_yaml,
device=device)
# Cross validation here
#accuracy, val_loss = cross_validation(model, validation_loader, device=device)
#logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
#scheduler.step(val_loss)
#print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
scheduler.step(val_loss)
print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
# remember best accuracy and save checkpoint
#is_best = accuracy > best_accuracy
#best_accuracy = max(accuracy, best_accuracy)
#if type(model) is SeqToSeq:
# save_checkpoint({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy,
# 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#else:
# save_checkpoint({
# 'epoch': epoch,
# 'model_state_dict': model.module.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy,
# 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#if is_best:
# best_accuracy_epoch = epoch
# curr_patience = patience
#else:
# curr_patience -= 1
#logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
if type(model) is SeqToSeq:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
if is_best:
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
......@@ -420,9 +379,12 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
model.to(device)
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor([0.1,0.9]))
recall = 0.0
precision = 0.0
accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
optimizer.zero_grad()
......@@ -433,14 +395,32 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
loss = criterion(output, target.to(device))
loss.backward(retain_graph=True)
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item()
precision += pr.item()
accuracy += acc.item()
if batch_idx % log_interval == 0:
batch_size = target.shape[0]
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198)))
if precision!=0 or recall!=0:
f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx+1))) /\
((precision / ((batch_idx + 1) ))+(recall / ((batch_idx + 1))))
logging.critical(
'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} '\
'Recall: {:.3f} Precision: {:.3f} "\
F-Measure: {:.3f}'.format(epoch,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy / ((batch_idx + 1)),
100.0 * recall/ ((batch_idx+1)),
100.0 * precision / ((batch_idx+1)),
f_measure)
)
else:
print(f"precision = {precision} and recall = {recall}")
return model
......@@ -454,18 +434,77 @@ def cross_validation(model, validation_loader, device):
"""
model.eval()
recall = 0.0
precision = 0.0
accuracy = 0.0
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
output = model(data.to(device),target=target.to(device),is_eval=True)
print(output.shape)
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
output = model(data.to(device))
output = output.permute(1, 2, 0)
target = target.permute(1, 0)
nbpoint = output.shape[0]
rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item()
precision += pr.item()
accuracy += acc.item()
batch_size = target.shape[0]
if precision != 0 or recall != 0:
f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx + 1))) / \
((precision / ((batch_idx + 1))) + (recall / ((batch_idx + 1))))
logging.critical(
'Validation: [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} ' \
'Recall: {:.3f} Precision: {:.3f} "\
F-Measure: {:.3f}'.format(batch_idx + 1,
validation_loader.__len__(),
100. * batch_idx / validation_loader.__len__(), loss.item(),
100.0 * accuracy / ((batch_idx + 1)),
100.0 * recall / ((batch_idx + 1)),
100.0 * precision / ((batch_idx + 1)),
f_measure)
)
loss += criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
def calc_recall(output,target,device):
"""
:param output:
:param target:
:param device:
:return:
"""
y_trueb = target.to(device)
y_predb = output
rc = 0.0
pr = 0.0
acc= 0.0
for b in range(y_trueb.shape[-1]):
y_true = y_trueb[:,b]
y_pred = y_predb[:,:,b]
assert y_true.ndim == 1
assert y_pred.ndim == 1 or y_pred.ndim == 2
if y_pred.ndim == 2:
y_pred = y_pred.argmax(dim=1)
tp = (y_true * y_pred).sum().to(torch.float32)
tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
epsilon = 1e-7
pr+= tp / (tp + fp + epsilon)
rc+= tp / (tp + fn + epsilon)
a=(tp+tn)/(tp+fp+tn+fn+epsilon)
acc+=(tp+tn)/(tp+fp+tn+fn+epsilon)
rc/=len(y_trueb[0])
pr/=len(y_trueb[0])
acc/=len(y_trueb[0])
return rc,pr,acc
\ No newline at end of file
......@@ -38,6 +38,7 @@ import scipy
import sidekit
import soundfile
import torch
import yaml
from ..diar import Diar
from pathlib import Path
......@@ -93,6 +94,7 @@ def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
def mdtm_to_label(mdtm_filename,
mode,
start_time,
stop_time,
sample_number,
......@@ -120,7 +122,7 @@ def mdtm_to_label(mdtm_filename,
diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']
# Create the empty labels
label = numpy.zeros(sample_number, dtype=int)
label = list(numpy.zeros(sample_number, dtype=int))
# Compute the time stamp of each sample
time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
......@@ -128,28 +130,11 @@ def mdtm_to_label(mdtm_filename,
for t in range(sample_number):
time_stamps[t] = start_time + (2 * t + 1) * period / 2
# Find the label of the
# first sample
seg_idx = 0
while diarization.segments[seg_idx]['stop'] / 100. < start_time:
seg_idx += 1
for ii, t in enumerate(time_stamps):
# Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
if t <= diarization.segments[seg_idx]['start']/100.:
# On laisse le label 0 (non-speech)
pass
# Si on est déjà dans le premier segment qui overlape
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100. :
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
# Si on change de segment
elif diarization.segments[seg_idx]['stop']/100. < t and len(diarization.segments) > seg_idx + 1:
seg_idx += 1
# On est entre deux segments:
if t < diarization.segments[seg_idx]['start']/100.:
pass
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100.:
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
for idx, time in enumerate(time_stamps):
lbls = []
for seg in diarization.segments:
if seg['start'] / 100. <= time <= seg['stop'] / 100.:
lbls.append(speaker_dict[seg['cluster']])
return label
......@@ -193,7 +178,7 @@ def get_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap":
raise NotImplementedError()
output_label = (label > 0.5).astype(numpy.long)
else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
......@@ -225,16 +210,29 @@ def process_segment_label(label,
:param filter_type:
:return:
"""
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
spk_change[:-1] = label[:-1] ^ label[1:]
spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))
# depending of the mode, generates the labels and select the segments
if mode == "vad":
output_label = (label > 0.5).astype(numpy.long)
output_label = numpy.array([len(a) > 0 for a in label]).astype(numpy.long)
elif mode == "spk_turn":
tmp_label = []
for a in label:
if len(a) == 0:
tmp_label.append(0)
elif len(a) == 1:
tmp_label.append(a[0])
else:
tmp_label.append(sum(a) * 1000)
label = numpy.array(label)
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
spk_change[:-1] = label[:-1] ^ label[1:]
spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))
# Apply convolution to replace diracs by a chosen shape (gate or triangle)
filter_sample = int(collar_duration * framerate * 2 + 1)
......@@ -244,7 +242,11 @@ def process_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap":
raise NotImplementedError()
label = numpy.array([len(a) for a in label]).astype(numpy.long)
# For the moment, we just consider two classes: overlap / no-overlap
# in the future we might want to classify according to the number of speaker speaking at the same time
output_label = (label > 1).astype(numpy.long)
else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
......@@ -303,10 +305,11 @@ class SeqSet(Dataset):
Object creates a dataset for sequence to sequence training
"""
def __init__(self,
dataset_yaml,
wav_dir,
mdtm_dir,
mode,
segment_list=None,
speaker_dict=None,
duration=2.,
filter_type="gate",
collar_duration=0.1,
......@@ -356,8 +359,10 @@ class SeqSet(Dataset):
_transform.append(TemporalMask(a))
self.transforms = transforms.Compose(_transform)
segment_list, speaker_dict = seqSplit(mdtm_dir=self.mdtm_dir,
duration=self.duration)
if segment_list is None and speaker_dict is None:
segment_list, speaker_dict = seqSplit(mdtm_dir=self.mdtm_dir,
duration=self.duration)
self.segment_list = segment_list
self.speaker_dict = speaker_dict
self.len = len(segment_list)
......@@ -388,7 +393,8 @@ class SeqSet(Dataset):
start_time=start,
stop_time=start + self.duration,
sample_number=sig.shape[1],
speaker_dict=self.speaker_dict)
speaker_dict=self.speaker_dict,
overlap=self.mode=="overlap")
label = process_segment_label(label=tmp_label,
mode=self.mode,
......@@ -400,3 +406,68 @@ class SeqSet(Dataset):
def __len__(self):
return self.len
def create_train_val_seqtoseq(dataset_yaml):
"""
:param self:
:param wav_dir:
:param mdtm_dir:
:param mode:
:param segment_list
:param speaker_dict:
:param duration:
:param filter_type:
:param collar_duration:
:param audio_framerate:
:param output_framerate:
:param transform_pipeline:
:return:
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
torch.manual_seed(dataset_params['seed'])
# Read all MDTM files and ouptut a list of segments with minimum duration as well as a speaker dictionary
segment_list, speaker_dict = seqSplit(mdtm_dir=dataset_params["mdtm_dir"],
duration=dataset_params["duration"])
split_idx = numpy.random.choice