Commit 558fce3f authored by Anthony Larcher's avatar Anthony Larcher
Browse files

fix API

parent c63b451e
......@@ -365,10 +365,10 @@ class Xtractor(torch.nn.Module):
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
self.loss = cfg["loss"]
self.loss = cfg["training"]["loss"]
if self.loss == "aam":
self.aam_margin = cfg["aam_margin"]
self.aam_s = cfg["aam_s"]
self.aam_margin = cfg["training"]["aam_margin"]
self.aam_s = cfg["training"]["aam_s"]
"""
Prepare Preprocessor
......@@ -655,7 +655,7 @@ def xtrain(speaker_number,
if clipping is None:
clipping = model_archi["training"]["clipping"]
if model_name is None
if model_name is None:
model = Xtractor(speaker_number, model_yaml)
# If we start from an existing model
......@@ -665,22 +665,22 @@ def xtrain(speaker_number,
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, model_yaml)
"""
Here we remove all layers that we don't want to reload
"""
Here we remove all layers that we don't want to reload
"""
pretrained_dict = checkpoint["model_state_dict"]
for part in reset_parts:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
new_model_dict = model.state_dict()
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict)
# Freeze required layers
for name, param in model.named_parameters():
if name.split(".")[0] in freeze_parts:
param.requires_grad = False
"""
pretrained_dict = checkpoint["model_state_dict"]
for part in reset_parts:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
new_model_dict = model.state_dict()
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict)
# Freeze required layers
for name, param in model.named_parameters():
if name.split(".")[0] in freeze_parts:
param.requires_grad = False
print(model)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment