Commit 524b10ee authored by Anthony Larcher's avatar Anthony Larcher
Browse files

merge

parent 8e62d9aa
......@@ -59,6 +59,7 @@ from ..bosaris import Ndx
from ..statserver import StatServer
from ..iv_scoring import cosine_scoring
from .sincnet import SincNet
from ..bosaris.detplot import rocch, rocch2eer
from .loss import SoftmaxAngularProto, ArcLinear
from .loss import l2_norm
from .loss import ArcMarginProduct
......@@ -240,17 +241,18 @@ def test_metrics(model,
num_thread=num_thread,
mixed_precision=mixed_precision)
scores = cosine_scoring(xv_stat,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True,
device=device)
tar, non = cosine_scoring(xv_stat,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True,
device=device
).get_tar_non(Key(key_test_filename))
k = Key(key_test_filename)
tar, non = scores.get_tar_non(k)
test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
return test_eer
#test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
pmiss, pfa = rocch(tar, non)
return rocch2eer(pmiss, pfa)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
......@@ -1045,13 +1047,16 @@ def xtrain(speaker_number,
test_eer = 100.
classes = torch.LongTensor(validation_set.sessions['speaker_idx'].to_numpy())
classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
tar_indices = torch.tril(mask, -1).numpy()
non_indices = torch.tril(~mask, -1).numpy()
tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
non_indices *= numpy.random.choice([False, True], size=non_indices.shape, p=[1-tar_non_ratio, tar_non_ratio])
#tar_indices *= numpy.random.choice([False, True], size=tar_indices.shape, p=[0.9, 0.1])
#non_indices *= numpy.random.choice([False, True], size=non_indices.shape, p=[0.9, 0.1])
local_device = "cpu" if classes.shape[0] > 3e4 else device
mask = ((torch.ger(classes.to(local_device).float() + 1,
(1 / (classes.to(local_device).float() + 1))) == 1).long() * 2 - 1).float().cpu()
mask = mask.numpy()
mask = mask[numpy.tril_indices(mask.shape[0], -1)]
logging.critical("val tar count : {:d}, non count : {:d}".format(numpy.sum(tar_indices), numpy.sum(non_indices)))
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
......@@ -1070,19 +1075,12 @@ def xtrain(speaker_number,
# Add the cross validation here
if math.fmod(epoch, 1) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], mask, mixed_precision)
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], tar_indices, non_indices, mixed_precision)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
if compute_test_eer:
test_eer = test_metrics(model,
device,
idmap_test_filename=dataset_params["test_set"]["idmap_test_filename"],
ndx_test_filename=dataset_params["test_set"]["ndx_test_filename"],
key_test_filename=dataset_params["test_set"]["key_test_filename"],
data_root_name=dataset_params["test_set"]["data_root_name"],
num_thread=num_thread,
mixed_precision=mixed_precision)
test_eer = test_metrics(model, device, speaker_number, num_thread, mixed_precision)
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Reversed Test EER = {rev_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %")
# remember best accuracy and save checkpoint
......@@ -1230,7 +1228,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
return model
def cross_validation(model, validation_loader, device, validation_shape, mask, mixed_precision=False):
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
"""
:param model:
......@@ -1249,9 +1247,10 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
embeddings = torch.zeros(validation_shape)
classes = torch.zeros([validation_shape[0]])
#classes = torch.zeros([validation_shape[0]])
cursor = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
target = target.squeeze()
target = target.to(device)
batch_size = target.shape[0]
......@@ -1267,18 +1266,22 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(batch_predictions, target)
embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.detach().cpu()
classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu()
embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
#classes[cursor:cursor + batch_size] = target.detach().cpu()
cursor += batch_size
#print(classes.shape[0])
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
scores = torch.mm(embeddings.to(local_device), embeddings.to(local_device).T).cpu().numpy()
scores = scores[numpy.tril_indices(scores.shape[0], -1)]
negatives = scores[numpy.argwhere(mask == -1)][:, 0].astype(float)
positives = scores[numpy.argwhere(mask == 1)][:, 0].astype(float)
embeddings = embeddings.to(local_device)
scores = torch.einsum('ij,kj', embeddings, embeddings).cpu().numpy()
negatives = scores[non_indices]
positives = scores[tar_indices]
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
equal_error_rate = eer(negatives, positives)
#equal_error_rate = eer(negatives, positives)
pmiss, pfa = rocch(positives, negatives)
equal_error_rate = rocch2eer(pmiss, pfa)
return (100. * accuracy.cpu().numpy() / validation_shape[0],
loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment