Commit 9b0b63eb authored by Anthony Larcher's avatar Anthony Larcher
Browse files
parents fe6e8004 1c256079
......@@ -223,6 +223,56 @@ class ArcLinear(torch.nn.Module):
# project margin differences into cosθj
return self.s * (cos_theta_j + one_hot * (cos_theta_yi_margin - cos_theta_yi))
class ArcMarginProduct(nn.Module):
r"""Implement of large margin arc distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
cos(theta + m)
"""
def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
torch.nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, target):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
if target is None:
return cosine * self.s
else:
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + (
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
# print(output)
return output
......
......@@ -50,13 +50,19 @@ from .xsets import IdMapSet
from .xsets import IdMapSet_per_speaker
from .res_net import RawPreprocessor, ResBlockWFMS
from ..bosaris import IdMap
from ..bosaris import Key
from ..bosaris import Ndx
from ..bosaris.detplot import rocch
from ..bosaris.detplot import rocch2eer
from ..statserver import StatServer
from ..iv_scoring import cosine_scoring
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .sincnet import SincNet
from .loss import ArcLinear
from .loss import ArcFace
from .loss import l2_norm
from .loss import ArcMarginProduct
......@@ -94,7 +100,7 @@ class GuruMeditation (torch.autograd.detect_anomaly):
super(GuruMeditation, self).__exit__()
if isinstance(value, RuntimeError):
traceback.print_tb(trace)
halt(str(value))
self.halt(str(value))
def halt(msg):
print (msg)
......@@ -119,7 +125,7 @@ def matplotlib_imshow(img, one_channel=False):
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.imshow(numpy.transpose(npimg, (1, 2, 0)))
def speech_to_probs(model, speech):
'''
......@@ -156,6 +162,47 @@ def plot_classes_preds(model, speech, labels):
return fig
def compute_metrics(model,
validation_loader,
device,
val_embs_shape,
speaker_number,
model_archi):
"""Compute model metrics
Args:
model ([type]): [description]
validation_loader ([type]): [description]
device ([type]): [description]
speaker_number ([type]): [description]
model_archi ([type]): [description]
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
[type]: [description]
"""
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, val_embs_shape)
xv_stat = extract_embeddings(idmap_name='h5f/idmap_test.h5',
speaker_number=speaker_number,
model_filename=model,
model_yaml=model_archi,
data_root_name="data/vox1/wav/" ,
device=device,
transform_pipeline="MFCC,CMVN")
scores = cosine_scoring(xv_stat, xv_stat,
Ndx('h5f/ndx_test.h5'),
wccn=None, check_missing=True)
tar, non = scores.get_tar_non(Key('h5f/key_test.h5'))
pmiss, pfa = rocch(numpy.array(tar).astype(numpy.double), numpy.array(non).astype(numpy.double))
test_eer = rocch2eer(pmiss, pfa)
return val_acc, val_loss, val_eer, test_eer
def get_lr(optimizer):
"""
......@@ -289,15 +336,17 @@ class Xtractor(torch.nn.Module):
]))
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.Linear(3072, 512))
]))
if self.loss == "aam":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("arclinear", ArcLinear(512, int(self.speaker_number), margin=aam_margin, s=aam_s))
]))
self.after_speaker_embedding = ArcMarginProduct(512,
int(self.speaker_number),
s=64,
m=0.2,
easy_margin=True)
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
......@@ -559,17 +608,17 @@ class Xtractor(torch.nn.Module):
#x = torch.div(x, x_norm)
x = l2_norm(x)
if is_eval:
return x
if self.loss == "cce":
x = self.after_speaker_embedding(x)
if is_eval:
return self.after_speaker_embedding(x), x
else:
return self.after_speaker_embedding(x)
elif self.loss == "aam":
if not is_eval:
x = self.after_speaker_embedding(x,target=target)
x = self.after_speaker_embedding(l2_norm(x), target=target), l2_norm(x)
else:
x = self.after_speaker_embedding(x, target=None)
x = self.after_speaker_embedding(l2_norm(x), target=None), l2_norm(x)
return x
......@@ -836,6 +885,18 @@ def xtrain(speaker_number,
best_accuracy = 0.0
best_accuracy_epoch = 1
curr_patience = patience
val_acc, val_loss, val_eer, test_eer = compute_metrics(model,
validation_loader,
device,
[validation_set.__len__(), 512],
speaker_number,
model_archi)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Test EER = {test_eer * 100} %")
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
if curr_patience == 0:
......@@ -851,15 +912,22 @@ def xtrain(speaker_number,
tb_writer=writer)
# Add the cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Cross validation accuracy = {accuracy} %")
val_acc, val_loss, val_eer, test_eer = compute_metrics(model,
validation_loader,
device,
[validation_set.__len__(), 512],
speaker_number,
model_archi)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %")
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
# remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
is_best = val_acc > best_accuracy
best_accuracy = max(val_acc, best_accuracy)
if type(model) is Xtractor:
save_checkpoint({
......@@ -923,7 +991,7 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
optimizer.zero_grad()
if loss_criteria == 'aam':
output = model(data, target=target)
output, _ = model(data, target=target)
else:
output = model(data, target=None)
......@@ -974,7 +1042,7 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
return model
def cross_validation(model, validation_loader, device):
def cross_validation(model, validation_loader, device, validation_shape):
"""
:param model:
......@@ -992,6 +1060,8 @@ def cross_validation(model, validation_loader, device):
accuracy = 0.0
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
embeddings = torch.zeros(validation_shape)
classes = torch.zeros([validation_shape[0]])
with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
......@@ -999,16 +1069,30 @@ def cross_validation(model, validation_loader, device):
data = data.squeeze().to(device)
if loss_criteria == "aam":
output = model(data, target=target)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
else:
output = model(data, target=None)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
batch_embeddings = l2_norm(batch_embeddings)
accuracy += (torch.argmax(output.data, 1) == target).sum()
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(output, target)
loss += criterion(batch_predictions, target)
embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.cpu().detach()
classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.cpu().detach()
mask = ((torch.ger(classes.to(device).float() + 1,
(1 / (classes.to(device).float() + 1))) == 1).long() * 2 - 1).float().cpu().numpy()
scores = torch.mm(embeddings.to(device), embeddings.to(device).T).cpu().numpy()
scores = scores[numpy.tril_indices(scores.shape[0], -1)]
mask = mask[numpy.tril_indices(mask.shape[0], -1)]
negatives = scores[numpy.argwhere(mask == -1)][:, 0].astype(float)
positives = scores[numpy.argwhere(mask == 1)][:, 0].astype(float)
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
pmiss, pfa = rocch(numpy.array(positives).astype(numpy.double), numpy.array(negatives).astype(numpy.double))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
loss.cpu().numpy() / ((batch_idx + 1) * batch_size), rocch2eer(pmiss, pfa)
def extract_embeddings(idmap_name,
......@@ -1080,7 +1164,7 @@ def extract_embeddings(idmap_name,
for idx, (data, mod, seg, start, stop) in tqdm.tqdm(enumerate(dataloader)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
vec = model(data.to(device), is_eval=True)
preds, vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment