Commit 7c0073ee authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

bugfix validation

parent 3a2b9935
...@@ -208,7 +208,7 @@ class SideSet(Dataset): ...@@ -208,7 +208,7 @@ class SideSet(Dataset):
chunk_nb = len(possible_starts) chunk_nb = len(possible_starts)
else: else:
chunk_nb = min(len(possible_starts), chunk_per_segment) chunk_nb = min(len(possible_starts), chunk_per_segment)
starts = numpy.random.permutation(possible_starts)[:chunk_nb] / self.sample_rate starts = numpy.random.permutation(possible_starts)[:chunk_nb]
# Once we know how many segments are selected, create the other fields to fill the DataFrame # Once we know how many segments are selected, create the other fields to fill the DataFrame
for ii in range(chunk_nb): for ii in range(chunk_nb):
...@@ -256,13 +256,16 @@ class SideSet(Dataset): ...@@ -256,13 +256,16 @@ class SideSet(Dataset):
# TODO is this required ? # TODO is this required ?
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}") nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
original_start = int(current_session['start']) original_start = int(current_session['start'])
lowest_shift = self.overlap/2 if self.overlap > 0:
highest_shift = self.overlap/2 lowest_shift = self.overlap/2
if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2): highest_shift = self.overlap/2
lowest_shift = int(original_start - current_session['file_start']*self.sample_rate) if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2: lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number)) if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift)) highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
else:
start_frame = original_start
if start_frame + self.sample_number >= nfo.frames: if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1) start_frame = numpy.min(nfo.frames - self.sample_number - 1)
......
...@@ -1105,7 +1105,7 @@ def xtrain(speaker_number, ...@@ -1105,7 +1105,7 @@ def xtrain(speaker_number,
test_eer = 100. test_eer = 100.
classes = torch.ByteTensor(validation_set.sessions['speaker_idx'].to_numpy()) classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
mask = classes.unsqueeze(1) == classes.unsqueeze(1).T mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
tar_indices = torch.tril(mask, -1).numpy() tar_indices = torch.tril(mask, -1).numpy()
non_indices = torch.tril(~mask, -1).numpy() non_indices = torch.tril(~mask, -1).numpy()
...@@ -1302,6 +1302,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind ...@@ -1302,6 +1302,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
embeddings = torch.zeros(validation_shape) embeddings = torch.zeros(validation_shape)
#classes = torch.zeros([validation_shape[0]]) #classes = torch.zeros([validation_shape[0]])
cursor = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)): for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
target = target.squeeze() target = target.squeeze()
...@@ -1318,8 +1319,9 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind ...@@ -1318,8 +1319,9 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
batch_embeddings = l2_norm(batch_embeddings) batch_embeddings = l2_norm(batch_embeddings)
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum() accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(batch_predictions, target) loss += criterion(batch_predictions, target)
embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.detach().cpu() embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
#classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu() #classes[cursor:cursor + batch_size] = target.detach().cpu()
cursor += batch_size
#print(classes.shape[0]) #print(classes.shape[0])
local_device = "cpu" if embeddings.shape[0] > 3e4 else device local_device = "cpu" if embeddings.shape[0] > 3e4 else device
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment