Commit 476a1625 authored by Martin Lebourdais's avatar Martin Lebourdais
Browse files

Add of overlaps

parent 5502d3ea
Pipeline #620 failed with stages
in 0 seconds
......@@ -35,7 +35,7 @@ import shutil
import torch
import torch.nn as nn
import yaml
from sklearn.model_selection import train_test_split
from torch import optim
from torch.utils.data import Dataset
......@@ -52,18 +52,19 @@ __status__ = "Production"
__docformat__ = 'reS'
# def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
# """
#
# :param state:
# :param is_best:
# :param filename:
# :param best_filename:
# :return:
# """
# torch.save(state, filename)
# if is_best:
# shutil.copyfile(filename, best_filename)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
"""
:param state:
:param is_best:
:param filename:
:param best_filename:
:return:
"""
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
class BLSTM(nn.Module):
......@@ -78,13 +79,16 @@ class BLSTM(nn.Module):
super(BLSTM, self).__init__()
self.input_size = input_size
self.blstm_sizes = blstm_sizes
self.blstm_layers = []
for blstm_size in blstm_sizes:
self.blstm_layers.append(nn.LSTM(input_size, blstm_size // 2, bidirectional=True, batch_first=True))
input_size = blstm_size
self.output_size = blstm_size
#self.blstm_layers = []
# for blstm_size in blstm_sizes:
# print(f"Input size {input_size},Output_size {self.output_size}")
# self.blstm_layers.append(nn.LSTM(input_size, blstm_size, bidirectional=False, batch_first=True))
# input_size = blstm_size
self.output_size = blstm_sizes[0] * 2
# self.blstm_layers = torch.nn.ModuleList(self.blstm_layers)
self.blstm_layers = nn.LSTM(input_size,blstm_sizes[0],bidirectional=True,batch_first=True,num_layers=2)
self.hidden = None
"""
Bi LSTM model used for voice activity detection or speaker turn detection
......@@ -109,13 +113,16 @@ class BLSTM(nn.Module):
x = inputs
outputs = []
for idx, _s in enumerate(self.blstm_sizes):
x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
outputs.append(x)
self.hidden = tuple(hiddens)
output = torch.cat(outputs, dim=2)
return x
# for idx, _s in enumerate(self.blstm_sizes):
# # self.blstm_layers[idx].flatten_parameters()
# print("IN",x.shape)
# x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
# print("OUT",x.shape)
# outputs.append(x)
# self.hidden = tuple(hiddens)
# output = torch.cat(outputs, dim=2)
output,h = self.blstm_layers(x)
return output
def output_size(self):
return self.output_size
......@@ -168,8 +175,7 @@ class SeqToSeq(nn.Module):
if self.feature_size is None:
self.feature_size = cfg["feature_size"]
input_size = self.feature_size
input_size = self.feature_size
self.sequence_to_sequence = BLSTM(input_size=input_size,
blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"])
......@@ -255,6 +261,7 @@ def seqTrain(dataset_yaml,
:return:
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("CUP")
# Start from scratch
if model_name is None:
......@@ -265,7 +272,6 @@ def seqTrain(dataset_yaml,
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = SeqToSeq(model_yaml)
if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
......@@ -278,37 +284,48 @@ def seqTrain(dataset_yaml,
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
#df = pandas.read_csv(dataset_params["dataset_description"])
#training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
_wav_dir=dataset_params['wav_dir']
_mdtm_dir=dataset_params['mdtm_dir']
torch.manual_seed(dataset_params['seed'])
training_set = SeqSet(dataset_yaml,
wav_dir="data/wav/",
mdtm_dir="data/mdtm/",
mode="vad",
wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir,
mode=mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=16000,
output_framerate=100,
audio_framerate=framerate,
output_framerate=output_rate,
transform_pipeline="MFCC")
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
shuffle=True,
drop_last=True,
shuffle=True,
pin_memory=True,
num_workers=num_thread)
#validation_set = SeqSet(dataset_yaml,
# set_type="validation",
# dataset_df=validation_df)
validation_set = SeqSet(dataset_yaml,
wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir,
mode=mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=framerate,
output_framerate=output_rate,
set_type= "validation",
transform_pipeline="MFCC")
#validation_loader = DataLoader(validation_set,
# batch_size=dataset_params["batch_size"],
# drop_last=True,
# pin_memory=True,
# num_workers=num_thread)
validation_loader = DataLoader(validation_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
shuffle=True,
pin_memory=True,
num_workers=num_thread)
"""
Set the training options
......@@ -369,41 +386,41 @@ def seqTrain(dataset_yaml,
device=device)
# Cross validation here
#accuracy, val_loss = cross_validation(model, validation_loader, device=device)
#logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
#scheduler.step(val_loss)
#print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
# remember best accuracy and save checkpoint
#is_best = accuracy > best_accuracy
#best_accuracy = max(accuracy, best_accuracy)
#if type(model) is SeqToSeq:
# save_checkpoint({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy,
# 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#else:
# save_checkpoint({
# 'epoch': epoch,
# 'model_state_dict': model.module.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': best_accuracy,
# 'scheduler': scheduler
# }, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
#if is_best:
# best_accuracy_epoch = epoch
# curr_patience = patience
#else:
# curr_patience -= 1
#logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
scheduler.step(val_loss)
print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
## remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
if type(model) is SeqToSeq:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name + ".pt", best_filename=best_model_name + '.pt')
if is_best:
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
......@@ -421,15 +438,19 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
model.to(device)
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
optimizer.zero_grad()
output = model(data.to(device))
data = data.to(device)
output = model(data)
output = output.permute(1, 2, 0)
target = target.permute(1, 0)
#print(output.shape)
#print(torch.argmax(output[:,:,0],1))
#print(target[:,0])
loss = criterion(output, target.to(device))
loss.backward(retain_graph=True)
optimizer.step()
......@@ -461,11 +482,14 @@ def cross_validation(model, validation_loader, device):
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
output = model(data.to(device),target=target.to(device),is_eval=True)
print(output.shape)
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
output = model(data.to(device))
output = output.permute(1, 2, 0)
target = target.permute(1, 0)
nbpoint = output.shape[0]
accuracy += (((torch.argmax(output.data, 1) == target.to(device)).sum()).cpu().numpy())/nbpoint
loss += criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
return 100. * accuracy / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
......@@ -69,11 +69,11 @@ def framing(sig, win_size, win_shift=1, context=(0, 0), pad='zeros'):
shape = (int((sig.shape[0] - win_size) / win_shift) + 1, 1, _win_size, sig.shape[1])
strides = tuple(map(lambda x: x * dsize, [win_shift * sig.shape[1], 1, sig.shape[1], 1]))
return numpy.lib.stride_tricks.as_strided(sig,
shape=shape,
strides=strides).squeeze()
shape=shape,
strides=strides).squeeze()
def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
"""
def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
"""
:param wav_file_name:
:param idx:
......@@ -85,18 +85,18 @@ def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
# Load waveform
signal = sidekit.frontend.io.read_audio(wav_file_name, framerate)[0]
tmp = framing(signal,
int(framerate * duration),
win_shift=int(framerate * seg_shift),
context=(0, 0),
pad='zeros')
int(framerate * duration),
win_shift=int(framerate * seg_shift),
context=(0, 0),
pad='zeros')
return tmp[idx], len(signal)
def mdtm_to_label(mdtm_filename,
start_time,
stop_time,
sample_number,
speaker_dict):
start_time,
stop_time,
sample_number,
speaker_dict):
"""
:param mdtm_filename:
......@@ -108,7 +108,8 @@ def mdtm_to_label(mdtm_filename,
"""
diarization = Diar.read_mdtm(mdtm_filename)
diarization.sort(['show', 'start'])
overlaps = numpy.zeros(sample_number, dtype=int)
# When one segment starts just the frame after the previous one ends, o
# we replace the time of the start by the time of the previous stop to avoid artificial holes
previous_stop = 0
......@@ -127,13 +128,21 @@ def mdtm_to_label(mdtm_filename,
period = (stop_time - start_time) / sample_number
for t in range(sample_number):
time_stamps[t] = start_time + (2 * t + 1) * period / 2
for ii,i in enumerate(numpy.linspace(start_time,stop_time,num=sample_number)):
cnt = 0
for d in diarization.segments:
if d['start']/100 <= i <= d['stop']/100:
cnt+=1
overlaps[ii]=cnt
# Find the label of the
# first sample
seg_idx = 0
while diarization.segments[seg_idx]['stop'] / 100. < start_time:
#sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None, None, None, None, None))
seg_idx += 1
for ii, t in enumerate(time_stamps):
# Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
if t <= diarization.segments[seg_idx]['start']/100.:
......@@ -151,17 +160,18 @@ def mdtm_to_label(mdtm_filename,
elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100.:
label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
return label
return (label,overlaps)
def get_segment_label(label,
seg_idx,
mode,
duration,
framerate,
seg_shift,
collar_duration,
filter_type="gate"):
overlaps,
seg_idx,
mode,
duration,
framerate,
seg_shift,
collar_duration,
filter_type="gate"):
"""
:param label:
......@@ -193,26 +203,27 @@ def get_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap":
raise NotImplementedError()
output_label = (overlaps > 0.5).astype(numpy.long)
else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
# Create segments with overlap
segment_label = framing(output_label,
int(framerate * duration),
win_shift=int(framerate * seg_shift),
context=(0, 0),
pad='zeros')
int(framerate * duration),
win_shift=int(framerate * seg_shift),
context=(0, 0),
pad='zeros')
return segment_label[seg_idx]
def process_segment_label(label,
mode,
framerate,
collar_duration,
filter_type="gate"):
overlaps,
mode,
framerate,
collar_duration,
filter_type="gate"):
"""
:param label:
......@@ -244,7 +255,7 @@ def process_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap":
raise NotImplementedError()
output_label = (overlaps>1).astype(numpy.long)
else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
......@@ -253,9 +264,9 @@ def process_segment_label(label,
def seqSplit(mdtm_dir,
duration=2.):
duration=2.):
"""
:param mdtm_dir:
:param duration:
:return:
......@@ -278,17 +289,17 @@ def seqSplit(mdtm_dir,
for idx, seg in enumerate(ref.segments):
if idx > 0 and seg["start"] / 100. > duration and seg["start"] + duration < last_stop:
segment_list.append(show=seg['show'],
cluster="",
start=float(seg["start"]) / 100. - duration,
stop=float(seg["start"]) / 100. + duration)
cluster="",
start=float(seg["start"]) / 100. - duration,
stop=float(seg["start"]) / 100. + duration)
elif idx < len(ref.segments) - 1 and seg["stop"] + duration < last_stop:
segment_list.append(show=seg['show'],
cluster="",
start=float(seg["stop"]) / 100. - duration,
stop=float(seg["stop"]) / 100. + duration)
cluster="",
start=float(seg["stop"]) / 100. - duration,
stop=float(seg["stop"]) / 100. + duration)
# Get list of unique speakers
# Get list of unique speakers
speakers = ref.unique('cluster')
for spk in speakers:
if not spk in speaker_dict:
......@@ -303,16 +314,18 @@ class SeqSet(Dataset):
Object creates a dataset for sequence to sequence training
"""
def __init__(self,
dataset_yaml,
wav_dir,
mdtm_dir,
mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=16000,
output_framerate=100,
transform_pipeline=""):
dataset_yaml,
wav_dir,
mdtm_dir,
mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=16000,
output_framerate=100,
set_type="train",
dataset_df=None,
transform_pipeline=""):
"""
:param wav_dir:
......@@ -357,7 +370,7 @@ class SeqSet(Dataset):
self.transforms = transforms.Compose(_transform)
segment_list, speaker_dict = seqSplit(mdtm_dir=self.mdtm_dir,
duration=self.duration)
duration=self.duration)
self.segment_list = segment_list
self.speaker_dict = speaker_dict
self.len = len(segment_list)
......@@ -371,31 +384,40 @@ class SeqSet(Dataset):
"""
# Get segment info to load from
seg = self.segment_list[index]
ok = False
# Randomly pick an audio chunk within the current segment
start = random.uniform(seg["start"], seg["start"] + self.duration)
sig, _ = soundfile.read(self.wav_dir + seg["show"] + ".wav",
start=int(start * self.audio_framerate),
stop=int((start + self.duration) * self.audio_framerate)
)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline:
sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None, None, None, None, None))
tmp_label = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg["show"] + ".mdtm",
start_time=start,
stop_time=start + self.duration,
sample_number=sig.shape[1],
speaker_dict=self.speaker_dict)
while not ok:
try:
ok=True
start = random.uniform(seg["start"], seg["start"] + self.duration)
sig, _ = soundfile.read(self.wav_dir + seg["show"] + ".wav",
start=int(start * self.audio_framerate),
stop=int((start + self.duration) * self.audio_framerate)
)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline:
sig, speaker_idx,_t, _s = self.transforms((sig, None, None, None, None, None))
except ValueError as e:
ok=False
# sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None, None, None, None, None))
if sig.T.shape != (198,30):
print(sig.T.shape)
ok=False
tmp_label,overlaps = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg["show"] + ".mdtm",
start_time=start,
stop_time=start + self.duration,
sample_number=sig.shape[1],
speaker_dict=self.speaker_dict)
label = process_segment_label(label=tmp_label,
mode=self.mode,
framerate=self.output_framerate,
collar_duration=self.collar_duration,
filter_type=self.filter_type)
overlaps=overlaps,
mode=self.mode,
framerate=self.output_framerate,
collar_duration=self.collar_duration,
filter_type=self.filter_type)
return torch.from_numpy(sig.T).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment