Commit 575a86bc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

doc

parent dc71e2b2
......@@ -251,6 +251,11 @@ class ArcMarginProduct(torch.nn.Module):
self.mm = math.sin(math.pi - self.m) * self.m
def change_params(self, s=None, m=None):
"""
:param s:
:param m:
"""
if s is None:
s = self.s
if m is None:
......@@ -263,8 +268,15 @@ class ArcMarginProduct(torch.nn.Module):
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, target=None):
"""
:param input:
:param target:
:return:
"""
# cos(theta)
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight))
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input),
torch.nn.functional.normalize(self.weight))
if target == None:
return cosine * self.s
# cos(theta + m)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment