Commit 39759e1a authored by Félix Michaud's avatar Félix Michaud
Browse files

before image data augmentation

parent ba0b9f1d
......@@ -67,16 +67,46 @@ def warpgrid(HO, WO, warp=True):
def create_im(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid(256, magim.shape[3], warp=True))
magim = torch.log(F.grid_sample(magim, grid_warp))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
def create_mask(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid(264, 52, warp=True))
magim = torch.log(F.grid_sample(magim, grid_warp))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
#kernel size:5, padding:3, image size:[264, 52]
def freq_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(rows/40), int((rows/10)))
frame = np.zeros([fact1, columns])
#position of the band on the y axis
pos = random.randint(0, rows-fact1-1)
up = np.ones([pos-1, columns])
down = np.ones([rows-(pos+fact1)+1, columns])
mask = torch.from_numpy(np.concatenate((up, frame, down), axis=0)).float()
masked = spec * mask
return masked
def time_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(columns/40), int((columns/10)))
frame = np.zeros([rows, fact1])
#position of the band on the x axis
pos = random.randint(0, columns-fact1-1)
left = np.ones([rows, pos-1])
right = np.ones([rows, columns-(pos+fact1)+1])
mask = torch.from_numpy(np.concatenate((left, frame, right), axis=1)).float()
print(spec.type(), 'spec type')
print(mask.type(), 'mask type')
masked = spec * mask
return masked
class Dataset(data.Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, path, nb_classes=2, augmentation=True, path_background="./noises"):
......@@ -108,6 +138,22 @@ class Dataset(data.Dataset):
audio_noise, sr = load_file(filename_noise)
return audio + audio_noise
def spec_augmentation(self, spec):
n = random.randint(0, 2)
if n == 0:
t = random.randint(0, 1)
if t == 1:
spec = time_mask(spec)
if t == 0:
spec = freq_mask(spec)
else:
for ii in range(n):
spec = time_mask(spec)
spec = freq_mask(spec)
return spec
def __len__(self):
'Denotes the total number of samples'
# return len(self.dict_classes)
......@@ -122,7 +168,7 @@ class Dataset(data.Dataset):
classe_name = list(self.dict_classes.keys())[cl]
idx = int(random.random() * len(self.dict_classes[classe_name]) )
filename = self.dict_classes[classe_name][idx]
files.append([classe_name+'class', filename])
files.append([classe_name, filename])
return files
'[class_name, filename, [mask], [magnitude], [phase] ]'
......@@ -149,6 +195,7 @@ class Dataset(data.Dataset):
mag_mix, phase_max = _stft(audio_mix)
mag_mix = mag_mix.squeeze(0).squeeze(0)
mag_mix = self.spec_augmentation(mag_mix)
mags_mix = create_im(mag_mix)
mags_mix = mags_mix.squeeze(0)
return [mags_mix, files]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment