Commit c809d6b7 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

VAD

parent 836858d6
...@@ -23,28 +23,20 @@ ...@@ -23,28 +23,20 @@
Copyright 2014-2020 Anthony Larcher Copyright 2014-2020 Anthony Larcher
""" """
import os
import sys
import logging import logging
import pandas
import numpy
from collections import OrderedDict
import random
import h5py
import shutil import shutil
import torch import torch
import torch.nn as nn
import yaml import yaml
from torch import optim from collections import OrderedDict
from torch.utils.data import Dataset
from .wavsets import SeqSet
from sidekit.nnet.sincnet import SincNet from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .wavsets import SeqSet
from .wavsets import create_train_val_seqtoseq
__license__ = "LGPL" __license__ = "LGPL"
__author__ = "Anthony Larcher" __author__ = "Anthony Larcher, Martin Lebourdais, Meysam Shamsi"
__copyright__ = "Copyright 2015-2020 Anthony Larcher" __copyright__ = "Copyright 2015-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher" __maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr" __email__ = "anthony.larcher@univ-lemans.fr"
...@@ -52,21 +44,33 @@ __status__ = "Production" ...@@ -52,21 +44,33 @@ __status__ = "Production"
__docformat__ = 'reS' __docformat__ = 'reS'
# def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'): def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
# """ """
# :param state:
# :param state: :param is_best:
# :param is_best: :param filename:
# :param filename: :param best_filename:
# :param best_filename: :return:
# :return: """
# """ torch.save(state, filename)
# torch.save(state, filename) if is_best:
# if is_best: shutil.copyfile(filename, best_filename)
# shutil.copyfile(filename, best_filename)
def init_weights(m):
class BLSTM(nn.Module): """
:return:
"""
if type(m) == torch.nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class BLSTM(torch.nn.Module):
"""
Bi LSTM model used for voice activity detection, speaker turn detection, overlap detection and resegmentation
"""
def __init__(self, def __init__(self,
input_size, input_size,
blstm_sizes): blstm_sizes):
...@@ -78,17 +82,13 @@ class BLSTM(nn.Module): ...@@ -78,17 +82,13 @@ class BLSTM(nn.Module):
super(BLSTM, self).__init__() super(BLSTM, self).__init__()
self.input_size = input_size self.input_size = input_size
self.blstm_sizes = blstm_sizes self.blstm_sizes = blstm_sizes
self.blstm_layers = [] self.output_size = blstm_sizes[0] * 2
for blstm_size in blstm_sizes:
self.blstm_layers.append(nn.LSTM(input_size, blstm_size // 2, bidirectional=True, batch_first=True))
input_size = blstm_size
self.output_size = blstm_size
self.hidden = None self.blstm_layers = torch.nn.LSTM(input_size,
""" blstm_sizes,
Bi LSTM model used for voice activity detection or speaker turn detection bidirectional=True,
""" batch_first=True,
num_layers=2)
def forward(self, inputs): def forward(self, inputs):
""" """
...@@ -96,32 +96,18 @@ class BLSTM(nn.Module): ...@@ -96,32 +96,18 @@ class BLSTM(nn.Module):
:param inputs: :param inputs:
:return: :return:
""" """
#for idx, _s in enumerate(self.blstm_sizes): output, h = self.blstm_layers(inputs)
# self.blstm_layers[idx].flatten_parameters() return output
hiddens = []
if self.hidden is None:
#hidden_1, hidden_2 = None, None
for _s in self.blstm_sizes:
hiddens.append(None)
else:
hiddens = list(self.hidden)
x = inputs
outputs = []
for idx, _s in enumerate(self.blstm_sizes):
x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
outputs.append(x)
self.hidden = tuple(hiddens)
output = torch.cat(outputs, dim=2)
return x
def output_size(self): def output_size(self):
"""
:return:
"""
return self.output_size return self.output_size
class SeqToSeq(nn.Module): class SeqToSeq(torch.nn.Module):
""" """
Model used for voice activity detection or speaker turn detection Model used for voice activity detection or speaker turn detection
This model can include a pre-processor to input raw waveform, This model can include a pre-processor to input raw waveform,
...@@ -198,8 +184,7 @@ class SeqToSeq(nn.Module): ...@@ -198,8 +184,7 @@ class SeqToSeq(nn.Module):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k]))) post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers)) self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
#self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"] self.post_processing.apply(init_weights)
def forward(self, inputs): def forward(self, inputs):
""" """
...@@ -271,27 +256,16 @@ def seqTrain(dataset_yaml, ...@@ -271,27 +256,16 @@ def seqTrain(dataset_yaml,
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
else: else:
print("Train on a single GPU") print("Train on a single GPU")
model.to(device) model.to(device)
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
""" """
Create two dataloaders for training and evaluation Create two dataloaders for training and evaluation
""" """
with open(dataset_yaml, "r") as fh: training_set, validation_set = create_train_val_seqtoseq(dataset_yaml)
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
#df = pandas.read_csv(dataset_params["dataset_description"])
#training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SeqSet(dataset_yaml,
wav_dir="data/wav/",
mdtm_dir="data/mdtm/",
mode="vad",
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=16000,
output_framerate=100,
transform_pipeline="MFCC")
training_loader = DataLoader(training_set, training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"], batch_size=dataset_params["batch_size"],
...@@ -300,15 +274,11 @@ def seqTrain(dataset_yaml, ...@@ -300,15 +274,11 @@ def seqTrain(dataset_yaml,
pin_memory=True, pin_memory=True,
num_workers=num_thread) num_workers=num_thread)
#validation_set = SeqSet(dataset_yaml, validation_loader = DataLoader(validation_set,
# set_type="validation", batch_size=dataset_params["batch_size"],
# dataset_df=validation_df) drop_last=True,
pin_memory=True,
#validation_loader = DataLoader(validation_set, num_workers=num_thread)
# batch_size=dataset_params["batch_size"],
# drop_last=True,
# pin_memory=True,
# num_workers=num_thread)
""" """
Set the training options Set the training options
...@@ -338,24 +308,13 @@ def seqTrain(dataset_yaml, ...@@ -338,24 +308,13 @@ def seqTrain(dataset_yaml,
] ]
optimizer = _optimizer([{'params': model.parameters()},], **_options) optimizer = _optimizer([{'params': model.parameters()},], **_options)
#if type(model) is SeqToSeq:
# optimizer = _optimizer([
# {'params': model.parameters(),
# 'weight_decay': model.weight_decay},],
# **_options
# )
#else:
# optimizer = _optimizer([
# {'params': model.module.sequence_network.parameters(),
# #'weight_decay': model.module.sequence_network_weight_decay},],
# **_options
# )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
best_accuracy = 0.0 best_accuracy = 0.0
best_accuracy_epoch = 1 best_accuracy_epoch = 1
curr_patience = patience curr_patience = patience
for epoch in range(1, epochs + 1): for epoch in range(1, epochs + 1):
# Process one epoch and return the current model # Process one epoch and return the current model
if curr_patience == 0: if curr_patience == 0:
...@@ -369,41 +328,41 @@ def seqTrain(dataset_yaml, ...@@ -369,41 +328,41 @@ def seqTrain(dataset_yaml,
device=device) device=device)
# Cross validation here # Cross validation here
#accuracy, val_loss = cross_validation(model, validation_loader, device=device) accuracy, val_loss = cross_validation(model, validation_loader, device=device)
#logging.critical("*** Cross validation accuracy = {} %".format(accuracy)) logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy # Decrease learning rate according to the scheduler policy
#scheduler.step(val_loss) scheduler.step(val_loss)
#print(f"Learning rate is {optimizer.param_groups[0]['lr']}") print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
# remember best accuracy and save checkpoint # remember best accuracy and save checkpoint
#is_best = accuracy > best_accuracy is_best = accuracy > best_accuracy
#best_accuracy = max(accuracy, best_accuracy) best_accuracy = max(accuracy, best_accuracy)
#if type(model) is SeqToSeq: if type(model) is SeqToSeq:
# save_checkpoint({ save_checkpoint({
# 'epoch': epoch, 'epoch': epoch,
# 'model_state_dict': model.state_dict(), 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy, 'accuracy': best_accuracy,
# 'scheduler': scheduler 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt') }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#else: else:
# save_checkpoint({ save_checkpoint({
# 'epoch': epoch, 'epoch': epoch,
# 'model_state_dict': model.module.state_dict(), 'model_state_dict': model.module.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy, 'accuracy': best_accuracy,
# 'scheduler': scheduler 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt') }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#if is_best: if is_best:
# best_accuracy_epoch = epoch best_accuracy_epoch = epoch
# curr_patience = patience curr_patience = patience
#else: else:
# curr_patience -= 1 curr_patience -= 1
#logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}") logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device): def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
...@@ -420,9 +379,12 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device): ...@@ -420,9 +379,12 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
""" """
model.to(device) model.to(device)
model.train() model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean') criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor([0.1,0.9]))
recall = 0.0
precision = 0.0
accuracy = 0.0 accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader): for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze() target = target.squeeze()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -433,14 +395,32 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device): ...@@ -433,14 +395,32 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
loss = criterion(output, target.to(device)) loss = criterion(output, target.to(device))
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
optimizer.step() optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item()
precision += pr.item()
accuracy += acc.item()
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_size = target.shape[0] batch_size = target.shape[0]
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format( if precision!=0 or recall!=0:
epoch, batch_idx + 1, training_loader.__len__(), f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx+1))) /\
100. * batch_idx / training_loader.__len__(), loss.item(), ((precision / ((batch_idx + 1) ))+(recall / ((batch_idx + 1))))
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198))) logging.critical(
'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} '\
'Recall: {:.3f} Precision: {:.3f} "\
F-Measure: {:.3f}'.format(epoch,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy / ((batch_idx + 1)),
100.0 * recall/ ((batch_idx+1)),
100.0 * precision / ((batch_idx+1)),
f_measure)
)
else:
print(f"precision = {precision} and recall = {recall}")
return model return model
...@@ -454,18 +434,77 @@ def cross_validation(model, validation_loader, device): ...@@ -454,18 +434,77 @@ def cross_validation(model, validation_loader, device):
""" """
model.eval() model.eval()
recall = 0.0
precision = 0.0
accuracy = 0.0 accuracy = 0.0
loss = 0.0 loss = 0.0
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad(): with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader): for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0] batch_size = target.shape[0]
target = target.squeeze() target = target.squeeze()
output = model(data.to(device),target=target.to(device),is_eval=True) output = model(data.to(device))
print(output.shape) output = output.permute(1, 2, 0)
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum() target = target.permute(1, 0)
nbpoint = output.shape[0]
rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item()
precision += pr.item()
accuracy += acc.item()
batch_size = target.shape[0]
if precision != 0 or recall != 0:
f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx + 1))) / \
((precision / ((batch_idx + 1))) + (recall / ((batch_idx + 1))))
logging.critical(
'Validation: [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} ' \
'Recall: {:.3f} Precision: {:.3f} "\
F-Measure: {:.3f}'.format(batch_idx + 1,
validation_loader.__len__(),
100. * batch_idx / validation_loader.__len__(), loss.item(),
100.0 * accuracy / ((batch_idx + 1)),
100.0 * recall / ((batch_idx + 1)),
100.0 * precision / ((batch_idx + 1)),
f_measure)
)
loss += criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
def calc_recall(output,target,device):
"""
:param output:
:param target:
:param device:
:return:
"""
y_trueb = target.to(device)
y_predb = output
rc = 0.0
pr = 0.0
acc= 0.0
for b in range(y_trueb.shape[-1]):
y_true = y_trueb[:,b]
y_pred = y_predb[:,:,b]
assert y_true.ndim == 1
assert y_pred.ndim == 1 or y_pred.ndim == 2
if y_pred.ndim == 2:
y_pred = y_pred.argmax(dim=1)
tp = (y_true * y_pred).sum().to(torch.float32)
tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
epsilon = 1e-7
pr+= tp / (tp + fp + epsilon)
rc+= tp / (tp + fn + epsilon)
a=(tp+tn)/(tp+fp+tn+fn+epsilon)
acc+=(tp+tn)/(tp+fp+tn+fn+epsilon)
rc/=len(y_trueb[0])
pr/=len(y_trueb[0])
acc/=len(y_trueb[0])
return rc,pr,acc
\ No newline at end of file
...@@ -38,6 +38,7 @@ import scipy ...@@ -38,6 +38,7 @@ import scipy
import sidekit import sidekit
import soundfile import soundfile
import torch import torch
import yaml
from ..diar import Diar from ..diar import Diar
from pathlib import Path from pathlib import Path
...@@ -93,6 +94,7 @@ def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000): ...@@ -93,6 +94,7 @@ def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
def mdtm_to_label(mdtm_filename, def mdtm_to_label(mdtm_filename,
mode,
start_time, start_time,
stop_time, stop_time,
sample_number, sample_number,
...@@ -120,7 +122,7 @@ def mdtm_to_label(mdtm_filename, ...@@ -120,7 +122,7 @@ def mdtm_to_label(mdtm_filename,
diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop'] diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']
# Create the empty labels # Create the empty labels
label = numpy.zeros(sample_number, dtype=int) label = list(numpy.zeros(sample_number, dtype=int))
# Compute the time stamp of each sample # Compute the time stamp of each sample
time_stamps = numpy.zeros(sample_number, dtype=numpy.float32) time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
...@@ -128,28 +130,11 @@ def mdtm_to_label(mdtm_filename, ...@@ -128,28 +130,11 @@ def mdtm_to_label(mdtm_filename,
for t in range(sample_number): for t in range(sample_number):
time_stamps[t] = start_time + (2 * t + 1) * period / 2 time_stamps[t] = start_time + (2 * t + 1) * period / 2
# Find the label of the for idx, time in enumerate(time_stamps):
# first sample lbls = []
seg_idx = 0 for seg in diarization.segments:
while diarization.segments[seg_idx]['stop'] / 100. < start_time: if seg['start'] / 100. <= time <= seg['stop'] / 100.:
seg_idx += 1 lbls.append(speaker_dict[seg['cluster']])
for ii, t in enumerate(time_stamps):
# Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
if t <= diarization.segments[seg_idx]['start']/100.:
# On laisse le label 0 (non-speech)
pass
# Si on est déjà dans le premier segment qui overlape
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100. :
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
# Si on change de segment
elif diarization.segments[seg_idx]['stop']/100. < t and len(diarization.segments) > seg_idx + 1:
seg_idx += 1
# On est entre deux segments:
if t < diarization.segments[seg_idx]['start']/100.:
pass
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100.:
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
return label return label
...@@ -193,7 +178,7 @@ def get_segment_label(label, ...@@ -193,7 +178,7 @@ def get_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same') output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap": elif mode == "overlap":
raise NotImplementedError() output_label = (label > 0.5).astype(numpy.long)
else: else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'") raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
...@@ -225,16 +210,29 @@ def process_segment_label(label, ...@@ -225,16 +210,29 @@ def process_segment_label(label,
:param filter_type: :param filter_type:
:return: :return:
""" """
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
spk_change[:-1] = label[:-1] ^ label[1:]
spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))