loss.py 1.12 KB
Newer Older
Martin Lebourdais's avatar
Martin Lebourdais committed
1
2
import torch
import torch.nn as nn
3

Martin Lebourdais's avatar
Martin Lebourdais committed
4
5
class ConcordanceCorCoeff(nn.Module):

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    def __init__(self):
        super(ConcordanceCorCoeff, self).__init__()
        self.mean = torch.mean
        self.var = torch.var
        self.sum = torch.sum
        self.sqrt = torch.sqrt
        self.std = torch.std
        self.softmax = nn.Softmax(dim=1)
    def forward(self, prediction, ground_truth):
        temp1 = torch.abs(prediction)
        prediction = torch.argmax(temp1,dim=1).float()

        ground_truth = ground_truth.float()
        mean_gt = self.mean (ground_truth, 0)
        mean_pred = self.mean (prediction, 0)
        var_gt = self.var (ground_truth, 0)
        var_pred = self.var (prediction, 0)
        v_pred = prediction - mean_pred
        v_gt = ground_truth - mean_gt
        cor = self.sum (v_pred * v_gt) / (self.sqrt(self.sum(v_pred ** 2)) * self.sqrt(self.sum(v_gt ** 2)))
        sd_gt = self.std(ground_truth)
        sd_pred = self.std(prediction)
        numerator=2*cor*sd_gt*sd_pred
        denominator=var_gt+var_pred+(mean_gt-mean_pred)**2
        ccc = numerator/denominator
        
        return torch.mean(1-ccc)