Commit d8ac6322 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 7bbe5324
......@@ -293,8 +293,8 @@ class IdMapSet(Dataset):
idmap_name,
data_path,
file_extension,
transform_pipeline={},
sliding_window=True,
transform_pipeline="",
sliding_window=False,
window_len=24000,
window_shift=8000,
sample_rate=16000,
......@@ -358,9 +358,9 @@ class IdMapSet(Dataset):
start = max(0, int(middle - (self.min_sample_nb / 2)))
duration = self.min_sample_nb
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=duration)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=duration)
speech += 10e-6 * torch.randn(speech.shape)
......
......@@ -226,7 +226,7 @@ def test_metrics(model,
key_test_filename = 'h5f/key_test.h5'
data_root_name='/lium/scratch/larcher/voxceleb1/test/wav'
transform_pipeline = dict()
transform_pipeline = ""
xv_stat = extract_embeddings(idmap_name=idmap_test_filename,
model_filename=model,
......@@ -837,9 +837,6 @@ def xtrain(speaker_number,
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict)
print("Modifiy margin: 0.4")
model.after_speaker_embedding.m = 0.4
# Freeze required layers
for name, param in model.named_parameters():
......@@ -1201,7 +1198,6 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
:return:
"""
model.eval()
print("In cross validation")
if isinstance(model, Xtractor):
loss_criteria = model.loss
else:
......@@ -1214,9 +1210,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
classes = torch.zeros([validation_shape[0]])
with torch.no_grad():
#for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
print("In cross validation 2")
for batch_idx, (data, target) in enumerate(validation_loader):
print("In cross validation 3")
batch_size = target.shape[0]
target = target.squeeze().to(device)
data = data.squeeze().to(device)
......@@ -1254,7 +1248,7 @@ def extract_embeddings(idmap_name,
data_root_name,
device,
file_extension="wav",
transform_pipeline={},
transform_pipeline="",
frame_shift=0.01,
frame_duration=0.025,
extract_after_pooling=False,
......@@ -1306,7 +1300,6 @@ def extract_embeddings(idmap_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
min_duration=(model_cs + 2) * frame_shift * 2
)
......@@ -1602,7 +1595,7 @@ def extract_sliding_embedding(idmap_name,
segset += [seg,] * data.shape[0]
starts += [numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift),]
REPRENDRE ICI
#REPRENDRE ICI
# Create the StatServer
embeddings = StatServer()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment