Commit 7bbe5324 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of https://git-lium.univ-lemans.fr/Larcher/sidekit into dev_al

parents e7c23b4f 00496ab5
......@@ -1021,6 +1021,14 @@ def xtrain(speaker_number,
test_eer = 100.
classes = torch.LongTensor(validation_set.sessions['speaker_idx'].to_numpy())
local_device = "cpu" if classes.shape[0] > 3e4 else device
mask = ((torch.ger(classes.to(local_device).float() + 1,
(1 / (classes.to(local_device).float() + 1))) == 1).long() * 2 - 1).float().cpu()
mask = mask.numpy()
mask = mask[numpy.tril_indices(mask.shape[0], -1)]
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
if curr_patience == 0:
......@@ -1038,7 +1046,7 @@ def xtrain(speaker_number,
# Add the cross validation here
if math.fmod(epoch, 1) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], mixed_precision)
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], mask, mixed_precision)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
if compute_test_eer:
......@@ -1183,7 +1191,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
return model
def cross_validation(model, validation_loader, device, validation_shape, mixed_precision=False):
def cross_validation(model, validation_loader, device, validation_shape, mask, mixed_precision=False):
"""
:param model:
......@@ -1230,7 +1238,6 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
scores = torch.mm(embeddings.to(local_device), embeddings.to(local_device).T).cpu().numpy()
scores = scores[numpy.tril_indices(scores.shape[0], -1)]
mask = mask[numpy.tril_indices(mask.shape[0], -1)]
negatives = scores[numpy.argwhere(mask == -1)][:, 0].astype(float)
positives = scores[numpy.argwhere(mask == 1)][:, 0].astype(float)
......@@ -1586,11 +1593,16 @@ def extract_sliding_embedding(idmap_name,
desc='xvector extraction',
mininterval=1)):
with torch.cuda.amp.autocast(enabled=mixed_precision):
vec = model(x=data.to(device), is_eval=True)
embeddings.append(vec.detach().cpu())
modelset += [mod,] * embeddings.shape[0]
segset += [seg,] * embeddings.shape[0]
starts += numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift)
data = data.squeeze()
tmp_data = torch.split(data,data.shape[0]//(data.shape[0]//100))
for td in tmp_data:
vec = model(x=td.to(device), is_eval=True)
embeddings.append(vec.detach().cpu())
modelset += [mod,] * data.shape[0]
segset += [seg,] * data.shape[0]
starts += [numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift),]
REPRENDRE ICI
# Create the StatServer
embeddings = StatServer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment