Commit c2809646 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

fix API

parent c63b451e
......@@ -273,19 +273,19 @@ class Xtractor(torch.nn.Module):
self.sequence_network = torch.nn.Sequential(OrderedDict([
("conv1", torch.nn.Conv1d(self.feature_size, 512, 5, dilation=1)),
("activation1", torch.nn.LeakyReLU(0.2)),
("norm1", torch.nn.BatchNorm1d(512)),
("batch_norm1", torch.nn.BatchNorm1d(512)),
("conv2", torch.nn.Conv1d(512, 512, 3, dilation=2)),
("activation2", torch.nn.LeakyReLU(0.2)),
("norm2", torch.nn.BatchNorm1d(512)),
("batch_norm2", torch.nn.BatchNorm1d(512)),
("conv3", torch.nn.Conv1d(512, 512, 3, dilation=3)),
("activation3", torch.nn.LeakyReLU(0.2)),
("norm3", torch.nn.BatchNorm1d(512)),
("batch_norm3", torch.nn.BatchNorm1d(512)),
("conv4", torch.nn.Conv1d(512, 512, 1)),
("activation4", torch.nn.LeakyReLU(0.2)),
("norm4", torch.nn.BatchNorm1d(512)),
("batch_norm4", torch.nn.BatchNorm1d(512)),
("conv5", torch.nn.Conv1d(512, 1536, 1)),
("activation5", torch.nn.LeakyReLU(0.2)),
("norm5", torch.nn.BatchNorm1d(1536))
("batch_norm5", torch.nn.BatchNorm1d(1536))
]))
self.stat_pooling = MeanStdPooling()
......@@ -301,11 +301,11 @@ class Xtractor(torch.nn.Module):
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("batch_norm6", torch.nn.BatchNorm1d(512)),
("dropout6", torch.nn.Dropout(p=0.05)),
("linear7", torch.nn.Linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("batch_norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
......@@ -361,9 +361,12 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding_weight_decay = 0.00
else:
# Load Yaml configuration
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
if isinstance(model_archi, dict):
cfg = model_archi
else:
# Load Yaml configuration
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
self.loss = cfg["loss"]
if self.loss == "aam":
......@@ -819,7 +822,7 @@ def xtrain(speaker_number,
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number' : speaker_number,
'model_archi': model_yaml
'model_archi': model_archi
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
......@@ -829,7 +832,7 @@ def xtrain(speaker_number,
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number': speaker_number,
'model_archi': model_yaml
'model_archi': model_archi
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
......@@ -975,8 +978,8 @@ def extract_embeddings(idmap_name,
if speaker_number is None:
speaker_number = checkpoint["speaker_number"]
if model_yaml is None:
model_yaml = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_yaml)
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment