Commit 0403dd7f authored by Martin Lebourdais's avatar Martin Lebourdais
Browse files

Prediction without labels (not clean, but should work, testing in progress)

parent 2bbed2f5
......@@ -109,7 +109,7 @@ def calculate_pyannote_metrics(predf,trueyaml,task='overlap'):
cfg = yaml.load(fh, Loader=yaml.FullLoader)
else:
cfg=trueyaml
show_list = cfg["file_list"]
uem=None
if "uem_dir" in cfg:
......@@ -238,8 +238,8 @@ def model_pred(prediction_yaml,
wav_dir = cfg["wav_dir"]
model_name = cfg["model_name"]
model_yaml = cfg["model_archi"]
seg_set = cfg["seg_set"]
label_set = cfg["label_set"]
seg_set = cfg.get("seg_set",None)
label_set = cfg.get("label_set",None)
uem_dir = cfg.get("uem_dir",None)
batch_size = cfg["batch_size"]
audio_fr = cfg["audio_samplerate"]
......@@ -259,7 +259,7 @@ def model_pred(prediction_yaml,
# loading model
with open(model_yaml, "r") as fh:
archi = yaml.load(fh, Loader=yaml.FullLoader)
# manage context in case of convolutional pre-processing
context = 0
if not is_multichannel:
......@@ -272,7 +272,7 @@ def model_pred(prediction_yaml,
checkpoint = torch.load(model_name, map_location=device)
# Pytorch fails with saving mel scale transformation
# Pytorch fails with saving mel scale transformation
if 'pre_processing.sacc.mel_scale.fb' in checkpoint['model_state_dict'].keys():
checkpoint['model_state_dict']['pre_processing.sacc.mel_scale.fb'] = torch.Tensor([])
if 'pre_processing.complex_sacc.mel_scale.fb' in checkpoint['model_state_dict'].keys():
......@@ -329,7 +329,7 @@ def model_pred(prediction_yaml,
show=show,
first = first,
out_dir=out_dir,
)
pred_history.append({"show":show,"pred":pred,"target":target,"raw_pred":raw_pred[:,1]})
......@@ -347,12 +347,12 @@ def model_pred(prediction_yaml,
os.mkdir(res_path+"rttm/")
#compute scores + create RTTM file with current prediction
scores = metrics(pred_history,out_rttm,min_on,min_off,compute_ap_score=True,uem=uem_dir)
# Serializing json
# Serializing json
json_object = json.dumps(scores, indent = 4)
# Writing to json file
with open(fname, "w") as outfile:
outfile.write(json_object)
......@@ -375,7 +375,7 @@ def model_pred(prediction_yaml,
min_off = min_off,
metrics=scores)
scores["test_config"]={"th_in":th_in,"th_out":th_out,"min_on":min_on,"min_off":min_off}
# metrics from Pyannote as a comparison
if pyametrics is not None:
logger.message("\n\n--------Pyannote--------")
......@@ -388,9 +388,46 @@ def model_pred(prediction_yaml,
return pred_history,scores
else:
return pred_history
res_path="{}metrics/".format(exp_dir)
out_rttm = "{}rttm/{:s}_in_{:d}_out_{:d}_on_{:d}_off_{:d}.rttm".format(res_path,task,int(th_in*100),int(th_out*100),int(min_on*1e3),int(min_off*1e-3),)
resf = open(out_rttm,'w')
for show in pred_history:
pred = show["pred"]
uem = uem_dir
create_rttm(resf,pred,show['show'],min_on=min_on,min_off=min_off,uem=uem,show=show["show"])
return pred_history,0
'''
def metrics(predict_per_show,resfpath,min_on,min_off,compute_ap_score=False,uem=None):
resf = open(resfpath,'w')
macrotp = 0
macrotn = 0
macrofp = 0
macrofn = 0
ap=0.0
for show in predict_per_show:
true = show["target"]
uem_mask = get_uem_mask(uem,show["show"],true)
pred = show["pred"]
truem = true[uem_mask]
predm = pred[uem_mask]
tp = numpy.sum(truem * predm)
tn = numpy.sum((1 - truem) * (1 - predm))
fp = numpy.sum((1 - truem) * predm)
fn = numpy.sum(truem * (1 - predm))
macrotp += tp
macrotn += tn
macrofp += fp
macrofn += fn
epsilon = 1e-7
if compute_ap_score:
raw_pred=show["raw_pred"]
ap+=average_precision_score(true,raw_pred,)
create_rttm(resf,pred,show['show'],min_on=min_on,min_off=min_off,uem=uem,show=show["show"])
'''
def predict(batch_size,
validation_loader,
model, # TODO (2021/03/26) define synthetic overlap ratio to guarantee
......@@ -405,7 +442,7 @@ def predict(batch_size,
show=None,
first = True,
out_dir=None,
):
"""
A MODIFIER POU NE PRENDRE QUE LE NOM DU FICHIER WAV ET CRÉER LE DATA LOADER À L'INTERIEUR
......@@ -423,18 +460,27 @@ def predict(batch_size,
output_target = []
output_idx = []
done = first
islabel = False
if first:
sm = torch.nn.Softmax(dim=2)
with torch.no_grad():
for batch_idx, (win_idx, data, target) in enumerate(validation_loader):
target = target.squeeze().cpu().numpy()
for batch_idx, data_full in enumerate(validation_loader):
if len(data_full) ==3:
win_idx, data, target = data_full
islabel=True
else:
win_idx,data = data_full
islabel = False
if islabel:
target = target.squeeze().cpu().numpy()
output = sm(model(data.to(device))).cpu().numpy()
del(data)
for ii in range(output.shape[0]):
output_data.append(output[ii])
output_target.append(target[ii])
if islabel:
output_target.append(target[ii])
output_idx.append(int(win_idx[ii]))
# Unfold outputs by averaging sliding windows
final_output,final_target = multi_label_combination(output_idx,
output_target,
......@@ -445,11 +491,15 @@ def predict(batch_size,
raw_output = final_output[:,1]
if out_dir:
pickle.dump(final_output,open(out_dir+f"pred_{show}.pkl","wb"))
pickle.dump(final_target,open(out_dir+f"target_{show}.pkl","wb"))
if islabel:
pickle.dump(final_target,open(out_dir+f"target_{show}.pkl","wb"))
else:
if out_dir:
final_output = pickle.load(open(out_dir+f"pred_{show}.pkl","rb"))
final_target = pickle.load(open(out_dir+f"target_{show}.pkl","rb"))
if islabel:
final_target = pickle.load(open(out_dir+f"target_{show}.pkl","rb"))
else:
final_target = []
else:
raise Exception("No path where to find previously computed predictions !")
......@@ -544,11 +594,12 @@ def multi_label_combination(output_idx, output_target, output_data, shift, outpu
Author: Anthony Larcher
"""
islabel = len(output_target)>0
win_shift = int(shift * output_rate)
# Initialize the size of final_output
final_output = numpy.zeros((win_shift * (len(output_data) - 1) + output_data[0].shape[0], output_data[0].shape[1]))
final_target = numpy.zeros(win_shift * (len(output_data) - 1) + output_data[0].shape[0])
overlaping_label_count = numpy.zeros(final_output.shape)
......@@ -556,19 +607,27 @@ def multi_label_combination(output_idx, output_target, output_data, shift, outpu
tmp = numpy.ones(output_data[0].shape)
# Loop on the overlaping windows
for idx, tmp_t, tmp_d in zip(output_idx, output_target, output_data):
start_idx = win_shift * idx
stop_idx = start_idx + win_len
if islabel:
for idx, tmp_t, tmp_d in zip(output_idx, output_target, output_data):
start_idx = win_shift * idx
stop_idx = start_idx + win_len
overlaping_label_count[start_idx: stop_idx, :] += tmp
final_output[start_idx: stop_idx, :] += tmp_d
final_target[start_idx: stop_idx] += tmp_t
else:
for idx, tmp_d in zip(output_idx, output_data):
start_idx = win_shift * idx
stop_idx = start_idx + win_len
overlaping_label_count[start_idx: stop_idx, :] += tmp
final_output[start_idx: stop_idx, :] += tmp_d
overlaping_label_count[start_idx: stop_idx, :] += tmp
final_output[start_idx: stop_idx, :] += tmp_d
final_target[start_idx: stop_idx] += tmp_t
# Divide by the number of overlapping values
raw_output = final_output
final_output /= overlaping_label_count
final_target /= overlaping_label_count[:, 0].squeeze()
if islabel:
final_target /= overlaping_label_count[:, 0].squeeze()
return final_output,final_target
......@@ -104,11 +104,13 @@ def prepare_loaders(dataset_yaml, logger=None, rng_=None, seed=1234):
task=task,
rng = rng)
elif sampler_typ == "random":
print(f"TASK : {task} , OVART: {art_ola_ratio}")
sampler = SeqSetRandomSampler(batch_size=batch_size,
batch_num=batch_num,
list_file=file_list,
seg_set=seg_set,
task=task,
mode="train",
artificial_ov_ratio=art_ola_ratio,
rng = rng)
else:
......@@ -126,6 +128,8 @@ def prepare_loaders(dataset_yaml, logger=None, rng_=None, seed=1234):
batch_num=batch_num,
list_file=eval_file_list,
seg_set=eval_seg_set,
task=task,
mode="eval",
rng = rng)
eval_loader = DataLoader(evaluation_set,
......
......@@ -156,12 +156,14 @@ class SeqToSeq(torch.nn.Module):
self.feature_size = cfg["feature_size"]
self.samplerate = cfg["samplerate"]
self.channel_number = cfg["channel_number"]
self.sum_channels = False
self.is_mfcc = False
# pre-processing layers
pre_processing_layers = []
for k in cfg["pre_processing"].keys():
if k.startswith("mfcc"):
self.is_mfcc = True
n_fft = cfg["pre_processing"][k]["n_fft"]
win_length = cfg["pre_processing"][k].get("win_length",480)
hop_length = cfg["pre_processing"][k]["win_shift"]
......@@ -190,6 +192,7 @@ class SeqToSeq(torch.nn.Module):
self.feature_size = n_mels // stride[0]
input_size = self.feature_size
if k.startswith("mono_mel"):
self.is_mfcc=True
pre_processing_layers.append((k,MelSpec(samplerate=self.samplerate,
conf=cfg["pre_processing"][k])))
input_size = self.feature_size
......@@ -204,6 +207,7 @@ class SeqToSeq(torch.nn.Module):
input_size=self.feature_size
# In case of multchannel input, tConv is a learnable filter and sum beamforming
if k.startswith("tConv"):
self.sum_channels = True
window_length = cfg["pre_processing"][k]["window_length"]
hop_length = cfg["pre_processing"][k]["hop_length"]
# number of FIR filters applied to each channel of the array
......@@ -282,50 +286,57 @@ class SeqToSeq(torch.nn.Module):
post_processing_activation = torch.nn.Tanh()
post_processing_layers = []
for k in cfg["post_processing"].keys():
self.post_processing_fl = False
if "post_processing" in cfg:
self.post_processing_fl = True
for k in cfg["post_processing"].keys():
if k.startswith("lin"):
post_processing_layers.append((k, torch.nn.Linear(input_size,
cfg["post_processing"][k]["output"])))
input_size = cfg["post_processing"][k]["output"]
if k.startswith("lin"):
post_processing_layers.append((k, torch.nn.Linear(input_size,
cfg["post_processing"][k]["output"])))
input_size = cfg["post_processing"][k]["output"]
elif k.startswith("activation"):
post_processing_layers.append((k, post_processing_activation))
elif k.startswith("activation"):
post_processing_layers.append((k, post_processing_activation))
elif k.startswith('batch_norm'):
post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('batch_norm'):
post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
elif k.startswith('dropout'):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
self.post_processing.apply(self._init_weights)
self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
self.post_processing.apply(self._init_weights)
self.output_size = input_size
def forward(self, inputs):
"""
:param inputs: raw audio signal
:param inputs:
:return:
"""
if self.center:
if self.sum_channels:
nch=inputs.shape[1]
inputs = inputs.sum(dim=1, keepdim=True)/nch
if self.is_mfcc and self.center:
x = self.pre_processing(inputs[:,:,:-1])
elif self.is_wavlm:
#inputs = inputs.unfold(-1,self.wavlm_win_length, self.wavlm_hop_length)
# print('in wavelm shape before squeeze',inputs.shape)
inputs = inputs.squeeze(dim=1)
# print('in wavelm shape after squeeze',inputs.shape)
x = self.pre_processing(inputs)
# print("out wavelm",x.shape)
else:
x=inputs
x = self.pre_processing(inputs)
# remove energy
if len(x.shape) == 4:
x = torch.squeeze(x[:,:,:])
x = self.sequence_to_sequence(x.permute(0,2,1))
x = self.post_processing(x)
if self.is_lstm:
x = x.permute(0,2,1)
x = self.sequence_to_sequence(x)
if self.post_processing_fl:
x = self.post_processing(x)
else:
x = x.permute(0,2,1)
return x
def get_features(self,inputs):
......@@ -341,11 +352,11 @@ class SeqToSeq(torch.nn.Module):
return x, w_comb, w_att
elif self.is_wavlm:
#inputs = inputs.unfold(-1,self.wavlm_win_length, self.wavlm_hop_length)
print('in wavelm shape before squeeze',inputs.shape)
#print('in wavelm shape before squeeze',inputs.shape)
inputs = inputs.squeeze(dim=1)
print('in wavelm shape after squeeze',inputs.shape)
#print('in wavelm shape after squeeze',inputs.shape)
x = self.pre_processing(inputs)
print("out wavelm",x.shape)
#print("out wavelm",x.shape)
return x
else:
x = self.pre_processing(inputs)
......
......@@ -108,36 +108,30 @@ class SeqSetRandomSampler(torch.utils.data.Sampler):
idx_iter = list()
segments = shelve.open(seg_set)
for show in tqdm.tqdm(
self.show_list, desc="Sampler initialization,first pass", unit="show"
self.show_list, desc="Sampler initialization,first pass", unit="show"
):
#print(list(segments.keys()))
seg_tags = segments.get(show)
try:
seg_tags = segments.get(show)
except:
print(segments[show])
for ii in range(len(seg_tags)):
seg = seg_tags[ii]
ovseg=None
if mode=="train" and task=="overlap":
if rng.random()<artificial_ov_ratio:
ovseg = seg_tags[rng.integers(len(seg_tags))]
for seg in seg_tags:
idx_iter.append({'seg':seg,'overlap':None})
idx_iter.append({'seg':seg,'overlap':ovseg})
self.index_iterator = numpy.array(idx_iter)
self.length = batch_num * batch_size
segments.close()
def __iter__(self):
"""
:return:
"""
self.rng.shuffle(self.index_iterator)
self.iter = self.index_iterator[: self.length]
return iter(self.iter)
def __len__(self) -> int:
"""
:return:
"""
return self.length
......@@ -630,14 +624,14 @@ class SeqSet(torch.utils.data.Dataset):
for show in segments:
crnt_set = segments.get(show)
self.time_base_start[show] = crnt_set[0]["start"] #centisenconds
# labels augmentation for speaker turn detection
if not "labelling" in dataset_params.keys():
self.collar_duration = 0.125
else:
self.collar_duration = dataset_params["labelling"]["collar_duration"]
# lire les fichiers contenant les t_start par show pour construire une liste de segments
self.duration = numpy.ceil(
dataset_params[mode]["duration"] * self.audio_fr
......@@ -658,6 +652,7 @@ class SeqSet(torch.utils.data.Dataset):
self.transformation = dataset_params["eval"]["transformation"]
self.transform = dict()
self.spec_aug = False
if (self.transformation["pipeline"] != "") and (
self.transformation["pipeline"] is not None
):
......@@ -708,54 +703,43 @@ class SeqSet(torch.utils.data.Dataset):
"stop":float(segarr[4]),
}
idx_start = numpy.round(seg["start"] / 100.0 * self.audio_fr).astype(int)
idx_start = numpy.round(seg["start"] / 100.0 * self.audio_fr-numpy.ceil(self.context/2)).astype(int)
waveform = waveform_loader(self.wav_dir+seg["show"]+".wav",
idx_start=idx_start,
seg_len=self.duration,
context=self.context)
# load audio waveform (dim = (channels,length))
# normalization is applied here
if self.mod == 'xvector':
pass
# Extract the segment from a pretrained xvector file
else:
waveform, speech_fs = torchaudio.load(
filepath=self.wav_dir + seg["show"] + ".wav",
frame_offset=idx_start,
num_frames=self.duration,
channels_first=True,
normalize=True
)
#waveform = normalize(waveform)
# is the signal mono or not ?
is_multichannel = waveform.shape[0] > 1
# is the signal mono or not ?
is_multichannel = waveform.shape[0] > 1
# add low energy noise to avoid zero values
waveform += 1e-6 * torch.randn(waveform.shape[0],waveform.shape[1])
# data augmentation if needed
# add low energy noise to avoid zero values
waveform += 1e-6 * torch.randn(waveform.shape[0],waveform.shape[1])
# data augmentation if needed
if self.transform and not is_multichannel:
waveform = data_augmentation(
waveform,
self.audio_fr,
self.transform,
self.transform_number,
noise_df=self.noise_df,
rir_df=self.rir_df,
babble_noise = self.babble_noise
)
if self.transform and not is_multichannel:
waveform = data_augmentation(
waveform,
speech_fs,
# in case of multichannel signal, each channel has to be transformed
elif self.transform and is_multichannel:
for ii in range(waveform.shape[0]):
waveform[ii,:] = data_augmentation(
waveform[ii,:].unsqueeze(0),
self.audio_fr,
self.transform,
self.transform_number,
noise_df=self.noise_df,
rir_df=self.rir_df,
babble_noise = self.babble_noise
)
if waveform is None:
print(seg["show"],seg["start"] / 100.0,seg["stop"] / 100.0)
# in case of multichannel signal, each channel has to be transformed
elif self.transform and is_multichannel:
for ii in range(waveform.shape[0]):
waveform[ii,:] = data_augmentation(
waveform[ii,:],
speech_fs,
self.transform,
self.transform_number,
noise_df=self.noise_df,
rir_df=self.rir_df,
babble_noise = self.babble_noise
)
if self.spec_aug:
waveform = self.spec_aug(waveform)
......@@ -766,15 +750,16 @@ class SeqSet(torch.utils.data.Dataset):
with h5py.File(self.label_set, "r") as data:
crnt_label = data[seg["show"]]["total"][:, start:stop]
expected_frames_num = int(self.duration/self.audio_fr * self.output_fr)
if crnt_label.shape[1]<expected_frames_num:
crnt_label = numpy.pad(crnt_label,[(0,0),(0,expected_frames_num - crnt_label.shape[1])])
if crnt_label.shape[1] > expected_frames_num:
crnt_label = crnt_label[:,:expected_frames_num]
# crnt_label = numpy.ones((stop-start,))
if self.task == "vad":
output_label = (crnt_label > 0).astype(numpy.long)
# may probably be optimized...
elif self.task == "spk_turn":
label = numpy.zeros_like(crnt_label)
label[:,:-1] = (numpy.abs(crnt_label[:,:-1] - crnt_label[:,1:]) > 0).astype(
......@@ -787,8 +772,9 @@ class SeqSet(torch.utils.data.Dataset):
output_label = numpy.convolve(conv_filt, label.squeeze(), mode='same')
output_label = (numpy.expand_dims(output_label,axis=0)>=1).astype(numpy.long)
elif "ov" in self.task:
elif self.task == "overlap":
# batch overlap ratio
if struct['overlap'] is not None:
# Loading artificial overlap
......@@ -798,27 +784,29 @@ class SeqSet(torch.utils.data.Dataset):
"show":ov_segarr[0],
"stop":float(ov_segarr[4]),
}
idx_start_ov = numpy.round(ov_seg["start"] / 100.0 * self.audio_fr).astype(int)
frame_count_ov = self.duration
idx_start_ov = numpy.round(ov_seg["start"] / 100.0 * self.audio_fr-numpy.ceil(self.context/2)).astype(int)
waveform_ov = waveform_loader(self.wav_dir+ov_seg["show"]+".wav",
idx_start=idx_start_ov,
seg_len=frame_count_ov,
context=self.context)
waveform_ov, _ = torchaudio.load(
filepath=self.wav_dir + ov_seg["show"] + ".wav",
frame_offset=idx_start_ov,
num_frames=frame_count_ov,
channels_first=True,
)
start_ov = numpy.round((ov_seg["start"] - self.time_base_start[ov_seg["show"]])/ 100.0 * self.output_fr).astype(int)
stop_ov = numpy.round((ov_seg["stop"] - self.time_base_start[ov_seg["show"]])/ 100.0 * self.output_fr ).astype(int)
start_ov = numpy.round(ov_seg["start"] / 100.0 * self.output_fr).astype(int)
stop_ov = numpy.round(ov_seg["stop"] / 100.0 * self.output_fr).astype(int)
with h5py.File(self.label_set, "r") as data:
label_ov = data[seg["show"]]["total"][:, start_ov:stop_ov]
label_ov = data[ov_seg["show"]]["total"][:, start_ov:stop_ov]
if label_ov.shape[1]<expected_frames_num:
label_ov = numpy.pad(label_ov,[(0,0),(0,expected_frames_num - label_ov.shape[1])])
if label_ov.shape[1] > expected_frames_num:
label_ov = label_ov[:,:expected_frames_num]
speech_power = waveform.norm(p=2)
noise_power = waveform_ov.norm(p=2)
snr_db = 10 * self.rng.random()+1
#snr_db=0
snr = 10 ** (snr_db / 20)
scale = snr * noise_power / speech_power
......@@ -833,11 +821,10 @@ class SeqSet(torch.utils.data.Dataset):
else:
raise NotImplementedError()
if torch.isnan(waveform).any():
print("Waveform NAN !!!")
# self.logger.segmentslog(seg['show'],seg['start']/100,seg['stop']/100,(struct['overlap'] is not None))
return waveform, torch.from_numpy(output_label).T