Commit 7c0073ee authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

bugfix validation

parent 3a2b9935
......@@ -208,7 +208,7 @@ class SideSet(Dataset):
chunk_nb = len(possible_starts)
else:
chunk_nb = min(len(possible_starts), chunk_per_segment)
starts = numpy.random.permutation(possible_starts)[:chunk_nb] / self.sample_rate
starts = numpy.random.permutation(possible_starts)[:chunk_nb]
# Once we know how many segments are selected, create the other fields to fill the DataFrame
for ii in range(chunk_nb):
......@@ -256,13 +256,16 @@ class SideSet(Dataset):
# TODO is this required ?
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
original_start = int(current_session['start'])
lowest_shift = self.overlap/2
highest_shift = self.overlap/2
if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
if self.overlap > 0:
lowest_shift = self.overlap/2
highest_shift = self.overlap/2
if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
else:
start_frame = original_start
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
......
......@@ -1105,7 +1105,7 @@ def xtrain(speaker_number,
test_eer = 100.
classes = torch.ByteTensor(validation_set.sessions['speaker_idx'].to_numpy())
classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
mask = classes.unsqueeze(1) == classes.unsqueeze(1).T
tar_indices = torch.tril(mask, -1).numpy()
non_indices = torch.tril(~mask, -1).numpy()
......@@ -1302,6 +1302,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
criterion = torch.nn.CrossEntropyLoss()
embeddings = torch.zeros(validation_shape)
#classes = torch.zeros([validation_shape[0]])
cursor = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
target = target.squeeze()
......@@ -1318,8 +1319,9 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
batch_embeddings = l2_norm(batch_embeddings)
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(batch_predictions, target)
embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.detach().cpu()
#classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu()
embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
#classes[cursor:cursor + batch_size] = target.detach().cpu()
cursor += batch_size
#print(classes.shape[0])
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment