Commit 84493543 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent f7052bd4
......@@ -693,15 +693,15 @@ class Xtractor(torch.nn.Module):
if self.loss == "cce":
if is_eval:
return self.after_speaker_embedding(x), x
return x
else:
return self.after_speaker_embedding(x)
return self.after_speaker_embedding(x), x
elif self.loss == "aam":
if not is_eval:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=None), torch.nn.functional.normalize(x, dim=1)
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
return x
......@@ -1160,9 +1160,9 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == "aam":
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
else:
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
batch_embeddings = l2_norm(batch_embeddings)
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
......@@ -1683,7 +1683,7 @@ def extract_embeddings(idmap_name,
if data.shape[1] > 20000000:
data = data[...,:20000000]
with torch.cuda.amp.autocast(enabled=mixed_precision):
preds, vec = model(data.to(device), is_eval=True)
vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment