Commit 69df51c0 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

bugfix, enhance add_noise

parent 795e4559
......@@ -455,11 +455,11 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
"""
# Select the data augmentation randomly
if len(transform_dict.keys) >= transform_number:
aug_idx = numpy.arange(len(transform_dict.keys))
if len(transform_dict.keys()) >= transform_number:
aug_idx = numpy.arange(len(transform_dict.keys()))
else:
aug_idx = numpy.random.randint(0, len(transform_dict), transform_number)
augmentations = list(numpy.array(transform_dict.keys())[aug_idx])
aug_idx = random.choice(numpy.arange(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "phone_filtering" in augmentations:
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
......@@ -473,11 +473,11 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch()
rate = numpy.random.uniform(0.8,1.2)
rate = random.uniform(0.8,1.2)
speech = strech(speech, rate)
if "add_reverb" in augmentations:
rir_nfo = numpy.random.randint(0, len(rir_df))
rir_nfo = random.randrange(len(rir_df))
rir_fn = transform_dict["add_noise"]["data_path"] + "/" + rir_nfo + ".wav"
rir, rir_fs = torchaudio.load(rir_fn)
rir = rir[rir_nfo[1], :] #keep selected channel
......@@ -486,40 +486,43 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
if "add_noise" in augmentations:
# Pick a noise sample from the noise_df
noise_row = noise_df.sample()
noise_row = noise_df.iloc[random.randrange(noise_df.shape[0])]
noise_type = noise_row['type']
noise_start = noise_row['start']
noise_duration = noise_row['duration']
assert noise_type in ['music', 'noise', 'speech']
noise_file_id = noise_row['file_id']
# Pick a SNR level
# TODO make SNRs configurable by noise type
if noise_type == 'music':
snr_db = random.choice(transform_dict["add_noise"]["snr"])
snr_db = random.randint(5, 15)
elif noise_type == 'noise':
snr_db = numpy.random.randint(0, 16)
snr_db = random.randint(0, 15)
else:
snr_db = numpy.random.randint(13, 20)
snr_db = random.randint(13, 20)
if noise_duration * sample_rate > speech.shape[1]:
frame_offset = 0
# We force frame_offset to stay in the 20 first seconds of the file, otherwise it takes too long to load
frame_offset = random.randrange(noise_start * sample_rate, min(int(20*sample_rate), int((noise_start + noise_duration) * sample_rate - speech.shape[1])))
else:
frame_offset = numpy.random.randint(noise_start * sample_rate, speech.shape[1] - noise_duration * sample_rate)
frame_offset = noise_start * sample_rate
noise_fn = transform_dict["add_noise"]["data_path"] + "/" + noise_file_id + ".wav"
if noise_duration * sample_rate > speech.shape[1]:
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech.shape[1]))
else:
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
noise_fn = transform_dict["add_noise"]["data_path"] + "/" + noise_row['file_id'] + ".wav"
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=frame_offset, num_frames=frame_offset + noise_duration * sample_rate)
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
if numpy.random.randint(0, 2) == 1:
noise = torch.flip(noise, dims=[0, 1])
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
if noise.shape[1] < speech.shape[1]:
noise = torch.tensor(numpy.resize(noise.numpy(), speech.shape))
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
......
......@@ -220,9 +220,13 @@ class SideSet(Dataset):
self.sessions = pandas.DataFrame.from_dict(df_dict)
self.len = len(self.sessions)
self.transform = []
self.transform = dict()
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
transforms = self.transformation["pipeline"].split(',')
if "add_noise" in transforms:
self.transform["add_noise"] = self.transformation["add_noise"]
if "add_reverb" in transforms:
self.transform["add_reverb"] = self.transformation["add_reverb"]
self.noise_df = None
if "add_noise" in self.transform:
......@@ -243,8 +247,9 @@ class SideSet(Dataset):
# Check the size of the file
current_session = self.sessions.iloc[index]
# TODO is this required ?
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
start_frame = int(current_session['start'] * self.sample_rate)
start_frame = int(current_session['start'])
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
......
......@@ -300,11 +300,14 @@ class AttentivePooling(torch.nn.Module):
"""
"""
# TODO Make global_context configurable (True/False)
# TODO Make convolution parameters configurable
super(AttentivePooling, self).__init__()
self.attention = torch.nn.Sequential(
torch.nn.Conv1d(num_channels * 10 * 3, 128, kernel_size=1),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(128),
torch.nn.Tanh(),
torch.nn.Conv1d(128, num_channels * 10, kernel_size=1),
torch.nn.Softmax(dim=2),
)
......@@ -321,9 +324,11 @@ class AttentivePooling(torch.nn.Module):
:param x:
:return:
"""
global_context = self.global_context(x).unsqueeze(2).repeat(1, 1, x.shape[-1])
w = self.attention(torch.cat([x, global_context], dim=1))
#w = self.attention(x)
mu = torch.sum(x * w, dim=2)
rh = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-5) )
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment