Commit 9fa896da authored by Anthony Larcher's avatar Anthony Larcher
Browse files

rawnet2

parent 5dad76a2
......@@ -223,6 +223,12 @@ class Xtractor(torch.nn.Module):
out_features = int(self.speaker_number),
bias = True)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
self.stat_pooling_weight_decay = 0.000
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
else:
# Load Yaml configuration
with open(model_archi, 'r') as fh:
......@@ -419,9 +425,9 @@ def xtrain(speaker_number,
model.load_state_dict(checkpoint["model_state_dict"])
else:
# Initialize a first model
if model_yaml is "xvector":
if model_yaml == "xvector":
model = Xtractor(speaker_number, "xverctor")
elif model_yaml is "rawnet2":
elif model_yaml == "rawnet2":
model = Xtractor(speaker_number, "rawnet2")
else:
model = Xtractor(speaker_number, model_yaml)
......@@ -477,8 +483,12 @@ def xtrain(speaker_number,
if type(model) is Xtractor:
optimizer = _optimizer([
{'params': model.preprocessor.parameters(),
'weight_decay': model.preprocessor_weight_decay},
{'params': model.sequence_network.parameters(),
'weight_decay': model.sequence_network_weight_decay},
{'params': model.stat_pooling.parameters(),
'weight_decay': model.stat_pooling_weight_decay},
{'params': model.before_speaker_embedding.parameters(),
'weight_decay': model.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment