Commit 4283067f authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent ee92acf2
......@@ -232,6 +232,10 @@ class SideSet(Dataset):
self.transform["add_noise"] = self.transformation["add_noise"]
if "add_reverb" in transforms:
self.transform["add_reverb"] = self.transformation["add_reverb"]
if "codec" in transforms:
self.transform["codec"] = []
if "phone_filtering" in transforms:
self.transform["phone_filtering"] = []
self.noise_df = None
if "add_noise" in self.transform:
......
......@@ -1433,7 +1433,6 @@ def xtrain(speaker_number,
pretrained_dict = checkpoint["model_state_dict"]
for part in reset_parts:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
new_model_dict = model.state_dict()
new_model_dict.update(pretrained_dict)
......@@ -1936,17 +1935,10 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
batch_size = target.shape[0]
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == 'aam':
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
elif loss_criteria == 'aps':
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
else:
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
batch_embeddings = l2_norm(batch_embeddings)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(batch_predictions, target)
embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
#classes[cursor:cursor + batch_size] = target.detach().cpu()
cursor += batch_size
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment