Commit 575a86bc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

doc

parent dc71e2b2
...@@ -251,6 +251,11 @@ class ArcMarginProduct(torch.nn.Module): ...@@ -251,6 +251,11 @@ class ArcMarginProduct(torch.nn.Module):
self.mm = math.sin(math.pi - self.m) * self.m self.mm = math.sin(math.pi - self.m) * self.m
def change_params(self, s=None, m=None): def change_params(self, s=None, m=None):
"""
:param s:
:param m:
"""
if s is None: if s is None:
s = self.s s = self.s
if m is None: if m is None:
...@@ -263,8 +268,15 @@ class ArcMarginProduct(torch.nn.Module): ...@@ -263,8 +268,15 @@ class ArcMarginProduct(torch.nn.Module):
self.mm = math.sin(math.pi - self.m) * self.m self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, target=None): def forward(self, input, target=None):
"""
:param input:
:param target:
:return:
"""
# cos(theta) # cos(theta)
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight)) cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input),
torch.nn.functional.normalize(self.weight))
if target == None: if target == None:
return cosine * self.s return cosine * self.s
# cos(theta + m) # cos(theta + m)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment