Commit dce620d5 authored by Pierre Champion's avatar Pierre Champion
Browse files

pytorch 1.8 compatible

parent cc7d1469
......@@ -66,7 +66,7 @@ class CCELoss(torch.nn.Module):
def forward(self, embbedings, target):
x = self.module(embbedings)
if target == None:
return torch.tensor(torch.nan), x
return torch.tensor(float('nan')), x
loss = self.criterion(x, target)
return loss, x
......@@ -127,7 +127,7 @@ class ArcMarginProduct(torch.nn.Module):
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input),
torch.nn.functional.normalize(self.weight))
if target == None:
return torch.tensor(torch.nan), cosine * self.s
return torch.tensor(float('nan')), cosine * self.s
# cos(theta + m)
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
......@@ -176,7 +176,7 @@ class SoftmaxAngularProto(torch.nn.Module):
cce_prediction = self.cce_backend(x)
if target is None:
return torch.tensor(torch.nan), cce_prediction
return torch.tensor(float('nan')), cce_prediction
x = x.reshape(-1, 2, x.size()[-1]).squeeze(1)
......@@ -226,7 +226,7 @@ class AngularProximityMagnet(torch.nn.Module):
cce_prediction = self.cce_backend(x)
if target is None:
return torch.tensor(torch.nan), cce_prediction
return torch.tensor(float('nan')), cce_prediction
x = x.reshape(-1, 2, x.size()[-1]).squeeze(1)
out_anchor = torch.mean(x[:, 1:, :], 1)
......@@ -279,7 +279,7 @@ class CircleMargin(torch.nn.Module):
cosine = cosine.reshape(cosine.shape[0], -1, self.k).max(-1)[0]
if target is None:
return torch.tensor(torch.nan), cosine * self.gamma
return torch.tensor(float('nan')), cosine * self.gamma
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, target.view(-1, 1), 1)
......@@ -329,7 +329,7 @@ class CircleProto(torch.nn.Module):
torch.nn.functional.normalize(self.weight))
if target == None:
return torch.tensor(torch.nan), cosine * self.gamma
return torch.tensor(float('nan')), cosine * self.gamma
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, target.view(-1, 1), 1)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment