Commit 7df4d68e authored by Martin Lebourdais's avatar Martin Lebourdais
Browse files
parents 4de1fd1a b0835fe2
......@@ -60,4 +60,4 @@ __maintainer__ = "Sylvain Meignier"
__email__ = "sylvain.meignierr@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'
__version__ = "0.1.4.7"
__version__ = "0.1.4.8"
......@@ -24,5 +24,7 @@ Copyright 2014-2020 Anthony Larcher
"""
from .wavsets import SeqSet
from .wavsets import create_train_val_seqtoseq
from .seqtoseq import BLSTM
from .seqtoseq import SeqToSeq
\ No newline at end of file
from .seqtoseq import SeqToSeq
from .seqtoseq import seqTrain
\ No newline at end of file
......@@ -23,29 +23,22 @@
Copyright 2014-2020 Anthony Larcher
"""
import os
import sys
import logging
import pandas
import numpy
from collections import OrderedDict
import random
import h5py
import shutil
import torch
import torch.nn as nn
import yaml
from sklearn.model_selection import train_test_split
from torch import optim
from torch.utils.data import Dataset
from .loss import ConcordanceCorCoeff
from .wavsets import SeqSet
from collections import OrderedDict
from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader
from .wavsets import SeqSet
from .wavsets import create_train_val_seqtoseq
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__author__ = "Anthony Larcher, Martin Lebourdais, Meysam Shamsi"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......@@ -68,7 +61,21 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename
shutil.copyfile(filename, best_filename)
class BLSTM(nn.Module):
def init_weights(m):
"""
:return:
"""
if type(m) == torch.nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class BLSTM(torch.nn.Module):
"""
Bi LSTM model used for voice activity detection, speaker turn detection, overlap detection and resegmentation
"""
def __init__(self,
input_size,
blstm_sizes):
......@@ -80,20 +87,13 @@ class BLSTM(nn.Module):
super(BLSTM, self).__init__()
self.input_size = input_size
self.blstm_sizes = blstm_sizes
#self.blstm_layers = []
# for blstm_size in blstm_sizes:
# print(f"Input size {input_size},Output_size {self.output_size}")
# self.blstm_layers.append(nn.LSTM(input_size, blstm_size, bidirectional=False, batch_first=True))
# input_size = blstm_size
self.output_size = blstm_sizes[0] * 2
# self.blstm_layers = torch.nn.ModuleList(self.blstm_layers)
self.output_size = blstm_sizes * 2
self.blstm_layers = nn.LSTM(input_size,blstm_sizes[0],bidirectional=True,batch_first=True,num_layers=2)
self.hidden = None
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
self.blstm_layers = torch.nn.LSTM(input_size,
blstm_sizes,
bidirectional=True,
batch_first=True,
num_layers=2)
def forward(self, inputs):
"""
......@@ -101,35 +101,18 @@ class BLSTM(nn.Module):
:param inputs:
:return:
"""
#for idx, _s in enumerate(self.blstm_sizes):
# self.blstm_layers[idx].flatten_parameters()
hiddens = []
if self.hidden is None:
#hidden_1, hidden_2 = None, None
for _s in self.blstm_sizes:
hiddens.append(None)
else:
hiddens = list(self.hidden)
x = inputs
outputs = []
# for idx, _s in enumerate(self.blstm_sizes):
# # self.blstm_layers[idx].flatten_parameters()
# print("IN",x.shape)
# x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
# print("OUT",x.shape)
# outputs.append(x)
# self.hidden = tuple(hiddens)
# output = torch.cat(outputs, dim=2)
output,h = self.blstm_layers(x)
output, h = self.blstm_layers(inputs)
return output
def output_size(self):
"""
:return:
"""
return self.output_size
class SeqToSeq(nn.Module):
class SeqToSeq(torch.nn.Module):
"""
Model used for voice activity detection or speaker turn detection
This model can include a pre-processor to input raw waveform,
......@@ -205,8 +188,7 @@ class SeqToSeq(nn.Module):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
#self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"]
self.post_processing.apply(init_weights)
def forward(self, inputs):
"""
......@@ -226,7 +208,6 @@ class SeqToSeq(nn.Module):
def seqTrain(dataset_yaml,
val_dataset_yaml,
model_yaml,
mode,
epochs=100,
lr=0.0001,
patience=10,
......@@ -235,11 +216,6 @@ def seqTrain(dataset_yaml,
best_model_name=None,
multi_gpu=True,
opt='sgd',
filter_type="gate",
collar_duration=0.1,
framerate=16000,
output_rate=100,
batch_size=32,
log_interval=10,
num_thread=10,
non_overlap_dataset = None,
......@@ -281,29 +257,17 @@ def seqTrain(dataset_yaml,
model = torch.nn.DataParallel(model)
else:
print("Train on a single GPU")
model.to(device)
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
"""
Create two dataloaders for training and evaluation
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
_wav_dir=dataset_params['wav_dir']
_mdtm_dir=dataset_params['mdtm_dir']
torch.manual_seed(dataset_params['seed'])
training_set = SeqSet(dataset_yaml,
wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir,
mode=mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=framerate,
output_framerate=output_rate,
transform_pipeline="MFCC")
training_set, validation_set = create_train_val_seqtoseq(dataset_yaml)
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
......@@ -311,23 +275,9 @@ def seqTrain(dataset_yaml,
pin_memory=True,
num_workers=num_thread)
validation_set = SeqSet(val_dataset_yaml,
wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir,
mode=mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=framerate,
output_framerate=output_rate,
set_type= "validation",
transform_pipeline="MFCC")
validation_loader = DataLoader(validation_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
shuffle=True,
pin_memory=True,
num_workers=num_thread)
......@@ -359,24 +309,13 @@ def seqTrain(dataset_yaml,
]
optimizer = _optimizer([{'params': model.parameters()},], **_options)
#if type(model) is SeqToSeq:
# optimizer = _optimizer([
# {'params': model.parameters(),
# 'weight_decay': model.weight_decay},],
# **_options
# )
#else:
# optimizer = _optimizer([
# {'params': model.module.sequence_network.parameters(),
# #'weight_decay': model.module.sequence_network_weight_decay},],
# **_options
# )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
best_fmes = 0.0
best_fmes_epoch = 1
curr_patience = patience
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
if curr_patience == 0:
......@@ -390,23 +329,23 @@ def seqTrain(dataset_yaml,
device=device)
# Cross validation here
fmes, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Validation f-Measure = {}".format(fmes))
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
## remember best accuracy and save checkpoint
is_best = fmes > best_fmes
best_fmes = max(fmes, best_fmes)
# remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
if type(model) is SeqToSeq:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_fmes,
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
else:
......@@ -414,46 +353,19 @@ def seqTrain(dataset_yaml,
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_fmes,
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
if is_best:
best_fmes_epoch = epoch
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
logging.critical(f"Best F-Mesure {best_fmes * 100.} obtained at epoch {best_fmes_epoch}")
def calc_recall(output,target,device):
y_trueb = target.to(device)
y_predb = output
rc = 0.0
pr = 0.0
batch_size = y_trueb.shape[1]
for b in range(batch_size):
y_true = y_trueb[:,b]
y_pred = y_predb[:,:,b]
assert y_true.ndim == 1
assert y_pred.ndim == 1 or y_pred.ndim == 2
if y_pred.ndim == 2:
y_pred = y_pred.argmax(dim=1)
tp = (y_true * y_pred).sum().to(torch.float32)
tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
epsilon = 1e-7
precision = tp / (tp + fp + epsilon)
recall = tp / (tp + fn + epsilon)
rc+=recall
pr+=precision
return rc,pr
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
......@@ -468,12 +380,14 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
model.to(device)
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
#criterion = ccc_loss
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor([0.1,0.9]).to(device))
recall = 0.0
precision = 0.0
accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
# tnumpy = target.numpy()
# print(tnumpy.shape)
......@@ -491,18 +405,36 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
loss.backward(retain_graph=True)
optimizer.step()
rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item()
precision += pr.item()
accuracy += acc.item()
rc,pr = calc_recall(output.data,target,device)
accuracy += pr
recall += rc
if batch_idx % log_interval == 0:
batch_size = target.shape[0]
# print(100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198))
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}\tRecall: {:.3f}'.format(
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198),
100.0 * recall.item() / ((batch_idx+1) *batch_size * 198)
))
if precision!=0 or recall!=0:
f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx+1))) /\
((precision / ((batch_idx + 1) ))+(recall / ((batch_idx + 1))))
logging.critical(
'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} '\
'Recall: {:.3f} Precision: {:.3f} "\
F-Measure: {:.3f}'.format(epoch,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy / ((batch_idx + 1)),
100.0 * recall/ ((batch_idx+1)),
100.0 * precision / ((batch_idx+1)),
f_measure)
)
else:
print(f"precision = {precision} and recall = {recall}")
return model
return model
def pearsonr(x, y):
......@@ -560,7 +492,7 @@ def llincc(x, y):
return ccc:
'''
def cross_validation(model, validation_loader, device):
"""
......@@ -570,9 +502,12 @@ def cross_validation(model, validation_loader, device):
:return:
"""
model.eval()
recall = 0.0
precision = 0.0
accuracy = 0.0
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader):
......@@ -581,13 +516,66 @@ def cross_validation(model, validation_loader, device):
output = model(data.to(device))
output = output.permute(1, 2, 0)
target = target.permute(1, 0)
nbpoint = output.shape[0]
rc,pr = calc_recall(output.data,target,device)
accuracy+= pr
recall += rc
loss += criterion(output, target.to(device))
fmes = 2*(accuracy*recall)/(recall+accuracy)
return fmes / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item()
precision += pr.item()
accuracy += acc.item()
batch_size = target.shape[0]
if precision != 0 or recall != 0:
f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx + 1))) / \
((precision / ((batch_idx + 1))) + (recall / ((batch_idx + 1))))
logging.critical(
'Validation: [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} ' \
'Recall: {:.3f} Precision: {:.3f} "\
F-Measure: {:.3f}'.format(batch_idx + 1,
validation_loader.__len__(),
100. * batch_idx / validation_loader.__len__(), loss.item(),
100.0 * accuracy / ((batch_idx + 1)),
100.0 * recall / ((batch_idx + 1)),
100.0 * precision / ((batch_idx + 1)),
f_measure)
)
return accuracy, loss
def calc_recall(output,target,device):
"""
:param output:
:param target:
:param device:
:return:
"""
y_trueb = target.to(device)
y_predb = output
rc = 0.0
pr = 0.0
acc= 0.0
for b in range(y_trueb.shape[-1]):
y_true = y_trueb[:,b]
y_pred = y_predb[:,:,b]
assert y_true.ndim == 1
assert y_pred.ndim == 1 or y_pred.ndim == 2
if y_pred.ndim == 2:
y_pred = y_pred.argmax(dim=1)
tp = (y_true * y_pred).sum().to(torch.float32)
tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
epsilon = 1e-7
pr+= tp / (tp + fp + epsilon)
rc+= tp / (tp + fn + epsilon)
a=(tp+tn)/(tp+fp+tn+fn+epsilon)
acc+=(tp+tn)/(tp+fp+tn+fn+epsilon)
rc/=len(y_trueb[0])
pr/=len(y_trueb[0])
acc/=len(y_trueb[0])
return rc,pr,acc
......@@ -38,6 +38,7 @@ import scipy
import sidekit
import soundfile
import torch
import yaml
from ..diar import Diar
from pathlib import Path
......@@ -123,7 +124,7 @@ def mdtm_to_label(mdtm_filename,
diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']
# Create the empty labels
label = numpy.zeros(sample_number, dtype=int)
label = []
# Compute the time stamp of each sample
time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
......@@ -131,38 +132,16 @@ def mdtm_to_label(mdtm_filename,
for t in range(sample_number):
time_stamps[t] = start_time + (2 * t + 1) * period / 2
for idx, time in enumerate(time_stamps):
lbls = []
for seg in diarization.segments:
if seg['start'] / 100. <= time <= seg['stop'] / 100.:
lbls.append(speaker_dict[seg['cluster']])
for ii,i in enumerate(numpy.linspace(start_time,stop_time,num=sample_number)):
cnt = 0
for d in diarization.segments:
# print(d)
# print(d['start'],d['stop'])
if d['start']/100 <= i <= d['stop']/100:
cnt+=1
overlaps[ii]=cnt
# Find the label of the
# first sample
seg_idx = 0
while diarization.segments[seg_idx]['stop'] / 100. < start_time:
#sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None, None, None, None, None))
seg_idx += 1
for ii, t in enumerate(time_stamps):
# Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
if t <= diarization.segments[seg_idx]['start']/100.:
# On laisse le label 0 (non-speech)
pass
# Si on est déjà dans le premier segment qui overlape
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100. :
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
# Si on change de segment
elif diarization.segments[seg_idx]['stop']/100. < t and len(diarization.segments) > seg_idx + 1:
seg_idx += 1
# On est entre deux segments:
if t < diarization.segments[seg_idx]['start']/100.:
pass
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100.:
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
if len(lbls) > 0:
label.append(lbls)
else:
label.append([])
return (label,overlaps)
......@@ -207,7 +186,7 @@ def get_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap":
output_label = (overlaps > 1).astype(numpy.long)
output_label = (label > 0.5).astype(numpy.long)
else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
......@@ -240,16 +219,29 @@ def process_segment_label(label,
:param filter_type:
:return:
"""
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
spk_change[:-1] = label[:-1] ^ label[1:]
spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))
# depending of the mode, generates the labels and select the segments
if mode == "vad":
output_label = (label > 0.5).astype(numpy.long)
output_label = numpy.array([len(a) > 0 for a in label]).astype(numpy.long)
elif mode == "spk_turn":
tmp_label = []
for a in label:
if len(a) == 0:
tmp_label.append(0)
elif len(a) == 1:
tmp_label.append(a[0])
else:
tmp_label.append(sum(a) * 1000)
label = numpy.array(label)
# Create labels with Diracs at every speaker change detection
spk_change = numpy.zeros(label.shape, dtype=int)
spk_change[:-1] = label[:-1] ^ label[1:]
spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))
# Apply convolution to replace diracs by a chosen shape (gate or triangle)
filter_sample = int(collar_duration * framerate * 2 + 1)
......@@ -259,7 +251,11 @@ def process_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap":
output_label = (overlaps>1).astype(numpy.long)
label = numpy.array([len(a) for a in label]).astype(numpy.long)
# For the moment, we just consider two classes: overlap / no-overlap
# in the future we might want to classify according to the number of speaker speaking at the same time
output_label = (label > 1).astype(numpy.long)
else: