Commit 61e6919f authored by Meysam Shamsi's avatar Meysam Shamsi
Browse files

add lstms4d VAD

parent 3bb25db0
Pipeline #634 canceled with stages
import s4d
import torch
import logging
import shutil
import numpy
import pathlib
import random
import scipy
import sidekit
import soundfile
import torch
import yaml
import os
import tqdm
from s4d.diar import Diar
from collections import OrderedDict
from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sidekit.nnet.xsets import PreEmphasis
from sidekit.nnet.xsets import MFCC
from sidekit.nnet.xsets import CMVN
from sidekit.nnet.xsets import FrequencyMask
from sidekit.nnet.xsets import TemporalMask
from torchvision import transforms
from collections import namedtuple
# from s4d.s4d.nnet.wavsets import SeqSet
# from s4d.s4d.nnet.wavsets import mdtm_to_label
from s4d.nnet.wavsets import overlapping
from s4d.nnet.wavsets import process_segment_label
from s4d.nnet.wavsets import create_train_val_seqtoseq
def mdtm_to_label_(mdtm_filename,
start_time,
stop_time,
sample_number,
speaker_dict):
"""
:param mdtm_filename:
:param start_time:
:param stop_time:
:param sample_number:
:param speaker_dict:
:return:
"""
diarization = Diar.read_mdtm(mdtm_filename)
diarization.sort(['show', 'start'])
# When one segment starts just the frame after the previous one ends, o
# we replace the time of the start by the time of the previous stop to avoid artificial holes
previous_stop = 0
for ii, seg in enumerate(diarization.segments):
if ii == 0:
previous_stop = seg['stop']
else:
if seg['start'] == diarization.segments[ii - 1]['stop'] + 1:
diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']
# Create the empty labels
label = []
# Compute the time stamp of each sample
time_stamps = numpy.zeros((sample_number,2), dtype=numpy.float32)
period = (stop_time - start_time) / sample_number
st=start_time
for t in range(sample_number):
time_stamps[t] = [st,st+period]
st+=period
framed_segments = [seg for seg in diarization.segments if overlapping((seg['start'],seg['stop']),(start_time,stop_time))]
for idx, times in enumerate(time_stamps):
lbls = []
for seg in framed_segments:
if (seg['start']<times[0] and seg['stop']>times[0]) or (seg['start']<times[1] and seg['stop']>times[1]):
lbls.append(speaker_dict[seg['cluster']])
if len(lbls) > 0:
label.append(lbls)
else:
label.append([])
return label
def seqSplit_slidWin(mdtm_file,
wav_dir,
uem_dir=None,
duration=2.,
step=0.2):
"""
:param mdtm_dir:
:param duration:
:return:
"""
segment_list = Diar()
speaker_dict = dict()
idx = 0
# Load MDTM file
ref = Diar.read_mdtm(mdtm_file)
ref.sort()
last_stop = ref.segments[-1]["stop"]
# Get the borders of the segments (not the start of the first and not the end of the last
# Check the length of audio
head, tail = os.path.split(mdtm_file)
showName=tail.split(".")[0]
nfo = soundfile.info(wav_dir + showName + ".wav")
# import pdb
# pdb.set_trace()
# split by sliding windowing by length of duaration and stride of step
if uem_dir==None:
stratframe=0
while (stratframe + duration) < nfo.duration*nfo.samplerate:
segment_list.append(show=showName,
cluster="",
start = stratframe,
stop = stratframe + duration)
stratframe += step
else:
with open(uem_dir) as f:
lines = [line.rstrip() for line in f]
for seg in lines:
seg=seg.replace(" ", " ")
seg=seg.replace(" ", " ")
seg=seg.replace(" ", " ")
stratframe = int(float(seg.split(" ")[2])*nfo.samplerate)
stopframe = int(float(seg.split(" ")[3])*nfo.samplerate)
while (stratframe + duration) < stopframe:
segment_list.append(show=showName,
cluster="",
start = stratframe,
stop = stratframe + duration)
stratframe += step
# Get list of unique speakers
speakers = ref.unique('cluster')
for spk in speakers:
if not spk in speaker_dict:
speaker_dict[spk] = idx
idx += 1
return segment_list, speaker_dict
class SeqSet_slidWin(Dataset):
"""
Object creates a dataset for sequence to sequence training
"""
def __init__(self,
wav_dir,
mdtm_file,
mode,
uem_dir=None,
segment_list=None,
speaker_dict=None,
duration=2.,
step=0.2,
filter_type="gate",
collar_duration=0.1,
audio_framerate=16000,
output_framerate=100,
transform_pipeline=""):
"""
:param wav_dir:
:param mdtm_dir:
:param mode:
:param duration:
:param filter_type:
:param collar_duration:
:param audio_framerate:
:param output_framerate:
:param transform_pipeline:
"""
self.wav_dir = wav_dir
self.mdtm_file = mdtm_file
self.uem_dir=uem_dir
self.mode = mode
self.duration = duration
self.step = step
self.filter_type = filter_type
self.collar_duration = collar_duration
self.audio_framerate = audio_framerate
self.output_framerate = output_framerate
self.transform_pipeline = transform_pipeline
_transform = []
if not self.transform_pipeline == '':
trans = self.transform_pipeline.split(',')
for t in trans:
if 'PreEmphasis' in t:
_transform.append(PreEmphasis())
if 'MFCC' in t:
_transform.append(MFCC())
if "CMVN" in t:
_transform.append(CMVN())
if "FrequencyMask" in t:
a = int(t.split('-')[0].split('(')[1])
b = int(t.split('-')[1].split(')')[0])
_transform.append(FrequencyMask(a, b))
if "TemporalMask" in t:
a = int(t.split("(")[1].split(")")[0])
_transform.append(TemporalMask(a))
self.transforms = transforms.Compose(_transform)
if segment_list is None and speaker_dict is None:
segment_list, speaker_dict = seqSplit_slidWin(mdtm_file=self.mdtm_file,
wav_dir=self.wav_dir,
duration=int(self.duration* self.audio_framerate),
step=int(self.step* self.audio_framerate))
self.segment_list = segment_list
self.speaker_dict = speaker_dict
self.len = len(segment_list)
# import pdb
# pdb.set_trace()
def __getitem__(self, index):
"""
On renvoie un segment wavform brut mais il faut que les labels soient échantillonés à la bonne fréquence
(trames)
:param index:
:return:
"""
# Get segment info to load from
seg = self.segment_list[index]
sig, _ = soundfile.read(self.wav_dir + seg["show"] + ".wav",
start=seg["start"],
stop=seg["stop"])
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline:
sig, speaker_idx, _t, _s = self.transforms((sig, None, None, None, None, None))
tmp_label = mdtm_to_label_(mdtm_filename=self.mdtm_file,
start_time=(seg["start"]/self.audio_framerate)*100,
stop_time=(seg["stop"]/self.audio_framerate)*100,
sample_number=sig.shape[1],
speaker_dict=self.speaker_dict)
label = process_segment_label(label=tmp_label,
mode=self.mode,
framerate=self.output_framerate,
collar_duration=self.collar_duration,
filter_type=self.filter_type)
return torch.from_numpy(sig.T).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
def __len__(self):
return self.len
def unfold_(final_output,output,step):
if len(output.shape)<2:
output=numpy.array([output])
len_seq=output.shape[1]
sampleNumber=int(step*100)
for i in range(len(output)):
if len(final_output)==0:
final_output=numpy.array(output[i]).tolist()
else:
# import pdb
# pdb.set_trace()
final_output=final_output[:-len_seq+sampleNumber]\
+(numpy.array(final_output[-len_seq+sampleNumber:])+numpy.array(output[i])[:-sampleNumber]).tolist()\
+numpy.array(output[i]).tolist()[-sampleNumber:]
return final_output
def cross_validation(model, validation_loader, device, step,uemfile_dir=None):
"""
:param model:
:param validation_loader:
:param device:
:return:
"""
model.eval()
recall = 0.0
precision = 0.0
accuracy = 0.0
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
final_output=[]
final_target=[]
with torch.no_grad():
# for batch_idx, (data, target) in enumerate(validation_loader):#tqdm.tqdm(validation_loader)):
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader)):
batch_size = target.shape[0]
target = target.squeeze()
output = model(data.to(device))
# import pdb
# pdb.set_trace()
final_output=unfold_(final_output,output.cpu().numpy(),step)
final_target=unfold_(final_target,target.cpu().numpy(),step)
sad_seq=numpy.argmax(numpy.array(final_output),axis=1)
recall, precision, accuracy = calc_recall(sad_seq,(numpy.array(final_target)!=0),uemfile_dir=uemfile_dir)
if precision != 0 or recall != 0:
f_measure = 2 * (precision) * (recall) / (precision+recall)
logging.critical(
'Validation: [{}/{} ({:.0f}%)] Accuracy: {:.3f} ' \
'Recall: {:.3f} Precision: {:.3f} '\
'F-Measure: {:.3f}'.format(batch_idx + 1,
validation_loader.__len__(),
100. * batch_idx / validation_loader.__len__(),
100.0 * accuracy,
100.0 * recall,
100.0 * precision,
f_measure)
)
return 100.0 * accuracy ,100.0 * recall,100.0 * precision, sad_seq
def calc_recall(output,target,uemfile_dir=None):
"""
:param output:
:param target:
:return:
"""
rc = 0.0
pr = 0.0
acc= 0.0
if uemfile_dir!=None:
with open(uemfile_dir) as f:
lines = [line.rstrip() for line in f]
for seg in lines:
seg=seg.replace(" ", " ")
seg=seg.replace(" ", " ")
seg=seg.replace(" ", " ")
stratframe = int(float(seg.split(" ")[2]))*100
stopframe = int(float(seg.split(" ")[3]))*100
target=target[stratframe:stopframe]
output=output[stratframe:stopframe]
tp = sum(target * output)
tn = sum((1 - target) * (1 - output))
fp = sum((1 - target) * output)
fn = sum(target * (1 - output))
epsilon = 1e-7
pr= tp / (tp + fp + epsilon)
rc= tp / (tp + fn + epsilon)
acc=(tp+tn)/(tp+fp+tn+fn+epsilon)
return rc,pr,acc
_dataset_yaml = "../dihard"
_model_yaml = "../sts.yaml"
_epochs = 100
_lr = 0.001
_patience = 10
_multi_gpu = True
_opt='sgd'
_num_thread = 10
dataset_yaml=_dataset_yaml+".yaml"
model_yaml=_model_yaml
epochs=_epochs
lr=_lr
patience=_patience
model_name="../model/best_VAD.pt"
multi_gpu=_multi_gpu
opt=_opt
num_thread=_num_thread
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
# Load the model
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name,map_location='cpu')
model = s4d.nnet.SeqToSeq(model_yaml)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
else:
print("Test on a single GPU")
model.to(device)
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
slide_step=0.2
# slide_step=1.5
result_total={}
total_acc=0
total_len=0
total_rec=0
total_pr=0
for mdtm_file in pathlib.Path(dataset_params["mdtm_dir"]).glob('*.mdtm'):
showName=str(mdtm_file)[len(dataset_params["mdtm_dir"]):].split(".")[0]
if "uem_dir" in dataset_params:
uem_dir_=dataset_params["uem_dir"]+showName+".uem"
else:
uem_dir_=None
# if showName.startswith("BFMTV_CultureEtVous"):
test_set = SeqSet_slidWin(wav_dir=dataset_params["wav_dir"],
mdtm_file=mdtm_file,
uem_dir=uem_dir_,
mode=dataset_params["mode"],
duration=dataset_params["eval"]["duration"],
step=slide_step,
filter_type=dataset_params["filter_type"],
collar_duration=dataset_params["collar_duration"],
audio_framerate=dataset_params["sample_rate"],
output_framerate=dataset_params["output_rate"],
transform_pipeline=dataset_params["eval"]["transformation"]["pipeline"])
test_loader = DataLoader(test_set,
batch_size=dataset_params["batch_size"],
drop_last=False,
shuffle=False,
pin_memory=True,
num_workers=num_thread)
# Cross validation here
accuracy,recall,precision,output = cross_validation(model,
test_loader,
device=device,
step=slide_step,
uemfile_dir=uem_dir_)
# Save sad file
start_frm=-1
end_frm=0
with open("../VAD_result_"+_dataset_yaml+"/sad_"+showName+".lab", "w") as text_file:
for frm in range(len(output)):
if output[frm]:
if start_frm>0:
end_frm=frm
else:
start_frm=frm
end_frm=frm
else:
if start_frm>0:
text_file.write(str(start_frm/100.)+"\t"+str((end_frm+1)/100.)+"\tspeaker\n")
start_frm=-1
end_frm=0
if start_frm>0:
text_file.write(str(start_frm/100.)+"\t"+str((end_frm+1)/100.)+"\tspeaker\n")
# Print performance of show
logging.critical("*** {}(length:{} sec) = {} %".format(showName,len(output)/100.,accuracy))
result_total[showName]={"acc":accuracy,"rec":recall,"pr":precision,"length":len(output)}
total_acc+=accuracy*len(output)
total_len+=len(output)
total_rec+=recall*len(output)
total_pr+=precision*len(output)
ac=total_acc/total_len
rc=total_rec/total_len
pr=total_pr/total_len
fm=2*rc*pr/(rc+pr)
logging.critical("\n\n*** Total : Accuracy: {} Recall: {} Precision: {} F-Measure: {}%".format(ac,rc,pr,fm))
# save total performance in json file
import json
with open("../performance_"+_dataset_yaml+"_"+dataset_params["mode"]+".json", 'w') as f:
json.dump(result_total, f)
import s4d
import torch
_dataset_yaml = "../dihard.yaml"
_model_yaml = "../sts.yaml"
_epochs = 100
_lr = 0.001
_patience = 10
_tmp_model_name = "../model/checkpoint_VAD_DIHARD"
_best_model_name = "../model/best_VAD_DIHARD"
_multi_gpu = True
_opt='sgd'
_num_thread = 10
s4d.nnet.seqTrain(dataset_yaml=_dataset_yaml,
model_yaml=_model_yaml,
epochs=_epochs,
lr=_lr,
patience=_patience,
model_name=None,
tmp_model_name=_tmp_model_name,
best_model_name=_best_model_name,
multi_gpu=_multi_gpu,
opt=_opt,
num_thread=_num_thread)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment