Commit 89b36652 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

merge

parent fc1b0895
......@@ -236,7 +236,6 @@ def data_augmentation(speech,
noise += transform(noise_)
noise /= pick_count
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
snr = 10 ** (snr_db / 20)
......@@ -280,9 +279,12 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
noise_fn = data_path + "/" + noise_file_id + ".wav"
if noise_duration * sample_rate > speech_shape[1]:
noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1]))
noise_seg, noise_sr = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1]))
else:
noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
noise_seg, noise_sr = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
assert noise_sr == sample_rate
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
if noise_seg.shape[1] < speech_shape[1]:
noise_seg = torch.tensor(numpy.resize(noise_seg.numpy(), speech_shape))
......
......@@ -204,8 +204,10 @@ class MelSpecFrontEnd(torch.nn.Module):
n_mels=self.melkwargs['n_mels'])
self.CMVN = torch.nn.InstanceNorm1d(self.n_mels)
self.time_masking = torchaudio.transforms.TimeMasking(time_mask_param=5)
self.freq_masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=10)
def forward(self, x):
def forward(self, x, is_eval=False):
"""
:param x:
......@@ -219,6 +221,9 @@ class MelSpecFrontEnd(torch.nn.Module):
out = self.MelSpec(out)+1e-6
out = torch.log(out)
out = self.CMVN(out)
if not is_eval:
out = self.freq_masking(out)
out = self.time_masking(out)
return out
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment