loss.py 863 Bytes
Newer Older
Martin Lebourdais's avatar
Martin Lebourdais committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
class ConcordanceCorCoeff(nn.Module):

def __init__(self):
    super(ConcordanceCorCoeff, self).__init__()
    self.mean = torch.mean
    self.var = torch.var
    self.sum = torch.sum
    self.sqrt = torch.sqrt
    self.std = torch.std
def forward(self, prediction, ground_truth):
    mean_gt = self.mean (ground_truth, 0)
    mean_pred = self.mean (prediction, 0)
    var_gt = self.var (ground_truth, 0)
    var_pred = self.var (prediction, 0)
    v_pred = prediction - mean_pred
    v_gt = ground_truth - mean_gt
    cor = self.sum (v_pred * v_gt) / (self.sqrt(self.sum(v_pred ** 2)) * self.sqrt(self.sum(v_gt ** 2)))
    sd_gt = self.std(ground_truth)
    sd_pred = self.std(prediction)
    numerator=2*cor*sd_gt*sd_pred
    denominator=var_gt+var_pred+(mean_gt-mean_pred)**2
    ccc = numerator/denominator
    return 1-ccc