Commit aa29bc58 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

arcface

parent 27ad221c
......@@ -28,6 +28,7 @@ Copyright 2014-2020 Anthony Larcher
import h5py
import logging
import math
import sys
import numpy
import torch
......@@ -37,7 +38,7 @@ from collections import OrderedDict
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from ..bosaris import IdMap
from ..statserver import StatServer
from torch.nn import Parameter
#from .classification import Classification
......@@ -51,6 +52,113 @@ __status__ = "Production"
__docformat__ = 'reS'
class ArcMarginModel(torch.nn.Module):
def __init__(self, args):
super(ArcMarginModel, self).__init__()
self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = args.easy_margin
self.m = args.margin_m
self.s = args.margin_s
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, label):
x = F.normalize(input)
W = F.normalize(self.weight)
cosine = F.linear(x, W)
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device=device)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
def l2_norm(input,axis=1):
norm = torch.norm(input,2,axis,True)
output = torch.div(input, norm)
return output
class Arcface(torch.nn.Module):
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
def __init__(self, embedding_size, classnum, s=64., m=0.5):
super(Arcface, self).__init__()
self.classnum = classnum
self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
# initial kernel
self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.m = m # the margin value, default is 0.5
self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.mm = self.sin_m * m # issue 1
self.threshold = math.cos(math.pi - m)
def forward(self, embbedings, label):
# weights norm
nB = len(embbedings)
kernel_norm = l2_norm(self.kernel, axis=0)
# cos(theta+m)
cos_theta = torch.mm(embbedings, kernel_norm)
# output = torch.mm(embbedings,kernel_norm)
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
cos_theta_2 = torch.pow(cos_theta, 2)
sin_theta_2 = 1 - cos_theta_2
sin_theta = torch.sqrt(sin_theta_2)
cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
# this condition controls the theta+m should in range [0, pi]
# 0<=theta+m<=pi
# -m<=theta<=pi-m
cond_v = cos_theta - self.threshold
cond_mask = cond_v <= 0
keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
cos_theta_m[cond_mask] = keep_val[cond_mask]
output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
idx_ = torch.arange(0, nB, dtype=torch.long)
output[idx_, label] = cos_theta_m[idx_, label]
output *= self.s # scale up in order to make softmax work, first introduced in normface
return output
################################## Cosface head #############################################################
class Am_softmax(torch.nn.Module):
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
def __init__(self, embedding_size=512, classnum=51332):
super(Am_softmax, self).__init__()
self.classnum = classnum
self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
# initial kernel
self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.m = 0.35 # additive margin recommended by the paper
self.s = 30. # see normface https://arxiv.org/abs/1704.06369
def forward(self, embbedings, label):
kernel_norm = l2_norm(self.kernel, axis=0)
cos_theta = torch.mm(embbedings, kernel_norm)
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
phi = cos_theta - self.m
label = label.view(-1, 1) # size=(B,1)
index = cos_theta.data * 0.0 # size=(B,Classnum)
index.scatter_(1, label.data.view(-1, 1), 1)
index = index.byte()
output = cos_theta * 1.0
output[index] = phi[index] # only change the correct predicted output
output *= self.s # scale up in order to make softmax work, first introduced in normface
return output
class ArcLinear(torch.nn.Module):
"""Additive Angular Margin linear module (ArcFace)
......
......@@ -51,7 +51,7 @@ from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .sincnet import SincNet
#from torch.utils.tensorboard import SummaryWriter
from .loss import ArcLinear
from .loss import ArcLinear, ArcFace
import tqdm
......@@ -519,10 +519,14 @@ class Xtractor(torch.nn.Module):
elif self.loss == "aam":
self.norm_embedding = True
self.after_speaker_embedding = ArcLinear(input_size,
self.speaker_number,
margin=self.aam_margin,
s=self.aam_s)
#self.after_speaker_embedding = ArcLinear(input_size,
# self.speaker_number,
# margin=self.aam_margin,
# s=self.aam_s)
self.after_speaker_embedding = ArcFace(embedding_size=input_size,
classnum=self.speaker_number,
s=64.,
margin=0.5)
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
......@@ -545,7 +549,8 @@ class Xtractor(torch.nn.Module):
x = self.before_speaker_embedding(x)
if self.norm_embedding:
x_norm = x.norm(p=2,dim=1, keepdim=True) / 10. # Why 10. ?
#x_norm = x.norm(p=2,dim=1, keepdim=True) / 10. # Why 10. ?
x_norm = torch.linalg.norm(x, ord=2, dim=1, keepdim=True, out=None, dtype=None)
x = torch.div(x, x_norm)
if is_eval:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment