Commit 4de1fd1a authored by Martin Lebourdais's avatar Martin Lebourdais
Browse files

working with 0,67 f-score on dihard_dev_01

parent e164e364
import torch import torch
import torch.nn as nn import torch.nn as nn
class ConcordanceCorCoeff(nn.Module): class ConcordanceCorCoeff(nn.Module):
def __init__(self): def __init__(self):
super(ConcordanceCorCoeff, self).__init__() super(ConcordanceCorCoeff, self).__init__()
self.mean = torch.mean self.mean = torch.mean
self.var = torch.var self.var = torch.var
self.sum = torch.sum self.sum = torch.sum
self.sqrt = torch.sqrt self.sqrt = torch.sqrt
self.std = torch.std self.std = torch.std
def forward(self, prediction, ground_truth): self.softmax = nn.Softmax(dim=1)
mean_gt = self.mean (ground_truth, 0) def forward(self, prediction, ground_truth):
mean_pred = self.mean (prediction, 0) temp1 = torch.abs(prediction)
var_gt = self.var (ground_truth, 0) prediction = torch.argmax(temp1,dim=1).float()
var_pred = self.var (prediction, 0)
v_pred = prediction - mean_pred ground_truth = ground_truth.float()
v_gt = ground_truth - mean_gt mean_gt = self.mean (ground_truth, 0)
cor = self.sum (v_pred * v_gt) / (self.sqrt(self.sum(v_pred ** 2)) * self.sqrt(self.sum(v_gt ** 2))) mean_pred = self.mean (prediction, 0)
sd_gt = self.std(ground_truth) var_gt = self.var (ground_truth, 0)
sd_pred = self.std(prediction) var_pred = self.var (prediction, 0)
numerator=2*cor*sd_gt*sd_pred v_pred = prediction - mean_pred
denominator=var_gt+var_pred+(mean_gt-mean_pred)**2 v_gt = ground_truth - mean_gt
ccc = numerator/denominator cor = self.sum (v_pred * v_gt) / (self.sqrt(self.sum(v_pred ** 2)) * self.sqrt(self.sum(v_gt ** 2)))
return 1-ccc sd_gt = self.std(ground_truth)
sd_pred = self.std(prediction)
numerator=2*cor*sd_gt*sd_pred
denominator=var_gt+var_pred+(mean_gt-mean_pred)**2
ccc = numerator/denominator
return torch.mean(1-ccc)
...@@ -468,8 +468,8 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device): ...@@ -468,8 +468,8 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
""" """
model.to(device) model.to(device)
model.train() model.train()
#criterion = torch.nn.CrossEntropyLoss(reduction='mean') criterion = torch.nn.CrossEntropyLoss(reduction='mean')
criterion = ConcordanceCorCoeff() #criterion = ccc_loss
recall = 0.0 recall = 0.0
accuracy = 0.0 accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader): for batch_idx, (data, target) in enumerate(training_loader):
...@@ -505,8 +505,62 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device): ...@@ -505,8 +505,62 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
)) ))
return model return model
def pearsonr(x, y):
mean_x = torch.mean(x)
mean_y = torch.mean(y)
xm = x.sub(mean_x)
ym = y.sub(mean_y)
r_num = xm.dot(ym)
r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
r_val = r_num / r_den
return r_val
def ccc_loss(pred,gt):
batch_size=64
temp1 = torch.argmax(torch.abs(pred),dim=1)
acc_ccc = 0
for n in range(batch_size):
curr_pred = temp1[:,n].float()
curr_gt = gt[:,n].float()
true_m = torch.mean(curr_gt)
true_var = torch.var(curr_gt)
pred_m = torch.mean(curr_pred)
pred_var = torch.var(curr_pred)
rho = pearsonr(curr_pred,curr_gt)
std_pred = torch.std(curr_pred)
std_true = torch.std(curr_gt)
ccc = (
2
* rho
* std_true
* std_pred
/ (std_pred ** 2 + std_true ** 2 + (pred_m - true_m) **2)
)
acc_ccc+=(1-ccc)
return (acc_ccc/batch_size)
'''
def llincc(x, y):
true_m = np.mean(y)
true_var = np.var(y)
pred_m = np.mean(x)
pred_var = np.var(x)
rho, _ = pearsonr(x, y)
std_pred = np.std(x)
std_true = np.std(y)
ccc = (
2
* rho
* std_true
* std_pred
/ (std_pred ** 2 + std_true ** 2 + (pred_m - true_m) ** 2)
)
return ccc:
'''
def cross_validation(model, validation_loader, device): def cross_validation(model, validation_loader, device):
""" """
......
...@@ -207,7 +207,7 @@ def get_segment_label(label, ...@@ -207,7 +207,7 @@ def get_segment_label(label,
output_label = numpy.convolve(conv_filt, spk_change, mode='same') output_label = numpy.convolve(conv_filt, spk_change, mode='same')
elif mode == "overlap": elif mode == "overlap":
output_label = (overlaps > 0.5).astype(numpy.long) output_label = (overlaps > 1).astype(numpy.long)
else: else:
raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'") raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment