Commit b04c99df authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of into dev_al

parents aa50412d cf8fdc1a
......@@ -838,18 +838,38 @@ def xtrain(speaker_number,
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Start from scratch
if model_name is None and model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
# Initialize a first model
if model_yaml == "xvector":
model = Xtractor(speaker_number, "xvector", loss=loss)
elif model_yaml == "rawnet2":
model = Xtractor(speaker_number, "rawnet2")
elif model_yaml == "resnet34":
model = Xtractor(speaker_number, "resnet34")
elif model_yaml == "fastresnet34":
model = Xtractor(speaker_number, "fastresnet34")
# Use a predefined architecture
if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
if model_name is None:
model = Xtractor(speaker_number, model_yaml)
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, model_yaml)
Here we remove all layers that we don't want to reload
pretrained_dict = checkpoint["model_state_dict"]
for part in reset_parts:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
new_model_dict = model.state_dict()
# Freeze required layers
for name, param in model.named_parameters():
if name.split(".")[0] in freeze_parts:
param.requires_grad = False
model_archi = model_yaml
# Here use a config file to build the architecture
with open(model_yaml, 'r') as fh:
model_archi = yaml.load(fh, Loader=yaml.FullLoader)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment