Commit 80656610 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

update extract_embeddings

parent 5491cb86
......@@ -360,7 +360,7 @@ class IdMapSet(Dataset):
transform_number=1,
sliding_window=False,
window_len=3.,
window_shift=1.,
window_shift=1.5,
sample_rate=16000,
min_duration=0.165
):
......@@ -378,18 +378,19 @@ class IdMapSet(Dataset):
self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0]
self.transformation = transform_pipeline
self.min_sample_nb = min_duration * sample_rate
self.min_duration = min_duration
self.sample_rate = sample_rate
self.sliding_window = sliding_window
self.window_len = window_len
self.window_shift = window_shift
self.window_len = int(window_len * self.sample_rate)
self.window_shift = int(window_shift * self.sample_rate)
self.transform_number = transform_number
self.noise_df = None
if "add_noise" in self.transformation:
# Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
self.noise_df = noise_df.set_index(noise_df.type)
tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = tmp_df['file_id'].tolist()
self.rir_df = None
if "add_reverb" in self.transformation:
......@@ -403,18 +404,19 @@ class IdMapSet(Dataset):
:param index:
:return:
"""
# Read start and stop and convert to time in seconds
if self.idmap.start[index] is None:
start = 0
else:
start = int(self.idmap.start[index]) * 160
start = int(self.idmap.start[index] * 0.01 * self.sample_rate)
if self.idmap.stop[index] is None:
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
duration = int(speech.shape[1] - start)
else:
duration = int(self.idmap.stop[index]) * 160 - start
duration = int(self.idmap.stop[index] * 0.01) * self.sample_rate - start
# add this in case the segment is too short
if duration <= self.min_sample_nb:
if duration <= self.self.min_duration * self.sample_rate:
middle = start + duration // 2
start = max(0, int(middle - (self.min_sample_nb / 2)))
duration = int(self.min_sample_nb)
......@@ -426,7 +428,7 @@ class IdMapSet(Dataset):
speech += 10e-6 * torch.randn(speech.shape)
if self.sliding_window:
speech = speech.squeeze().unfold(0,self.window_len,self.window_shift)
speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
if len(self.transformation.keys()) > 0:
speech = data_augmentation(speech,
......
......@@ -1449,6 +1449,14 @@ def train_epoch(model,
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
scaler.scale(loss).backward()
if clipping:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
scaler.step(optimizer)
scaler.update()
else:
if loss_criteria == 'aam':
output, _ = model(data, target=target)
......@@ -1461,46 +1469,25 @@ def train_epoch(model,
output, _ = model(data, target=None)
loss = criterion(output, target)
#if not torch.isnan(loss):
if True:
if scaler is not None:
scaler.scale(loss).backward()
if clipping:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0]
training_monitor.update(training_loss=loss.item(),
training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))
training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
training_monitor.current_epoch,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(),
loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
#else:
# save_checkpoint({
# 'epoch': training_monitor.current_epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': 0.0,
# 'scheduler': 0.0
# }, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
# with open("batch_loss_NAN.pkl", "wb") as fh:
# pickle.dump(data.cpu(), fh)
# import sys
# sys.exit()
loss.backward()
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0]
training_monitor.update(training_loss=loss.item(),
training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))
training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
training_monitor.current_epoch,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(),
loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
running_loss = 0.0
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
......@@ -1570,12 +1557,10 @@ def extract_embeddings(idmap_name,
model_filename,
data_root_name,
device,
loss,
file_extension="wav",
transform_pipeline="",
frame_shift=0.01,
frame_duration=0.025,
extract_after_pooling=False,
frame_shift=1.5,
frame_duration=3.,
num_thread=1,
mixed_precision=False):
"""
......@@ -1622,7 +1607,7 @@ def extract_embeddings(idmap_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
min_duration=(model_cs + 2) * frame_shift * 2
min_duration=1.5
)
dataloader = DataLoader(dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment