Commit 218ade91 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

running VAD, miss test mode

parent 6ebc4ab4
...@@ -443,6 +443,8 @@ def cross_validation(model, validation_loader, device): ...@@ -443,6 +443,8 @@ def cross_validation(model, validation_loader, device):
target = target.permute(1, 0) target = target.permute(1, 0)
nbpoint = output.shape[0] nbpoint = output.shape[0]
loss = criterion(output, target.to(device))
rc, pr, acc = calc_recall(output.data, target, device) rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item() recall += rc.item()
precision += pr.item() precision += pr.item()
...@@ -502,3 +504,6 @@ def calc_recall(output,target,device): ...@@ -502,3 +504,6 @@ def calc_recall(output,target,device):
acc/=len(y_trueb[0]) acc/=len(y_trueb[0])
return rc,pr,acc return rc,pr,acc
...@@ -281,19 +281,19 @@ def seqSplit(mdtm_dir, ...@@ -281,19 +281,19 @@ def seqSplit(mdtm_dir,
# Get the borders of the segments (not the start of the first and not the end of the last # Get the borders of the segments (not the start of the first and not the end of the last
# Check the length of audio # Check the length of audio
nfo = soundfile.info(wav_dir + mdtm_file[len(mdtm_dir):].split(".")[0] + ".wav") nfo = soundfile.info(wav_dir + str(mdtm_file)[len(mdtm_dir):].split(".")[0] + ".wav")
# For each border time B get a segment between B - duration and B + duration # For each border time B get a segment between B - duration and B + duration
# in which we will pick up randomly later # in which we will pick up randomly later
for idx, seg in enumerate(ref.segments): for idx, seg in enumerate(ref.segments):
if seg["start"] / 100. > (duration / 2.) and seg["start"] + (duration / 2.) < nfo.duration:
if seg["start"] / 100. > duration and seg["start"] / 100. + duration < nfo.duration:
segment_list.append(show=seg['show'], segment_list.append(show=seg['show'],
cluster="", cluster="",
start=float(seg["start"]) / 100. - duration, start=float(seg["start"]) / 100. - duration,
stop=float(seg["start"]) / 100. + duration) stop=float(seg["start"]) / 100. + duration)
if seg["stop"] / 100. > (duration / 2.) and seg["stop"] + (duration / 2.) < nfo.duration: if seg["stop"] / 100. > duration and seg["stop"] / 100. + duration < nfo.duration:
segment_list.append(show=seg['show'], segment_list.append(show=seg['show'],
cluster="",
cluster="", cluster="",
start=float(seg["stop"]) / 100. - duration, start=float(seg["stop"]) / 100. - duration,
stop=float(seg["stop"]) / 100. + duration) stop=float(seg["stop"]) / 100. + duration)
...@@ -439,6 +439,7 @@ def create_train_val_seqtoseq(dataset_yaml): ...@@ -439,6 +439,7 @@ def create_train_val_seqtoseq(dataset_yaml):
# Read all MDTM files and ouptut a list of segments with minimum duration as well as a speaker dictionary # Read all MDTM files and ouptut a list of segments with minimum duration as well as a speaker dictionary
segment_list, speaker_dict = seqSplit(mdtm_dir=dataset_params["mdtm_dir"], segment_list, speaker_dict = seqSplit(mdtm_dir=dataset_params["mdtm_dir"],
wav_dir=dataset_params["wav_dir"],
duration=dataset_params["train"]["duration"]) duration=dataset_params["train"]["duration"])
split_idx = numpy.random.choice([True, False], split_idx = numpy.random.choice([True, False],
...@@ -452,6 +453,9 @@ def create_train_val_seqtoseq(dataset_yaml): ...@@ -452,6 +453,9 @@ def create_train_val_seqtoseq(dataset_yaml):
else: else:
segment_list_val.append_seg(seg) segment_list_val.append_seg(seg)
print(f"Length of training: {len(segment_list_train)}\n validation: {len(segment_list_val)}")
# Split the list of segment between training and validation sets # Split the list of segment between training and validation sets
train_set = SeqSet(wav_dir=dataset_params["wav_dir"], train_set = SeqSet(wav_dir=dataset_params["wav_dir"],
mdtm_dir=dataset_params["mdtm_dir"], mdtm_dir=dataset_params["mdtm_dir"],
...@@ -477,4 +481,7 @@ def create_train_val_seqtoseq(dataset_yaml): ...@@ -477,4 +481,7 @@ def create_train_val_seqtoseq(dataset_yaml):
output_framerate=dataset_params["output_rate"], output_framerate=dataset_params["output_rate"],
transform_pipeline=dataset_params["eval"]["transformation"]["pipeline"]) transform_pipeline=dataset_params["eval"]["transformation"]["pipeline"])
print(f"train segs: {train_set.__len__()}, validation: {validation_set.__len__()}")
return train_set, validation_set return train_set, validation_set
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment