Commit 01885d7c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

intergration of AAM

parent e1ac82f0
......@@ -52,7 +52,7 @@ __docformat__ = 'reS'
class ArcLinear(nn.Module):
class ArcLinear(torch.nn.Module):
"""Additive Angular Margin linear module (ArcFace)
Parameters
......@@ -75,8 +75,8 @@ class ArcLinear(nn.Module):
self.nclass = nclass
self.margin = margin
self.s = s
self.W = nn.Parameter(torch.Tensor(nclass, nfeat))
nn.init.xavier_uniform_(self.W)
self.W = torch.nn.Parameter(torch.Tensor(nclass, nfeat))
torch.nn.init.xavier_uniform_(self.W)
def forward(self, x, target=None):
"""Apply the angular margin transformation
......@@ -94,8 +94,8 @@ class ArcLinear(nn.Module):
logits after the angular margin transformation
"""
# normalize the feature vectors and W
xnorm = F.normalize(x)
Wnorm = F.normalize(self.W)
xnorm = torch.nn.functional.normalize(x)
Wnorm = torch.nn.functional.normalize(self.W)
target = target.long().view(-1, 1)
# calculate cosθj (the logits)
cos_theta_j = torch.matmul(xnorm, torch.transpose(Wnorm, 0, 1))
......
......@@ -579,17 +579,29 @@ class IdMapSet(Dataset):
:param index:
:return:
"""
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=int(self.idmap.start[index]),
stop=int(self.idmap.stop[index]))
if self.idmap.start[index] is None:
start = 0.0
if self.idmap.start[index] is None and self.idmap.stop[index] is None:
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = 0
stop = len(sig)
else:
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=start,
stop=stop)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline is not None:
sig, _, ___, _____, _t, _s = self.transforms((sig, 0, 0, 0, 0, 0))
return torch.from_numpy(sig).type(torch.FloatTensor), \
self.idmap.leftids[index], \
self.idmap.rightids[index], \
self.idmap.start[index], self.idmap.stop[index]
start, stop
#self.idmap.start[index], self.idmap.stop[index]
def __len__(self):
......
......@@ -347,7 +347,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding = ArcLinear(1024,
int(self.speaker_number),
margin=aam_margin, s=aam_s)
elif self.loss == "cce"
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Linear(in_features = 1024,
out_features = int(self.speaker_number),
bias = True)
......@@ -450,6 +450,8 @@ class Xtractor(torch.nn.Module):
gru_node=cfg["stat_pooling"]["gru_node"],
nb_gru_layer=cfg["stat_pooling"]["nb_gru_layer"])
self.stat_pooling_weight_decay = cfg["stat_pooling"]["weight_decay"]
"""
Prepare last part of the network (after pooling)
"""
......@@ -477,33 +479,42 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(before_embedding_layers))
self.before_speaker_embedding_weight_decay = cfg["before_embedding"]["weight_decay"]
# if loss_criteria is "cce"
# Create sequential object for the second part of the network
after_embedding_layers = []
for k in cfg["after_embedding"].keys():
if k.startswith("lin"):
if cfg["after_embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
after_embedding_layers.append((k, torch.nn.Linear(input_size,
if self.loss == "cce":
after_embedding_layers = []
for k in cfg["after_embedding"].keys():
if k.startswith("lin"):
if cfg["after_embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
after_embedding_layers.append((k, torch.nn.Linear(input_size,
cfg["after_embedding"][k]["output"])))
input_size = cfg["after_embedding"][k]["output"]
input_size = cfg["after_embedding"][k]["output"]
elif k.startswith('arc'):
after_embedding_layers.append((k, ArcLinear(input_size,
self.speaker_number,
margin=self.aam_margin,
s=self.aam_s)))
elif k.startswith('arc'):
after_embedding_layers.append((k, ArcLinear(input_size,
self.speaker_number,
margin=self.aam_margin,
s=self.aam_s)))
elif k.startswith("activation"):
after_embedding_layers.append((k, self.activation))
elif k.startswith("activation"):
after_embedding_layers.append((k, self.activation))
elif k.startswith('batch_norm'):
after_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('batch_norm'):
after_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
after_embedding_layers.append((k, torch.nn.Dropout(p=cfg["after_embedding"][k])))
elif k.startswith('dropout'):
after_embedding_layers.append((k, torch.nn.Dropout(p=cfg["after_embedding"][k])))
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict(after_embedding_layers))
elif self.loss == "aam":
self.after_speaker_embedding = ArcLinear(input_size,
self.speaker_number,
margin=self.aam_margin,
s=self.aam_s)
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict(after_embedding_layers))
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
......@@ -698,30 +709,25 @@ def xtrain(speaker_number,
},
]
param_list = []
if type(model) is Xtractor:
optimizer = _optimizer([
{'params': model.preprocessor.parameters(),
'weight_decay': model.preprocessor_weight_decay},
{'params': model.sequence_network.parameters(),
'weight_decay': model.sequence_network_weight_decay},
{'params': model.stat_pooling.parameters(),
'weight_decay': model.stat_pooling_weight_decay},
{'params': model.before_speaker_embedding.parameters(),
'weight_decay': model.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(),
'weight_decay': model.after_speaker_embedding_weight_decay}],
**_options
)
if model.preprocessor is not None:
param_list.append({'params': model.preprocessor.parameters(), 'weight_decay': model.preprocessor_weight_decay})
param_list.append({'params': model.sequence_network.parameters(), 'weight_decay': model.sequence_network_weight_decay})
param_list.append({'params': model.stat_pooling.parameters(), 'weight_decay': model.stat_pooling_weight_decay})
param_list.append({'params': model.before_speaker_embedding.parameters(), 'weight_decay': model.before_speaker_embedding_weight_decay})
param_list.append({'params': model.after_speaker_embedding.parameters(), 'weight_decay': model.after_speaker_embedding_weight_decay})
else:
optimizer = _optimizer([
{'params': model.module.sequence_network.parameters(),
'weight_decay': model.module.sequence_network_weight_decay},
{'params': model.module.before_speaker_embedding.parameters(),
'weight_decay': model.module.before_speaker_embedding_weight_decay},
{'params': model.module.after_speaker_embedding.parameters(),
'weight_decay': model.module.after_speaker_embedding_weight_decay}],
**_options
)
if model.module.preprocessor is not None:
param_list.append({'params': model.module.preprocessor.parameters(), 'weight_decay': model.module.preprocessor_weight_decay})
param_list.append({'params': model.module.sequence_network.parameters(), 'weight_decay': model.module.sequence_network_weight_decay})
param_list.append({'params': model.module.stat_pooling.parameters(), 'weight_decay': model.module.stat_pooling_weight_decay})
param_list.append({'params': model.module.before_speaker_embedding.parameters(), 'weight_decay': model.module.before_speaker_embedding_weight_decay})
param_list.append({'params': model.module.after_speaker_embedding.parameters(), 'weight_decay': model.module.after_speaker_embedding_weight_decay})
optimizer = _optimizer(param_list, **_options)
#optimizer = torch.optim.SGD(params,
# lr=lr,
......@@ -803,14 +809,25 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
if isinstance(model, Xtractor):
loss_criteria = model.loss
else:
loss_criteria = model.module.loss
accuracy = 0.0
running_loss = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
target = target.to(device)
optimizer.zero_grad()
output = model(data.to(device))
if loss_criteria == 'aam':
output = model(data.to(device), target=target)
else:
output = model(data.to(device), target=None)
#with GuruMeditation():
loss = criterion(output, target.to(device))
loss = criterion(output, target)
if not torch.isnan(loss):
loss.backward()
if clipping:
......@@ -819,8 +836,7 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
accuracy += (torch.argmax(output.data, 1) == target).sum()
if batch_idx % log_interval == 0:
batch_size = target.shape[0]
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
......@@ -866,6 +882,11 @@ def cross_validation(model, validation_loader, device):
"""
model.eval()
if isinstance(model, Xtractor):
loss_criteria = model.loss
else:
loss_criteria = model.module.loss
accuracy = 0.0
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
......@@ -873,11 +894,16 @@ def cross_validation(model, validation_loader, device):
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
output = model(data.to(device))
if loss_criteria == "aam":
output = model(data.to(device), target=target)
else:
output = model(data.to(device), target=None)
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
loss += criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment