Commit 26c1887e authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cross entrpoy weighin

parent 30c6fb21
......@@ -297,7 +297,8 @@ class SeqToSeq(torch.nn.Module):
device,
shift, # former step
output_rate,
threshold, #output_offset=None,
th_in, #output_offset=None,
th_out,
only_lab_generation=False):
"""
A MODIFIER POU NE PRENDRE QUE LE NOM DU FICHIER WAV ET CRÉER LE DATZA LOADER À L'INTERIEUR
......@@ -338,8 +339,16 @@ class SeqToSeq(torch.nn.Module):
output_rate,
mode="mean")
numpy.save("toto.npy",final_output)
vad = numpy.zeros(final_output.shape[0], dtype='bool')
speech = False
ii = 0
while ii < final_output.shape[0]:
if final_output[ii, 1] > th_in and not speech:
speech = True
elif final_output[ii, 1] < th_out and speech:
speech = False
vad[ii] = speech
ii += 1
#if not only_lab_generation:
# final_target = _unfold(final_target, target.cpu().numpy(), shift)
......@@ -366,6 +375,7 @@ class SeqToSeq(torch.nn.Module):
# return 100.0 * accuracy, 100.0 * recall, 100.0 * precision, sad_seq, tp, fp, tn, fn
#else:
# return 0, 0, 0, sad_seq, 0, 0, 0, 0
return vad
def seqTrain(dataset_yaml,
model_yaml,
......@@ -539,7 +549,7 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
model.to(device)
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor([0.1,0.9]).to(device))
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor([0.9,0.1]).to(device))
recall = 0.0
precision = 0.0
......
......@@ -687,7 +687,6 @@ class SeqSetSingle(Dataset):
self.wav_fn = wav_fn
self.mdtm_fn = mdtm_fn
self.uem_fn=uem_fn
self.mode = mode
self.audio_framerate = audio_framerate
self.output_framerate = output_framerate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment