Commit e7c23b4f authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 03d0e5e1
......@@ -32,6 +32,7 @@ import pandas
import random
import torch
import torchaudio
torchaudio.set_audio_backend("sox_io")
import tqdm
import soundfile
import yaml
......@@ -347,7 +348,7 @@ class IdMapSet(Dataset):
if self.idmap.stop[index] is None:
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
duration = len(speech) - start
duration = speech.shape[1] - start
else:
start = int(self.idmap.start[index])
duration = int(self.idmap.stop[index]) - start
......
......@@ -1193,7 +1193,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
:return:
"""
model.eval()
print("In cross validation")
if isinstance(model, Xtractor):
loss_criteria = model.loss
else:
......@@ -1205,7 +1205,10 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
embeddings = torch.zeros(validation_shape)
classes = torch.zeros([validation_shape[0]])
with torch.no_grad():
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
#for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
print("In cross validation 2")
for batch_idx, (data, target) in enumerate(validation_loader):
print("In cross validation 3")
batch_size = target.shape[0]
target = target.squeeze().to(device)
data = data.squeeze().to(device)
......@@ -1364,7 +1367,7 @@ def extract_embeddings_per_speaker(idmap_name,
model_archi = checkpoint["model_archi"]
model = Xtractor(checkpoint["speaker_number"], model_archi=model_archi)
model = Xtractor(checkpoint["speaker_number"], model_archi=model_archi, loss="aam")
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
......@@ -1429,7 +1432,7 @@ def extract_sliding_embedding(idmap_name,
file_extension="wav",
transform_pipeline=None,
num_thread=1):
"""
:param idmap_name:
:param window_length:
......@@ -1443,7 +1446,7 @@ def extract_sliding_embedding(idmap_name,
:param file_extension:
:param transform_pipeline:
:return:
"""
# From the original IdMap, create the new one to extract x-vectors
if not isinstance(idmap_name, IdMap):
input_idmap = IdMap(idmap_name)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment