Commit c5afedbc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug vad

parent ef259deb
......@@ -82,7 +82,7 @@ class BLSTM(torch.nn.Module):
super(BLSTM, self).__init__()
self.input_size = input_size
self.blstm_sizes = blstm_sizes
self.output_size = blstm_sizes[0] * 2
self.output_size = blstm_sizes * 2
self.blstm_layers = torch.nn.LSTM(input_size,
blstm_sizes,
......@@ -373,7 +373,7 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
model.to(device)
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor([0.1,0.9]))
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor([0.1,0.9]).to(device))
recall = 0.0
precision = 0.0
......
......@@ -94,7 +94,6 @@ def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
def mdtm_to_label(mdtm_filename,
mode,
start_time,
stop_time,
sample_number,
......@@ -122,7 +121,7 @@ def mdtm_to_label(mdtm_filename,
diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']
# Create the empty labels
label = list(numpy.zeros(sample_number, dtype=int))
label = []
# Compute the time stamp of each sample
time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
......@@ -136,6 +135,11 @@ def mdtm_to_label(mdtm_filename,
if seg['start'] / 100. <= time <= seg['stop'] / 100.:
lbls.append(speaker_dict[seg['cluster']])
if len(lbls) > 0:
label.append(lbls)
else:
label.append([])
return label
......@@ -393,8 +397,7 @@ class SeqSet(Dataset):
start_time=start,
stop_time=start + self.duration,
sample_number=sig.shape[1],
speaker_dict=self.speaker_dict,
overlap=self.mode=="overlap")
speaker_dict=self.speaker_dict)
label = process_segment_label(label=tmp_label,
mode=self.mode,
......@@ -432,7 +435,7 @@ def create_train_val_seqtoseq(dataset_yaml):
# Read all MDTM files and ouptut a list of segments with minimum duration as well as a speaker dictionary
segment_list, speaker_dict = seqSplit(mdtm_dir=dataset_params["mdtm_dir"],
duration=dataset_params["duration"])
duration=dataset_params["train"]["duration"])
split_idx = numpy.random.choice([True, False],
size=(len(segment_list),),
......@@ -448,26 +451,26 @@ def create_train_val_seqtoseq(dataset_yaml):
# Split the list of segment between training and validation sets
train_set = SeqSet(wav_dir=dataset_params["wav_dir"],
mdtm_dir=dataset_params["mdtm_dir"],
mode=dataset_param["mode"],
mode=dataset_params["mode"],
segment_list=segment_list_train,
speaker_dict=speaker_dict,
duration=dataset_param["train"]["duration"],
filter_type=dataset_param["filter_type"],
collar_duration=dataset_param["collar_duration"],
audio_framerate=dataset_param["sample_rate"],
output_framerate=dataset_param["output_rate"],
transform_pipeline=dataset_param["train"]["transformation"]["pipeline"])
duration=dataset_params["train"]["duration"],
filter_type=dataset_params["filter_type"],
collar_duration=dataset_params["collar_duration"],
audio_framerate=dataset_params["sample_rate"],
output_framerate=dataset_params["output_rate"],
transform_pipeline=dataset_params["train"]["transformation"]["pipeline"])
validation_set = SeqSet(wav_dir=dataset_params["wav_dir"],
mdtm_dir=dataset_params["mdtm_dir"],
mode=dataset_param["mode"],
mode=dataset_params["mode"],
segment_list=segment_list_val,
speaker_dict=speaker_dict,
duration=dataset_param["eval"]["duration"],
filter_type=dataset_param["filter_type"],
collar_duration=dataset_param["collar_duration"],
audio_framerate=dataset_param["sample_rate"],
output_framerate=dataset_param["output_rate"],
transform_pipeline=dataset_param["eval"]["transformation"]["pipeline"])
duration=dataset_params["eval"]["duration"],
filter_type=dataset_params["filter_type"],
collar_duration=dataset_params["collar_duration"],
audio_framerate=dataset_params["sample_rate"],
output_framerate=dataset_params["output_rate"],
transform_pipeline=dataset_params["eval"]["transformation"]["pipeline"])
return train_set, validation_set
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment