Commit e164e364 authored by Martin Lebourdais's avatar Martin Lebourdais
Browse files

Adding new loss

parent f9db7c36
import torch
import torch.nn as nn
class ConcordanceCorCoeff(nn.Module):
def __init__(self):
super(ConcordanceCorCoeff, self).__init__()
self.mean = torch.mean
self.var = torch.var
self.sum = torch.sum
self.sqrt = torch.sqrt
self.std = torch.std
def forward(self, prediction, ground_truth):
mean_gt = self.mean (ground_truth, 0)
mean_pred = self.mean (prediction, 0)
var_gt = self.var (ground_truth, 0)
var_pred = self.var (prediction, 0)
v_pred = prediction - mean_pred
v_gt = ground_truth - mean_gt
cor = self.sum (v_pred * v_gt) / (self.sqrt(self.sum(v_pred ** 2)) * self.sqrt(self.sum(v_gt ** 2)))
sd_gt = self.std(ground_truth)
sd_pred = self.std(prediction)
numerator=2*cor*sd_gt*sd_pred
denominator=var_gt+var_pred+(mean_gt-mean_pred)**2
ccc = numerator/denominator
return 1-ccc
......@@ -39,6 +39,7 @@ from sklearn.model_selection import train_test_split
from torch import optim
from torch.utils.data import Dataset
from .loss import ConcordanceCorCoeff
from .wavsets import SeqSet
from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader
......@@ -224,8 +225,6 @@ class SeqToSeq(nn.Module):
def seqTrain(dataset_yaml,
val_dataset_yaml,
norm_dataset_yaml,
over_dataset_yaml,
model_yaml,
mode,
epochs=100,
......@@ -242,7 +241,9 @@ def seqTrain(dataset_yaml,
output_rate=100,
batch_size=32,
log_interval=10,
num_thread=10
num_thread=10,
non_overlap_dataset = None,
overlap_dataset = None
):
"""
......@@ -291,9 +292,8 @@ def seqTrain(dataset_yaml,
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
_wav_dir=dataset_params['wav_dir']
_mdtm_dir=dataset_params['mdtm_dir']
torch.manual_seed(dataset_params['seed'])
training_set_norm = SeqSet(norm_dataset_yaml,
training_set = SeqSet(dataset_yaml,
wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir,
mode=mode,
......@@ -303,17 +303,7 @@ def seqTrain(dataset_yaml,
audio_framerate=framerate,
output_framerate=output_rate,
transform_pipeline="MFCC")
training_set_overlap = SeqSet(over_dataset_yaml,
wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir,
mode=mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=framerate,
output_framerate=output_rate,
transform_pipeline="add_noise,MFCC")
training_set = torch.utils.data.ConcatDataset([training_set_norm,training_set_overlap])
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
......@@ -441,7 +431,8 @@ def calc_recall(output,target,device):
y_predb = output
rc = 0.0
pr = 0.0
for b in range(64):
batch_size = y_trueb.shape[1]
for b in range(batch_size):
y_true = y_trueb[:,b]
y_pred = y_predb[:,:,b]
assert y_true.ndim == 1
......@@ -477,12 +468,16 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
model.to(device)
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
#criterion = torch.nn.CrossEntropyLoss(reduction='mean')
criterion = ConcordanceCorCoeff()
recall = 0.0
accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
# tnumpy = target.numpy()
# print(tnumpy.shape)
# print(sum(tnumpy)/(tnumpy.shape[1]))
optimizer.zero_grad()
data = data.to(device)
......
......@@ -135,10 +135,10 @@ def mdtm_to_label(mdtm_filename,
for ii,i in enumerate(numpy.linspace(start_time,stop_time,num=sample_number)):
cnt = 0
for d in diarization.segments:
# print(d)
# print(d['start'],d['stop'])
if d['start']/100 <= i <= d['stop']/100:
cnt+=1
if over and cnt==0:
cnt=2
overlaps[ii]=cnt
# Find the label of the
# first sample
......@@ -259,7 +259,7 @@ def process_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap":
output_label = (overlaps>0.5).astype(numpy.long)
output_label = (overlaps>1).astype(numpy.long)
else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment