Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Ambuj Mehrish
sidekit
Commits
aa29bc58
Commit
aa29bc58
authored
Nov 01, 2020
by
Anthony Larcher
Browse files
arcface
parent
27ad221c
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/loss.py
View file @
aa29bc58
...
...
@@ -28,6 +28,7 @@ Copyright 2014-2020 Anthony Larcher
import
h5py
import
logging
import
math
import
sys
import
numpy
import
torch
...
...
@@ -37,7 +38,7 @@ from collections import OrderedDict
from
.xsets
import
XvectorMultiDataset
,
XvectorDataset
,
StatDataset
from
..bosaris
import
IdMap
from
..statserver
import
StatServer
from
torch.nn
import
Parameter
#from .classification import Classification
...
...
@@ -51,6 +52,113 @@ __status__ = "Production"
__docformat__
=
'reS'
class
ArcMarginModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
args
):
super
(
ArcMarginModel
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
FloatTensor
(
num_classes
,
args
.
emb_size
))
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
self
.
easy_margin
=
args
.
easy_margin
self
.
m
=
args
.
margin_m
self
.
s
=
args
.
margin_s
self
.
cos_m
=
math
.
cos
(
self
.
m
)
self
.
sin_m
=
math
.
sin
(
self
.
m
)
self
.
th
=
math
.
cos
(
math
.
pi
-
self
.
m
)
self
.
mm
=
math
.
sin
(
math
.
pi
-
self
.
m
)
*
self
.
m
def
forward
(
self
,
input
,
label
):
x
=
F
.
normalize
(
input
)
W
=
F
.
normalize
(
self
.
weight
)
cosine
=
F
.
linear
(
x
,
W
)
sine
=
torch
.
sqrt
(
1.0
-
torch
.
pow
(
cosine
,
2
))
phi
=
cosine
*
self
.
cos_m
-
sine
*
self
.
sin_m
# cos(theta + m)
if
self
.
easy_margin
:
phi
=
torch
.
where
(
cosine
>
0
,
phi
,
cosine
)
else
:
phi
=
torch
.
where
(
cosine
>
self
.
th
,
phi
,
cosine
-
self
.
mm
)
one_hot
=
torch
.
zeros
(
cosine
.
size
(),
device
=
device
)
one_hot
.
scatter_
(
1
,
label
.
view
(
-
1
,
1
).
long
(),
1
)
output
=
(
one_hot
*
phi
)
+
((
1.0
-
one_hot
)
*
cosine
)
output
*=
self
.
s
return
output
def
l2_norm
(
input
,
axis
=
1
):
norm
=
torch
.
norm
(
input
,
2
,
axis
,
True
)
output
=
torch
.
div
(
input
,
norm
)
return
output
class
Arcface
(
torch
.
nn
.
Module
):
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
def
__init__
(
self
,
embedding_size
,
classnum
,
s
=
64.
,
m
=
0.5
):
super
(
Arcface
,
self
).
__init__
()
self
.
classnum
=
classnum
self
.
kernel
=
Parameter
(
torch
.
Tensor
(
embedding_size
,
classnum
))
# initial kernel
self
.
kernel
.
data
.
uniform_
(
-
1
,
1
).
renorm_
(
2
,
1
,
1e-5
).
mul_
(
1e5
)
self
.
m
=
m
# the margin value, default is 0.5
self
.
s
=
s
# scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
self
.
cos_m
=
math
.
cos
(
m
)
self
.
sin_m
=
math
.
sin
(
m
)
self
.
mm
=
self
.
sin_m
*
m
# issue 1
self
.
threshold
=
math
.
cos
(
math
.
pi
-
m
)
def
forward
(
self
,
embbedings
,
label
):
# weights norm
nB
=
len
(
embbedings
)
kernel_norm
=
l2_norm
(
self
.
kernel
,
axis
=
0
)
# cos(theta+m)
cos_theta
=
torch
.
mm
(
embbedings
,
kernel_norm
)
# output = torch.mm(embbedings,kernel_norm)
cos_theta
=
cos_theta
.
clamp
(
-
1
,
1
)
# for numerical stability
cos_theta_2
=
torch
.
pow
(
cos_theta
,
2
)
sin_theta_2
=
1
-
cos_theta_2
sin_theta
=
torch
.
sqrt
(
sin_theta_2
)
cos_theta_m
=
(
cos_theta
*
self
.
cos_m
-
sin_theta
*
self
.
sin_m
)
# this condition controls the theta+m should in range [0, pi]
# 0<=theta+m<=pi
# -m<=theta<=pi-m
cond_v
=
cos_theta
-
self
.
threshold
cond_mask
=
cond_v
<=
0
keep_val
=
(
cos_theta
-
self
.
mm
)
# when theta not in [0,pi], use cosface instead
cos_theta_m
[
cond_mask
]
=
keep_val
[
cond_mask
]
output
=
cos_theta
*
1.0
# a little bit hacky way to prevent in_place operation on cos_theta
idx_
=
torch
.
arange
(
0
,
nB
,
dtype
=
torch
.
long
)
output
[
idx_
,
label
]
=
cos_theta_m
[
idx_
,
label
]
output
*=
self
.
s
# scale up in order to make softmax work, first introduced in normface
return
output
################################## Cosface head #############################################################
class
Am_softmax
(
torch
.
nn
.
Module
):
# implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
def
__init__
(
self
,
embedding_size
=
512
,
classnum
=
51332
):
super
(
Am_softmax
,
self
).
__init__
()
self
.
classnum
=
classnum
self
.
kernel
=
Parameter
(
torch
.
Tensor
(
embedding_size
,
classnum
))
# initial kernel
self
.
kernel
.
data
.
uniform_
(
-
1
,
1
).
renorm_
(
2
,
1
,
1e-5
).
mul_
(
1e5
)
self
.
m
=
0.35
# additive margin recommended by the paper
self
.
s
=
30.
# see normface https://arxiv.org/abs/1704.06369
def
forward
(
self
,
embbedings
,
label
):
kernel_norm
=
l2_norm
(
self
.
kernel
,
axis
=
0
)
cos_theta
=
torch
.
mm
(
embbedings
,
kernel_norm
)
cos_theta
=
cos_theta
.
clamp
(
-
1
,
1
)
# for numerical stability
phi
=
cos_theta
-
self
.
m
label
=
label
.
view
(
-
1
,
1
)
# size=(B,1)
index
=
cos_theta
.
data
*
0.0
# size=(B,Classnum)
index
.
scatter_
(
1
,
label
.
data
.
view
(
-
1
,
1
),
1
)
index
=
index
.
byte
()
output
=
cos_theta
*
1.0
output
[
index
]
=
phi
[
index
]
# only change the correct predicted output
output
*=
self
.
s
# scale up in order to make softmax work, first introduced in normface
return
output
class
ArcLinear
(
torch
.
nn
.
Module
):
"""Additive Angular Margin linear module (ArcFace)
...
...
nnet/xvector.py
View file @
aa29bc58
...
...
@@ -51,7 +51,7 @@ from torch.utils.data import DataLoader
from
sklearn.model_selection
import
train_test_split
from
.sincnet
import
SincNet
#from torch.utils.tensorboard import SummaryWriter
from
.loss
import
ArcLinear
from
.loss
import
ArcLinear
,
ArcFace
import
tqdm
...
...
@@ -519,10 +519,14 @@ class Xtractor(torch.nn.Module):
elif
self
.
loss
==
"aam"
:
self
.
norm_embedding
=
True
self
.
after_speaker_embedding
=
ArcLinear
(
input_size
,
self
.
speaker_number
,
margin
=
self
.
aam_margin
,
s
=
self
.
aam_s
)
#self.after_speaker_embedding = ArcLinear(input_size,
# self.speaker_number,
# margin=self.aam_margin,
# s=self.aam_s)
self
.
after_speaker_embedding
=
ArcFace
(
embedding_size
=
input_size
,
classnum
=
self
.
speaker_number
,
s
=
64.
,
margin
=
0.5
)
self
.
after_speaker_embedding_weight_decay
=
cfg
[
"after_embedding"
][
"weight_decay"
]
...
...
@@ -545,7 +549,8 @@ class Xtractor(torch.nn.Module):
x
=
self
.
before_speaker_embedding
(
x
)
if
self
.
norm_embedding
:
x_norm
=
x
.
norm
(
p
=
2
,
dim
=
1
,
keepdim
=
True
)
/
10.
# Why 10. ?
#x_norm = x.norm(p=2,dim=1, keepdim=True) / 10. # Why 10. ?
x_norm
=
torch
.
linalg
.
norm
(
x
,
ord
=
2
,
dim
=
1
,
keepdim
=
True
,
out
=
None
,
dtype
=
None
)
x
=
torch
.
div
(
x
,
x_norm
)
if
is_eval
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment