Commit 13062045 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

arcface

parent aa29bc58
......@@ -51,7 +51,9 @@ from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .sincnet import SincNet
#from torch.utils.tensorboard import SummaryWriter
from .loss import ArcLinear, ArcFace
from .loss import ArcLinear
from .loss import ArcFace
from .loss import l2_norm
import tqdm
......@@ -550,8 +552,9 @@ class Xtractor(torch.nn.Module):
if self.norm_embedding:
#x_norm = x.norm(p=2,dim=1, keepdim=True) / 10. # Why 10. ?
x_norm = torch.linalg.norm(x, ord=2, dim=1, keepdim=True, out=None, dtype=None)
x = torch.div(x, x_norm)
#x_norm = torch.linalg.norm(x, ord=2, dim=1, keepdim=True, out=None, dtype=None)
#x = torch.div(x, x_norm)
x = l2_norm(x)
if is_eval:
return x
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment