Commit 34da4db1 authored by Félix Michaud's avatar Félix Michaud
Browse files

dataloader

parent a668b588
......@@ -67,15 +67,46 @@ def warpgrid(HO, WO, warp=True):
def create_im(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid(256, magim.shape[3], warp=True))
magim = torch.log(F.grid_sample(magim, grid_warp))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
def create_mask(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid(264, 52, warp=True))
magim = torch.log(F.grid_sample(magim, grid_warp))
grid_warp = torch.from_numpy(warpgrid(256, 44, warp=True))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
#kernel size:5, padding:3, image size:[264, 52]
#kernel size:3, padding:1, image size[256, 44]
#depends on the overlap of the stft
def freq_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(rows/40), int((rows/10)))
frame = np.zeros([fact1, columns])
#position of the band on the y axis
pos = random.randint(0, rows-fact1-1)
up = np.ones([pos-1, columns])
down = np.ones([rows-(pos+fact1)+1, columns])
mask = np.concatenate((up, frame, down), axis=0)
masked = spec * mask
return masked
def time_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(columns/40), int((columns/10)))
frame = np.zeros([rows, fact1])
#position of the band on the x axis
pos = random.randint(0, columns-fact1-1)
left = np.ones([rows, pos-1])
right = np.ones([rows, columns-(pos+fact1)+1])
mask = np.concatenate((left, frame, right), axis=1)
masked = spec * mask
return masked
class Dataset(data.Dataset):
'Characterizes a dataset for PyTorch'
......@@ -100,13 +131,31 @@ class Dataset(data.Dataset):
if len(list(dict_classes.keys() )) == 0:
print("** WARNING ** No data loaded from " + path)
return dict_classes
def mixed_audio_augmentation(self,audio):
classe_noise = random.randint(0, len(list(self.dict_noises.keys()))-1)
classe_noise = list(self.dict_noises.keys())[classe_noise]
filename_noise = self.dict_noises[classe_noise][random.randint(0, len(self.dict_noises[classe_noise])-1)]
audio_noise, sr = load_file(filename_noise)
return audio + audio_noise
coeff = int(np.ceil(np.max(audio)*random.choice([1, 2, 3, 4, 5, 6, 7])))
return audio + (audio_noise)*coeff
#apply randomly at list 1 band on the spectrogram
def spec_augmentation(self, spec):
spec = spec.numpy()
n = random.randint(0, 2)
if n == 0:
t = random.randint(0, 1)
if t == 1:
spec = time_mask(spec)
if t == 0:
spec = freq_mask(spec)
else:
for ii in range(n):
spec = time_mask(spec)
spec = freq_mask(spec)
return torch.from_numpy(spec)
def __len__(self):
'Denotes the total number of samples'
......@@ -122,7 +171,7 @@ class Dataset(data.Dataset):
classe_name = list(self.dict_classes.keys())[cl]
idx = int(random.random() * len(self.dict_classes[classe_name]) )
filename = self.dict_classes[classe_name][idx]
files.append([classe_name+'class', filename])
files.append([classe_name, filename])
return files
'[class_name, filename, [mask], [magnitude], [phase] ]'
......@@ -147,11 +196,14 @@ class Dataset(data.Dataset):
if self.augmentation:
audio_mix = self.mixed_audio_augmentation(audio_mix)
mag_mix, phase_max = _stft(audio_mix)
mag_mix = mag_mix.squeeze(0).squeeze(0)
mags_mix = create_im(mag_mix)
mags_mix = mags_mix.squeeze(0)
return [mags_mix, files]
mag_mix, phase_mix = _stft(audio_mix)
mag_mix = mag_mix.squeeze(0).squeeze(0)
print(mag_mix.size(),'mag_mix before spec' )
mag_mix = self.spec_augmentation(mag_mix)
print(mag_mix.size(),'mag_mix after spec' )
# mags_mix = create_im(mag_mix)
# mags_mix = mags_mix.squeeze(0)
# return [mags_mix, phase_mix, files]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment