Commit 2edd99d1 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add embedding size option

parent 1416db4e
......@@ -376,7 +376,8 @@ class Xtractor(torch.nn.Module):
loss=None,
norm_embedding=False,
aam_margin=0.2,
aam_s=30):
aam_s=30,
embedding_size=256):
"""
If config is None, default architecture is created
:param model_archi:
......@@ -420,7 +421,7 @@ class Xtractor(torch.nn.Module):
("batch_norm5", torch.nn.BatchNorm1d(1536))
]))
self.embedding_size = 512
self.embedding_size = embedding_size
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
......@@ -454,7 +455,7 @@ class Xtractor(torch.nn.Module):
self.preprocessor = MelSpecFrontEnd(n_mels=80)
self.sequence_network = PreResNet34()
self.embedding_size = 256
self.embedding_size = embedding_size
self.before_speaker_embedding = torch.nn.Linear(in_features=5120,
out_features=self.embedding_size)
......@@ -477,7 +478,7 @@ class Xtractor(torch.nn.Module):
elif model_archi == "fastresnet34":
self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreFastResNet34()
self.embedding_size = 256
self.embedding_size = embedding_size
self.before_speaker_embedding = torch.nn.Linear(in_features = 2560,
out_features = self.embedding_size)
......@@ -510,7 +511,9 @@ class Xtractor(torch.nn.Module):
hop_length=160,
n_mels=80)
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
self.embedding_size = embedding_size
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
......@@ -924,6 +927,7 @@ def update_training_dictionary(dataset_description,
# Initialize model options
model_opts["speaker_number"] = None
model_opts["embedding_size"] = 256
model_opts["loss"] = dict()
model_opts["loss"]["type"] ="aam"
model_opts["loss"]["aam_margin"] = 0.2
......@@ -1711,7 +1715,10 @@ def extract_embeddings(idmap_name,
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
model_opts = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_opts["model_type"], loss=model_opts["loss"]["type"])
model = Xtractor(speaker_number,
model_archi=model_opts["model_type"],
loss=model_opts["loss"]["type"],
embedding_size=model_opts["embedding_size"])
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment