Commit aa304dee authored by Anthony Larcher's avatar Anthony Larcher
Browse files
parents f41e9de3 a10aae62
......@@ -93,8 +93,10 @@ class ArcLinear(torch.nn.Module):
fX : `torch.Tensor`
logits after the angular margin transformation
"""
# normalize the feature vectors and W
xnorm = torch.nn.functional.normalize(x)
# the feature vectors has been normalized before calling this layer
#xnorm = torch.nn.functional.normalize(x)
xnorm = x
# normalize W
Wnorm = torch.nn.functional.normalize(self.W)
target = target.long().view(-1, 1)
# calculate cosθj (the logits)
......
......@@ -518,6 +518,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict(after_embedding_layers))
elif self.loss == "aam":
self.norm_embedding = True
self.after_speaker_embedding = ArcLinear(input_size,
self.speaker_number,
margin=self.aam_margin,
......@@ -544,7 +545,7 @@ class Xtractor(torch.nn.Module):
x = self.before_speaker_embedding(x)
if self.norm_embedding:
x_norm = x.norm(p=2,dim=1, keepdim=True) / 10.
x_norm = x.norm(p=2,dim=1, keepdim=True) / 10. # Why 10. ?
x = torch.div(x, x_norm)
if is_eval:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment