Commit 0f3b874a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

modification of amm

parent 31911a3e
...@@ -250,6 +250,18 @@ class ArcMarginProduct(torch.nn.Module): ...@@ -250,6 +250,18 @@ class ArcMarginProduct(torch.nn.Module):
self.th = math.cos(math.pi - self.m) self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m self.mm = math.sin(math.pi - self.m) * self.m
def __change_params(self, s=None, m=None):
if s is None:
s = self.s
if m is None:
m = self.m
self.s = s
self.m = m
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, target=None): def forward(self, input, target=None):
# cos(theta) # cos(theta)
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight)) cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment