Commit fea3a88a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

refactoring AAM loss

parent 94b6f695
......@@ -243,7 +243,7 @@ class Xtractor(torch.nn.Module):
def __init__(self,
speaker_number,
model_archi="xvector",
loss="cce",
loss=None,
norm_embedding=False,
aam_margin=0.5,
aam_s=0.5):
......@@ -256,11 +256,13 @@ class Xtractor(torch.nn.Module):
self.feature_size = None
self.norm_embedding = norm_embedding
self.loss = loss
if loss not in ["cce", 'aam']:
raise NotImplementedError(f"The valid loss are for now cce and aam ")
if model_archi == "xvector":
if loss not in ["cce", 'aam']:
raise NotImplementedError(f"The valid loss are for now cce and aam ")
else:
self.loss = loss
self.feature_size = 30
self.activation = torch.nn.LeakyReLU(0.2)
......@@ -292,7 +294,7 @@ class Xtractor(torch.nn.Module):
if self.loss == "aam":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("arclinear8", ArcLinear(512, int(self.speaker_number), margin=aam_margin, s=aam_s))
("arclinear", ArcLinear(512, int(self.speaker_number), margin=aam_margin, s=aam_s))
]))
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
......@@ -310,6 +312,12 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding_weight_decay = 0.002
elif model_archi == "rawnet2":
if loss not in ["cce", 'aam']:
raise NotImplementedError(f"The valid loss are for now cce and aam ")
else:
self.loss = loss
filts = [128, [128, 128], [128, 256], [256, 256]]
self.norm_embedding = True
......@@ -355,6 +363,11 @@ class Xtractor(torch.nn.Module):
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
self.loss = cfg["loss"]
if self.loss == "aam":
self.aam_margin = cfg["aam_margin"]
self.aam_s = cfg["aam_s"]
"""
Prepare Preprocessor
"""
......@@ -422,7 +435,7 @@ class Xtractor(torch.nn.Module):
elif k.startswith("activation"):
segmental_layers.append((k, self.activation))
elif k.startswith('norm'):
elif k.startswith('batch_norm'):
segmental_layers.append((k, torch.nn.BatchNorm1d(input_size)))
self.sequence_network = torch.nn.Sequential(OrderedDict(segmental_layers))
......@@ -438,7 +451,7 @@ class Xtractor(torch.nn.Module):
nb_gru_layer=cfg["stat_pooling"]["nb_gru_layer"])
"""
Prepapre last part of the network (after pooling)
Prepare last part of the network (after pooling)
"""
# Create sequential object for the second part of the network
input_size = input_size * 2
......@@ -455,7 +468,7 @@ class Xtractor(torch.nn.Module):
elif k.startswith("activation"):
before_embedding_layers.append((k, self.activation))
elif k.startswith('norm'):
elif k.startswith('batch_norm'):
before_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
......@@ -476,21 +489,15 @@ class Xtractor(torch.nn.Module):
input_size = cfg["after_embedding"][k]["output"]
elif k.startswith('arc'):
if cfg["after_embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append(
(k, ArcLinear(input_size, self.speaker_number, margin=aam_margin, s=aam_s)))
else:
after_embedding_layers.append(
(k, ArcLinear(input_size,
self.speaker_number,
margin=aam_margin,
s=aam_s)))
input_size = self.speaker_number
after_embedding_layers.append((k, ArcLinear(input_size,
self.speaker_number,
margin=self.aam_margin,
s=self.aam_s)))
elif k.startswith("activation"):
after_embedding_layers.append((k, self.activation))
elif k.startswith('norm'):
elif k.startswith('batch_norm'):
after_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
......@@ -517,15 +524,17 @@ class Xtractor(torch.nn.Module):
x = self.stat_pooling(x)
x = self.before_speaker_embedding(x)
if is_eval:
return x
if self.norm_embedding:
x_norm = x.norm(p=2,dim=1, keepdim=True) / 10.
x = torch.div(x, x_norm)
if is_eval:
return x
if self.loss == "cce":
x = self.after_speaker_embedding(x)
elif self.loss == "aam":
if not is_eval:
x = self.after_speaker_embedding(x,target=target)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment