Commit 80656610 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

update extract_embeddings

parent 5491cb86
...@@ -360,7 +360,7 @@ class IdMapSet(Dataset): ...@@ -360,7 +360,7 @@ class IdMapSet(Dataset):
transform_number=1, transform_number=1,
sliding_window=False, sliding_window=False,
window_len=3., window_len=3.,
window_shift=1., window_shift=1.5,
sample_rate=16000, sample_rate=16000,
min_duration=0.165 min_duration=0.165
): ):
...@@ -378,18 +378,19 @@ class IdMapSet(Dataset): ...@@ -378,18 +378,19 @@ class IdMapSet(Dataset):
self.file_extension = file_extension self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0] self.len = self.idmap.leftids.shape[0]
self.transformation = transform_pipeline self.transformation = transform_pipeline
self.min_sample_nb = min_duration * sample_rate self.min_duration = min_duration
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.window_len = window_len self.window_len = int(window_len * self.sample_rate)
self.window_shift = window_shift self.window_shift = int(window_shift * self.sample_rate)
self.transform_number = transform_number self.transform_number = transform_number
self.noise_df = None self.noise_df = None
if "add_noise" in self.transformation: if "add_noise" in self.transformation:
# Load the noise dataset, filter according to the duration # Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"]) noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
self.noise_df = noise_df.set_index(noise_df.type) tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = tmp_df['file_id'].tolist()
self.rir_df = None self.rir_df = None
if "add_reverb" in self.transformation: if "add_reverb" in self.transformation:
...@@ -403,18 +404,19 @@ class IdMapSet(Dataset): ...@@ -403,18 +404,19 @@ class IdMapSet(Dataset):
:param index: :param index:
:return: :return:
""" """
# Read start and stop and convert to time in seconds
if self.idmap.start[index] is None: if self.idmap.start[index] is None:
start = 0 start = 0
else: else:
start = int(self.idmap.start[index]) * 160 start = int(self.idmap.start[index] * 0.01 * self.sample_rate)
if self.idmap.stop[index] is None: if self.idmap.stop[index] is None:
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}") speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
duration = int(speech.shape[1] - start) duration = int(speech.shape[1] - start)
else: else:
duration = int(self.idmap.stop[index]) * 160 - start duration = int(self.idmap.stop[index] * 0.01) * self.sample_rate - start
# add this in case the segment is too short # add this in case the segment is too short
if duration <= self.min_sample_nb: if duration <= self.self.min_duration * self.sample_rate:
middle = start + duration // 2 middle = start + duration // 2
start = max(0, int(middle - (self.min_sample_nb / 2))) start = max(0, int(middle - (self.min_sample_nb / 2)))
duration = int(self.min_sample_nb) duration = int(self.min_sample_nb)
...@@ -426,7 +428,7 @@ class IdMapSet(Dataset): ...@@ -426,7 +428,7 @@ class IdMapSet(Dataset):
speech += 10e-6 * torch.randn(speech.shape) speech += 10e-6 * torch.randn(speech.shape)
if self.sliding_window: if self.sliding_window:
speech = speech.squeeze().unfold(0,self.window_len,self.window_shift) speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
if len(self.transformation.keys()) > 0: if len(self.transformation.keys()) > 0:
speech = data_augmentation(speech, speech = data_augmentation(speech,
......
...@@ -1449,6 +1449,14 @@ def train_epoch(model, ...@@ -1449,6 +1449,14 @@ def train_epoch(model,
else: else:
output, _ = model(data, target=None) output, _ = model(data, target=None)
loss = criterion(output, target) loss = criterion(output, target)
scaler.scale(loss).backward()
if clipping:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
scaler.step(optimizer)
scaler.update()
else: else:
if loss_criteria == 'aam': if loss_criteria == 'aam':
output, _ = model(data, target=target) output, _ = model(data, target=target)
...@@ -1461,18 +1469,9 @@ def train_epoch(model, ...@@ -1461,18 +1469,9 @@ def train_epoch(model,
output, _ = model(data, target=None) output, _ = model(data, target=None)
loss = criterion(output, target) loss = criterion(output, target)
#if not torch.isnan(loss):
if True:
if scaler is not None:
scaler.scale(loss).backward()
if clipping:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
scaler.step(optimizer)
scaler.update()
else:
loss.backward() loss.backward()
optimizer.step() optimizer.step()
running_loss += loss.item() running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum() accuracy += (torch.argmax(output.data, 1) == target).sum()
...@@ -1489,18 +1488,6 @@ def train_epoch(model, ...@@ -1489,18 +1488,6 @@ def train_epoch(model,
loss.item(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))) 100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
#else:
# save_checkpoint({
# 'epoch': training_monitor.current_epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': 0.0,
# 'scheduler': 0.0
# }, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
# with open("batch_loss_NAN.pkl", "wb") as fh:
# pickle.dump(data.cpu(), fh)
# import sys
# sys.exit()
running_loss = 0.0 running_loss = 0.0
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
...@@ -1570,12 +1557,10 @@ def extract_embeddings(idmap_name, ...@@ -1570,12 +1557,10 @@ def extract_embeddings(idmap_name,
model_filename, model_filename,
data_root_name, data_root_name,
device, device,
loss,
file_extension="wav", file_extension="wav",
transform_pipeline="", transform_pipeline="",
frame_shift=0.01, frame_shift=1.5,
frame_duration=0.025, frame_duration=3.,
extract_after_pooling=False,
num_thread=1, num_thread=1,
mixed_precision=False): mixed_precision=False):
""" """
...@@ -1622,7 +1607,7 @@ def extract_embeddings(idmap_name, ...@@ -1622,7 +1607,7 @@ def extract_embeddings(idmap_name,
data_path=data_root_name, data_path=data_root_name,
file_extension=file_extension, file_extension=file_extension,
transform_pipeline=transform_pipeline, transform_pipeline=transform_pipeline,
min_duration=(model_cs + 2) * frame_shift * 2 min_duration=1.5
) )
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment