Commit cf1afe6c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug and noise noise augmentation

parent 6430129d
......@@ -453,11 +453,11 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
"""
# Select the data augmentation randomly
if len(transform_dict.keys) >= transform_number:
aug_idx = numpy.arange(len(transform_dict.keys))
if len(transform_dict.keys()) >= transform_number:
aug_idx = numpy.arange(len(transform_dict.keys()))
else:
aug_idx = numpy.random.randint(0, len(transform_dict), transform_number)
augmentations = list(numpy.array(transform_dict.keys())[aug_idx])
aug_idx = random.choice(numpy.arange(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "phone_filtering" in augmentations:
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
......@@ -471,11 +471,11 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch()
rate = numpy.random.uniform(0.8,1.2)
rate = random.uniform(0.8,1.2)
speech = strech(speech, rate)
if "add_reverb" in augmentations:
rir_nfo = numpy.random.randint(0, len(rir_df))
rir_nfo = random.randrange(len(rir_df))
rir_fn = transform_dict["add_noise"]["data_path"] + "/" + rir_nfo + ".wav"
rir, rir_fs = torchaudio.load(rir_fn)
rir = rir[rir_nfo[1], :] #keep selected channel
......@@ -483,15 +483,44 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
if "add_noise" in augmentations:
# Pick a noise sample from the noise_df
noise_row = noise_df.iloc[random.randrange(noise_df.shape[0])]
noise_type = noise_row['type']
noise_start = noise_row['start']
noise_duration = noise_row['duration']
noise_file_id = noise_row['file_id']
# Pick a SNR level
snr_db = random.choice(transform_dict["add_noise"]["snr"])
# TODO make SNRs configurable by noise type
if noise_type == 'music':
snr_db = random.randint(5, 15)
elif noise_type == 'noise':
snr_db = random.randint(0, 15)
else:
snr_db = random.randint(13, 20)
if noise_duration * sample_rate > speech.shape[1]:
# We force frame_offset to stay in the 20 first seconds of the file, otherwise it takes too long to load
frame_offset = random.randrange(noise_start * sample_rate, min(int(20*sample_rate), int((noise_start + noise_duration) * sample_rate - speech.shape[1])))
else:
frame_offset = noise_start * sample_rate
noise_fn = transform_dict["add_noise"]["data_path"] + "/" + noise_file_id + ".wav"
if noise_duration * sample_rate > speech.shape[1]:
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech.shape[1]))
else:
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
# Pick a file name from the noise_df
noise_fn = transform_dict["add_noise"]["data_path"] + "/" + random.choice(noise_df) + ".wav"
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=0, num_frames=speech.shape[1])
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
if noise.shape[1] < speech.shape[1]:
noise = torch.tensor(numpy.resize(noise.numpy(), speech.shape))
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
......
......@@ -32,7 +32,6 @@ import pandas
import random
import torch
import torchaudio
torchaudio.set_audio_backend("sox_io")
import tqdm
import soundfile
import yaml
......@@ -166,7 +165,7 @@ class SideSet(Dataset):
self.transformation = dataset["train"]["transformation"]
else:
self.duration = dataset["eval"]["duration"]
self.transformation = dataset["eval"]["transformation"]
self.transformation = dataset["eval"]["transformation"]
self.sample_number = int(self.duration * self.sample_rate)
......
......@@ -226,7 +226,7 @@ def test_metrics(model,
key_test_filename = 'h5f/key_test.h5'
data_root_name='/lium/scratch/larcher/voxceleb1/test/wav'
transform_pipeline = ""
transform_pipeline = dict()
xv_stat = extract_embeddings(idmap_name=idmap_test_filename,
model_filename=model,
......@@ -399,6 +399,7 @@ class Xtractor(torch.nn.Module):
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
self.preprocessor_weight_decay = 0.0002
self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
......@@ -1207,14 +1208,14 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
embeddings = torch.zeros(validation_shape)
classes = torch.zeros([validation_shape[0]])
with torch.no_grad():
#for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
for batch_idx, (data, target) in enumerate(validation_loader):
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1)):
target = target.squeeze()
target = target.to(device)
batch_size = target.shape[0]
target = target.squeeze().to(device)
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == 'aam':
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
elif loss_criteria == 'aps':
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
else:
......@@ -1224,7 +1225,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(batch_predictions, target)
embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.detach().cpu()
classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu()
#classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu()
#print(classes.shape[0])
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment