Commit f7052bd4 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 0ae9bb73
......@@ -381,9 +381,9 @@ class Xtractor(torch.nn.Module):
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
self.sequence_network_weight_decay = 0.0
self.before_speaker_embedding_weight_decay = 0.0
self.after_speaker_embedding_weight_decay = 0.0
self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
self.embedding_size = 512
elif model_archi == "resnet34":
......@@ -693,15 +693,15 @@ class Xtractor(torch.nn.Module):
if self.loss == "cce":
if is_eval:
return x
else:
return self.after_speaker_embedding(x), x
else:
return self.after_speaker_embedding(x)
elif self.loss == "aam":
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
if not is_eval:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=None), torch.nn.functional.normalize(x, dim=1)
return x
......@@ -1160,9 +1160,9 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == "aam":
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
else:
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=True)
batch_embeddings = l2_norm(batch_embeddings)
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
......@@ -1683,7 +1683,7 @@ def extract_embeddings(idmap_name,
if data.shape[1] > 20000000:
data = data[...,:20000000]
with torch.cuda.amp.autocast(enabled=mixed_precision):
vec = model(data.to(device), is_eval=True)
preds, vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment