Commit 43927406 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

resnet

parent 0ff214ca
......@@ -420,10 +420,10 @@ class BasicBlock(torch.nn.Module):
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
out = torch.nn.functional.relu(out)
return out
......@@ -488,7 +488,7 @@ class ResNet(torch.nn.Module):
return torch.nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
......@@ -534,7 +534,8 @@ class PreResNet34(torch.nn.Module):
return torch.nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = x.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
......
......@@ -344,6 +344,8 @@ class Xtractor(torch.nn.Module):
("linear6", torch.nn.Linear(3072, 512))
]))
self.embedding_size = 512
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(512,
int(self.speaker_number),
......@@ -375,21 +377,17 @@ class Xtractor(torch.nn.Module):
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
if loss not in ["cce", 'aam']:
raise NotImplementedError(f"The valid loss are for now cce and aam ")
else:
self.loss = loss
self.loss = "aam"
if self.loss == "aam":
if loss == 'aam':
self.after_speaker_embedding = ArcLinear(256,
int(self.speaker_number),
margin=aam_margin, s=aam_s)
self.after_speaker_embedding = ArcMarginProduct(256, int(self.speaker_number), s=64, m=0.2, easy_margin=True)
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Linear(in_features = 256,
out_features = int(self.speaker_number),
bias = True)
self.embedding_size = 256
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
self.stat_pooling_weight_decay = 0.000
......@@ -750,12 +748,14 @@ def xtrain(speaker_number,
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
if model_name is None and model_yaml in ["xvector", "rawnet2"]:
if model_name is None and model_yaml in ["xvector", "rawnet2", "resnet34"]:
# Initialize a first model
if model_yaml == "xvector":
model = Xtractor(speaker_number, "xvector", loss=loss)
elif model_yaml == "rawnet2":
model = Xtractor(speaker_number, "rawnet2")
elif model_yaml == "resnet34":
model = Xtractor(speaker_number, "resnet34")
model_archi = model_yaml
else:
with open(model_yaml, 'r') as fh:
......@@ -988,6 +988,11 @@ def xtrain(speaker_number,
is_best = val_acc > best_accuracy
best_accuracy = max(val_acc, best_accuracy)
if tmp_model_name is None:
tmp_model_name = "tmp_model"
if best_model_name is None:
best_model_name = "best_model"
if type(model) is Xtractor:
save_checkpoint({
'epoch': epoch,
......@@ -1153,6 +1158,131 @@ def cross_validation(model, validation_loader, device, validation_shape):
loss.cpu().numpy() / ((batch_idx + 1) * batch_size), equal_error_rate
class XtractorTop(torch.nn.Module):
def __init__(self,
model_filename,
loss=None,
aam_margin=None,
aam_s=None):
"""
:param model_filename:
:param loss:
:param aam_margin:
:param aam_s:
"""
super(XtractorTop, self).__init__()
# Load the model and only use the last part of it (not to use sequence_network)
checkpoint = torch.load(model_filename, map_location='cpu')
cfg = checkpoint["model_archi"]
self.speaker_number = checkpoint["speaker_number"]
# Get activation function
if cfg["activation"] == 'LeakyReLU':
self.activation = torch.nn.LeakyReLU(0.2)
elif cfg["activation"] == 'PReLU':
self.activation = torch.nn.PReLU()
elif cfg["activation"] == 'ReLU6':
self.activation = torch.nn.ReLU6()
else:
self.activation = torch.nn.ReLU()
model_layers = []
for k in cfg["before_embedding"].keys():
if k.startswith("lin"):
input_size = checkpoint["model_state_dict"]["before_speaker_embedding."+k+".weight"].shape[1]
output_size = checkpoint["model_state_dict"]["before_speaker_embedding."+k+".weight"].shape[0]
model_layers.append((k, torch.nn.Linear(input_size, output_size)))
elif k.startswith("activation"):
model_layers.append((k, self.activation))
elif k.startswith('batch_norm'):
model_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
model_layers.append((k, torch.nn.Dropout(p=cfg["before_embedding"][k])))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(model_layers))
# if loss_criteria is "cce"
# Create sequential object for the second part of the network
if checkpoint["model_archi"]["training"]["loss"] == "cce":
for k in cfg["after_embedding"].keys():
if k.startswith("lin"):
if cfg["after_embedding"][k]["output"] == "speaker_number":
model_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
model_layers.append((k, torch.nn.Linear(input_size,
cfg["after_embedding"][k]["output"])))
input_size = cfg["after_embedding"][k]["output"]
elif k.startswith('arc'):
model_layers.append((k, ArcLinear(output_size,
self.speaker_number,
margin=aam_margin,
s=aam_s)))
elif k.startswith("activation"):
model_layers.append((k, self.activation))
elif k.startswith('batch_norm'):
model_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
model_layers.append((k, torch.nn.Dropout(p=cfg["after_embedding"][k])))
model = torch.nn.Sequential(OrderedDict(model_layers))
elif checkpoint["model_archi"]["training"]["loss"] == "aam":
self.after_speaker_embedding = ArcMarginProduct(output_size,
int(self.speaker_number),
s=64,
m=0.2,
easy_margin=True)
# Now load layers from the file
new_model_dict = self.state_dict()
pretrained_dict = checkpoint["model_state_dict"]
for k, v in pretrained_dict.items():
if k in new_model_dict:
new_model_dict[k] = v
self.load_state_dict(new_model_dict)
def forward(self, x, is_eval=False, target=None,):
"""
:param x:
:param is_eval: False for training
:return:
"""
x = self.before_speaker_embedding(x)
if self.loss == "cce":
if is_eval:
return self.after_speaker_embedding(x), x
else:
return self.after_speaker_embedding(x)
elif self.loss == "aam":
if not is_eval:
x = self.after_speaker_embedding(l2_norm(x), target=target), l2_norm(x)
else:
x = self.after_speaker_embedding(l2_norm(x), target=None), l2_norm(x)
return x
def extract_embeddings(idmap_name,
model_filename,
data_root_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment