Commit cf8fdc1a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

resume from predefined architecture

parent c5253816
......@@ -953,18 +953,38 @@ def xtrain(speaker_number,
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Start from scratch
if model_name is None and model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
# Initialize a first model
if model_yaml == "xvector":
model = Xtractor(speaker_number, "xvector", loss=loss)
elif model_yaml == "rawnet2":
model = Xtractor(speaker_number, "rawnet2")
elif model_yaml == "resnet34":
model = Xtractor(speaker_number, "resnet34")
elif model_yaml == "fastresnet34":
model = Xtractor(speaker_number, "fastresnet34")
# Use a predefined architecture
if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
if model_name is None:
model = Xtractor(speaker_number, model_yaml)
else:
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, model_yaml)
"""
Here we remove all layers that we don't want to reload
"""
pretrained_dict = checkpoint["model_state_dict"]
for part in reset_parts:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
new_model_dict = model.state_dict()
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict)
# Freeze required layers
for name, param in model.named_parameters():
if name.split(".")[0] in freeze_parts:
param.requires_grad = False
model_archi = model_yaml
# Here use a config file to build the architecture
else:
with open(model_yaml, 'r') as fh:
model_archi = yaml.load(fh, Loader=yaml.FullLoader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment