Commit e164e364 authored by Martin Lebourdais's avatar Martin Lebourdais
Browse files

Adding new loss

parent f9db7c36
import torch
import torch.nn as nn
class ConcordanceCorCoeff(nn.Module):
def __init__(self):
super(ConcordanceCorCoeff, self).__init__()
self.mean = torch.mean
self.var = torch.var
self.sum = torch.sum
self.sqrt = torch.sqrt
self.std = torch.std
def forward(self, prediction, ground_truth):
mean_gt = self.mean (ground_truth, 0)
mean_pred = self.mean (prediction, 0)
var_gt = self.var (ground_truth, 0)
var_pred = self.var (prediction, 0)
v_pred = prediction - mean_pred
v_gt = ground_truth - mean_gt
cor = self.sum (v_pred * v_gt) / (self.sqrt(self.sum(v_pred ** 2)) * self.sqrt(self.sum(v_gt ** 2)))
sd_gt = self.std(ground_truth)
sd_pred = self.std(prediction)
numerator=2*cor*sd_gt*sd_pred
denominator=var_gt+var_pred+(mean_gt-mean_pred)**2
ccc = numerator/denominator
return 1-ccc
...@@ -39,6 +39,7 @@ from sklearn.model_selection import train_test_split ...@@ -39,6 +39,7 @@ from sklearn.model_selection import train_test_split
from torch import optim from torch import optim
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .loss import ConcordanceCorCoeff
from .wavsets import SeqSet from .wavsets import SeqSet
from sidekit.nnet.sincnet import SincNet from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -224,8 +225,6 @@ class SeqToSeq(nn.Module): ...@@ -224,8 +225,6 @@ class SeqToSeq(nn.Module):
def seqTrain(dataset_yaml, def seqTrain(dataset_yaml,
val_dataset_yaml, val_dataset_yaml,
norm_dataset_yaml,
over_dataset_yaml,
model_yaml, model_yaml,
mode, mode,
epochs=100, epochs=100,
...@@ -242,7 +241,9 @@ def seqTrain(dataset_yaml, ...@@ -242,7 +241,9 @@ def seqTrain(dataset_yaml,
output_rate=100, output_rate=100,
batch_size=32, batch_size=32,
log_interval=10, log_interval=10,
num_thread=10 num_thread=10,
non_overlap_dataset = None,
overlap_dataset = None
): ):
""" """
...@@ -291,29 +292,18 @@ def seqTrain(dataset_yaml, ...@@ -291,29 +292,18 @@ def seqTrain(dataset_yaml,
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"]) training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
_wav_dir=dataset_params['wav_dir'] _wav_dir=dataset_params['wav_dir']
_mdtm_dir=dataset_params['mdtm_dir'] _mdtm_dir=dataset_params['mdtm_dir']
torch.manual_seed(dataset_params['seed']) torch.manual_seed(dataset_params['seed'])
training_set_norm = SeqSet(norm_dataset_yaml, training_set = SeqSet(dataset_yaml,
wav_dir=_wav_dir, wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir, mdtm_dir=_mdtm_dir,
mode=mode, mode=mode,
duration=2., duration=2.,
filter_type="gate", filter_type="gate",
collar_duration=0.1, collar_duration=0.1,
audio_framerate=framerate, audio_framerate=framerate,
output_framerate=output_rate, output_framerate=output_rate,
transform_pipeline="MFCC") transform_pipeline="MFCC")
training_set_overlap = SeqSet(over_dataset_yaml,
wav_dir=_wav_dir,
mdtm_dir=_mdtm_dir,
mode=mode,
duration=2.,
filter_type="gate",
collar_duration=0.1,
audio_framerate=framerate,
output_framerate=output_rate,
transform_pipeline="add_noise,MFCC")
training_set = torch.utils.data.ConcatDataset([training_set_norm,training_set_overlap])
training_loader = DataLoader(training_set, training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"], batch_size=dataset_params["batch_size"],
drop_last=True, drop_last=True,
...@@ -441,7 +431,8 @@ def calc_recall(output,target,device): ...@@ -441,7 +431,8 @@ def calc_recall(output,target,device):
y_predb = output y_predb = output
rc = 0.0 rc = 0.0
pr = 0.0 pr = 0.0
for b in range(64): batch_size = y_trueb.shape[1]
for b in range(batch_size):
y_true = y_trueb[:,b] y_true = y_trueb[:,b]
y_pred = y_predb[:,:,b] y_pred = y_predb[:,:,b]
assert y_true.ndim == 1 assert y_true.ndim == 1
...@@ -477,12 +468,16 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device): ...@@ -477,12 +468,16 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
""" """
model.to(device) model.to(device)
model.train() model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean') #criterion = torch.nn.CrossEntropyLoss(reduction='mean')
criterion = ConcordanceCorCoeff()
recall = 0.0 recall = 0.0
accuracy = 0.0 accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader): for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze() target = target.squeeze()
# tnumpy = target.numpy()
# print(tnumpy.shape)
# print(sum(tnumpy)/(tnumpy.shape[1]))
optimizer.zero_grad() optimizer.zero_grad()
data = data.to(device) data = data.to(device)
......
...@@ -135,10 +135,10 @@ def mdtm_to_label(mdtm_filename, ...@@ -135,10 +135,10 @@ def mdtm_to_label(mdtm_filename,
for ii,i in enumerate(numpy.linspace(start_time,stop_time,num=sample_number)): for ii,i in enumerate(numpy.linspace(start_time,stop_time,num=sample_number)):
cnt = 0 cnt = 0
for d in diarization.segments: for d in diarization.segments:
# print(d)
# print(d['start'],d['stop'])
if d['start']/100 <= i <= d['stop']/100: if d['start']/100 <= i <= d['stop']/100:
cnt+=1 cnt+=1
if over and cnt==0:
cnt=2
overlaps[ii]=cnt overlaps[ii]=cnt
# Find the label of the # Find the label of the
# first sample # first sample
...@@ -259,7 +259,7 @@ def process_segment_label(label, ...@@ -259,7 +259,7 @@ def process_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same') output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap": elif mode == "overlap":
output_label = (overlaps>0.5).astype(numpy.long) output_label = (overlaps>1).astype(numpy.long)
else: else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'") raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment