Commit e6fa65dc authored by Félix Michaud's avatar Félix Michaud
Browse files

Dataloader à la volee

parent 72ee634d
import torch
from torch.utils import data
import os
import fnmatch
import librosa
from scipy import signal
import numpy as np
import torch.nn.functional as F
import random
import collections
#load 1 audio
def load_file(file):
audio_raw, rate = librosa.load(file, sr=None, mono=True)
return audio_raw, rate
def filt(audio_raw, rate):
band = [800, 7000] # Desired pass band, Hz
trans_width = 100 # Width of transition from pass band to stop band, Hz
numtaps = 250 # Size of the FIR filter.
edges = [0, band[0] - trans_width,
band[0], band[1],
band[1] + trans_width, 0.5*rate]
taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')
sig_filt = signal.lfilter(taps, 1, audio_raw)
return sig_filt
# return the mag and phase for 1 stft in tensor
def _stft(audio):
spec = librosa.stft(
audio, n_fft=1022, hop_length=256)
amp = np.abs(spec)
phase = np.angle(spec)
W = np.shape(amp)[0]
H = np.shape(amp)[1]
tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
tch_mag[0, 0, :, :] = torch.from_numpy(amp)
tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
tch_phase[0, 0, :, :] = torch.from_numpy(phase)
return tch_mag, tch_phase
#return 1 torch matrix of dimensions of the stft
def threshold(mag):
gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
av = np.mean(mag[0, 0].numpy())
vari = np.var(mag[0, 0].numpy())
param = av + np.sqrt(vari)*2 #threshold
gt_mask = (mag[0, 0] > param).float()
return gt_mask
#create the grid for the image
def warpgrid(HO, WO, warp=True):
# meshgridu
x = np.linspace(-1, 1, WO)
y = np.linspace(-1, 1, HO)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((1, HO, WO, 2))
grid_x = xv
grid_y = yv
grid[0, :, :, 0] = grid_x
grid[0, :, :, 1] = grid_y
grid = grid.astype(np.float32)
return grid
#create image from the grid
def create_im(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid(256, magim.shape[3], warp=True))
magim = F.grid_sample(magim, grid_warp)
return magim
class Dataset(data.Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, path, nb_classes=2, augmentation=True, path_background="./noises"):
self.dict_classes = self.load_data(path)
self.nb_classes = nb_classes
self.augmentation = augmentation
self.path_background = path_background
if self.augmentation:
self.dict_noises = self.load_data(path_background)
def load_data(self, path, ext='wav'):
dict_classes = collections.OrderedDict()
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
classe = root.split("/")[-1]
if classe in dict_classes.keys():
dict_classes[classe].append(os.path.join(root, filename))
else:
dict_classes[classe] = [os.path.join(root, filename)]
if len(list(dict_classes.keys() )) == 0:
print("** WARNING ** No data loaded from " + path)
return dict_classes
def mixed_audio_augmentation(self,audio):
classe_noise = random.randint(0, len(list(self.dict_noises.keys()))-1)
classe_noise = list(self.dict_noises.keys())[classe_noise]
filename_noise = self.dict_noises[classe_noise][random.randint(0, len(self.dict_noises[classe_noise])-1)]
audio_noise, sr = load_file(filename_noise)
return audio + audio_noise
def __len__(self):
'Denotes the total number of samples'
return len(self.dict_classes)
def load_files(self, nb_classes):
files = []
for cl in range(nb_classes):
'Load audio file'
classe_name = list(self.dict_classes.keys())[cl]
idx = int(random.random() * len(self.dict_classes[classe_name]) )
filename = self.dict_classes[classe_name][idx]
files.append([classe_name, filename])
return files
'[class_name, filename, [mask], [magnitude], [phase] ]'
def __getitem__(self, index):
'Load audio file for each classe'
files = self.load_files(self.nb_classes)
audio_mix = None
for f in files:
audio_raw, rate = load_file(f[1])
audio = filt(audio_raw, rate)
mag, phase = _stft(audio)
mag = threshold(mag)
mask = create_im(mag)
f.append(mask)
f.append(mag)
f.append(phase)
if audio_mix is None:
audio_mix = audio
else:
audio_mix += audio
'Build mixed mask'
if self.augmentation:
audio_mix = self.mixed_audio_augmentation(audio_mix)
mag_mix, phase_max = _stft(audio_mix)
mag_mix = mag_mix.squeeze(0).squeeze(0)
mask_mix = create_im(mag_mix)
return [mask_mix, files]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment