Commit 57295377 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

arcface

parent 98fdd2d7
......@@ -106,7 +106,7 @@ class Arcface(torch.nn.Module):
self.mm = self.sin_m * m # issue 1
self.threshold = math.cos(math.pi - m)
def forward(self, embbedings, label):
def forward(self, embbedings, target):
# weights norm
nB = len(embbedings)
kernel_norm = l2_norm(self.kernel, axis=0)
......@@ -127,7 +127,7 @@ class Arcface(torch.nn.Module):
cos_theta_m[cond_mask] = keep_val[cond_mask]
output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
idx_ = torch.arange(0, nB, dtype=torch.long)
output[idx_, label] = cos_theta_m[idx_, label]
output[idx_, target] = cos_theta_m[idx_, target]
output *= self.s # scale up in order to make softmax work, first introduced in normface
return output
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment