Dataloader_solo.py 10.9 KB
Newer Older
Félix Michaud's avatar
Félix Michaud committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from torch.utils import data
import os
import fnmatch
import librosa
from scipy import signal
import numpy as np
import torch.nn.functional as F
import random
import collections

#load 1 audio
def load_file(file):
    audio_raw, rate = librosa.load(file, sr=22050, mono=True)    
    return audio_raw, rate


def filt(audio_raw, rate):
    band = [800, 7000]  # Desired pass band, Hz
    trans_width = 100    # Width of transition from pass band to stop band, Hz
    numtaps = 250     # Size of the FIR filter.    
    edges = [0, band[0] - trans_width,
     band[0], band[1],
     band[1] + trans_width, 0.5*rate]
    taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')     
    sig_filt = signal.lfilter(taps, 1, audio_raw)
    return sig_filt

# return the mag and phase for 1 stft in tensor
def _stft(audio):
    spec = librosa.stft(
        audio, n_fft=1022, hop_length=256)
    amp = np.abs(spec)
    phase = np.angle(spec)
    W = np.shape(amp)[0]
    H = np.shape(amp)[1]
    tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
    tch_mag[0, 0, :, :] = torch.from_numpy(amp)
    tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
    tch_phase[0, 0, :, :] = torch.from_numpy(phase)    
    return tch_mag, tch_phase

#return 1 torch matrix of dimensions of the stft
def threshold(mag):
    gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
    av = np.mean(mag[0, 0].numpy())     
    vari = np.var(mag[0, 0].numpy())
    param = av + np.sqrt(vari)*2   #threshold                 
    gt_mask = (mag[0, 0] > param).float()
    return gt_mask

#create the grid for the image
def warpgrid_log(HO, WO, warp=True):
    # meshgrid
    x = np.linspace(-1, 1, WO)
    y = np.linspace(-1, 1, HO)
    xv, yv = np.meshgrid(x, y)
    grid = np.zeros((1, HO, WO, 2))
    grid_x = xv
    if warp:
        grid_y = (np.power(21, (yv+1)/2) - 11) / 10
    else:
        grid_y = np.log(yv * 10 + 11) / np.log(21) * 2 - 1
    grid[:, :, :, 0] = grid_x
    grid[:, :, :, 1] = grid_y
    grid = grid.astype(np.float32)
    return grid


#create image from the grid
def create_im(mag):
    magim = mag.unsqueeze(0).unsqueeze(0)
#Zero center data    
    m = torch.mean(magim)
    magim = magim - m
    grid_warp = torch.from_numpy(warpgrid_log(256, magim.shape[3], warp=True))
#    grid_warp = torch.from_numpy(warpgrid_log(384, 192, warp=True))
    magim = F.grid_sample(magim, grid_warp) 
    return torch.from_numpy(np.flipud(magim).copy())


def create_mask(mag):
    magim = mag.unsqueeze(0).unsqueeze(0)
    grid_warp = torch.from_numpy(warpgrid_log(256, magim.shape[3], warp=True))
#    grid_warp = torch.from_numpy(warpgrid_log(264, 52, warp=True)) 
    magim = F.grid_sample(magim, grid_warp)    
    return torch.from_numpy(np.flipud(magim).copy())

Félix Michaud's avatar
Félix Michaud committed
89
#create a band of zeros in the spectrogram on the frequency range
Félix Michaud's avatar
Félix Michaud committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def freq_mask(spec):
    fbank_size = np.shape(spec)
    rows , columns = fbank_size[0], fbank_size[1]
    #width of the band
    fact1 = np.random.randint(int(rows/100), int(rows/60))
    frame = np.zeros([fact1, columns])
    #position of the band on the y axis
    pos = random.randint(10, rows-fact1-1)
    up = np.ones([pos-1, columns])
    down = np.ones([rows-(pos+fact1)+1, columns])
    mask = torch.from_numpy(np.concatenate((up, frame, down), axis=0)).float()
    masked = spec * mask  
    return masked

Félix Michaud's avatar
Félix Michaud committed
104
#create a band of zeros in the spectrogram on the time range
Félix Michaud's avatar
Félix Michaud committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def time_mask(spec):
    fbank_size = np.shape(spec)
    rows , columns = fbank_size[0], fbank_size[1]
    #width of the band
    fact1 = np.random.randint(int(columns/100), int(columns/60))
    frame = np.zeros([rows, fact1])
    #position of the band on the x axis
    pos = random.randint(10, columns-fact1-1)
    left = np.ones([rows, pos-1])
    right = np.ones([rows, columns-(pos+fact1)+1])
    mask = torch.from_numpy(np.concatenate((left, frame, right), axis=1)).float()
    masked = spec * mask
    return masked


def manipulate(data, sampling_rate, shift_max, shift_direction):
    shift = np.random.randint(sampling_rate * shift_max)
    if shift_direction == 'right':
        shift = -shift
    elif shift_direction == 'both':
        direction = np.random.randint(0, 2)
        if direction == 1:
            shift = -shift
    augmented_data = np.roll(data, shift)
    # Set to silence for heading/ tailing
    if shift > 0:
        augmented_data[:shift] = 0
    else:
        augmented_data[shift:] = 0
    return augmented_data


def _rms_energy(x):
    return np.sqrt(np.mean(x**2))

#add noise to signal from a same size vector
def _add_noise(signal, noise_file_name, snr, sample_rate):
    """

    :param signal:
    :param noise_file_name:
    :param snr:
    :return:
    """
    # Open noise file
    if isinstance(noise_file_name, np.ndarray):
        noise = noise_file_name
    else:
        noise, fs_noise = librosa.load(noise_file_name, sample_rate)

    # Generate random section of masker
    if len(noise) < len(signal):
        dup_factor = len(signal) // len(noise) + 1
        noise = np.tile(noise, dup_factor)

    if len(noise) != len(signal):
        idx = np.random.randint(1, len(noise) - len(signal))
        noise = noise[idx:idx + len(signal)]

    # Compute energy of both signals
    N_dB = _rms_energy(noise)
    S_dB = _rms_energy(signal)

    # Rescale N
    N_new = S_dB - snr
    noise_scaled = 10 ** (N_new / 20) * noise / 10 ** (N_dB / 20)
    noisy = signal + noise_scaled

    return (noisy - noisy.mean()) / noisy.std()


#create a new signal of length = max_time
def time_elong(sr, audio, max_time=2):
    if len(audio) > sr*max_time:
        print('the new audio file has to be longer then the original')
    else:   
        dim = len(audio)
        audio  = audio*np.hanning(dim)
        blockl = np.random.randint(0, sr*max_time -dim-1)
        blockr = blockl + dim 
        left   = np.zeros((blockl))
        right  = np.zeros((sr*max_time - blockr))
        new    = np.concatenate((left, audio, right), axis=0)
    return librosa.to_mono(new)


class Dataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
Félix Michaud's avatar
Félix Michaud committed
193
194
  def __init__(self, mode, path, name_cl, nb_classes=1, nb_classes_noise=1, augmentation=True, path_background="./noises"):
    self.mode = mode
Félix Michaud's avatar
Félix Michaud committed
195
196
197
    self.dict_classes = self.load_data(path)
    self.nb_classes = nb_classes
    self.nb_classes_noise = nb_classes_noise
Félix Michaud's avatar
Félix Michaud committed
198
    self.name_cl = name_cl
Félix Michaud's avatar
Félix Michaud committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    self.augmentation       = augmentation
    self.path_background    = path_background
    if self.augmentation:
        self.dict_noises = self.load_data(path_background)
  
  def load_data(self, path, ext='wav'):
    dict_classes = collections.OrderedDict()
    for root, dirnames, filenames in os.walk(path):
        for filename in fnmatch.filter(filenames, '*' + ext):
            classe = root.split("/")[-1]
            if classe in dict_classes.keys():
                dict_classes[classe].append(os.path.join(root, filename))
            else:
                dict_classes[classe] = [os.path.join(root, filename)]
    if len(list(dict_classes.keys() )) == 0:
        print("** WARNING ** No data loaded from " + path)
    return dict_classes
  
Félix Michaud's avatar
Félix Michaud committed
217
  def get_noise(self):
Félix Michaud's avatar
Félix Michaud committed
218
219
220
221
      classe_noise   = random.randint(0, len(list(self.dict_noises.keys()))-1)
      classe_noise   = list(self.dict_noises.keys())[classe_noise]
      #random natural noise augmentation
      filename_noise = self.dict_noises[classe_noise][random.randint(0, len(self.dict_noises[classe_noise])-1)]
Félix Michaud's avatar
Félix Michaud committed
222
223
      return filename_noise

Félix Michaud's avatar
loss    
Félix Michaud committed
224
225

  def data_augment(self, audio, sampling_rate):
Félix Michaud's avatar
Félix Michaud committed
226
      #random pitch shifting
Félix Michaud's avatar
Félix Michaud committed
227
228
229
230
231
      step_pitch = random.uniform(-0.001, 0.001)
      augment_audio = librosa.effects.pitch_shift(audio, sampling_rate, n_steps=step_pitch)
      return augment_audio
 
     
Félix Michaud's avatar
Félix Michaud committed
232
233
234
235
236
#apply randomly at list 1 band on the spectrogram
  def spec_augmentation(self, spec):
      n = random.randint(0, 2)
      if n == 0:
            spec =  time_mask(spec)
Félix Michaud's avatar
Félix Michaud committed
237
      if n == 1:
Félix Michaud's avatar
Félix Michaud committed
238
239
240
241
242
243
244
245
246
247
            spec = freq_mask(spec)
      else:
         for ii in range(n):
             spec =  time_mask(spec)
             spec =  freq_mask(spec)
      return spec


  def __len__(self):
        'Denotes the total number of samples'
Félix Michaud's avatar
Félix Michaud committed
248
249
250
251
        if self.mode == 'train':
            nb_samples = 200000
        else:
            nb_samples = 32
Félix Michaud's avatar
Félix Michaud committed
252
253
254
255
256
257
258
        return nb_samples
    
    
  def load_files(self, nb_classes):
      files = []
      for cl in range(nb_classes):
        'Load audio file'
Félix Michaud's avatar
Félix Michaud committed
259
260
261
262
        #pick a class in the order of the dict
        rand_class = random.randint(0, len(self.dict_classes)-1)
        classe_name = list(self.dict_classes.keys())[rand_class]
        #select a random file in the class
Félix Michaud's avatar
loss    
Félix Michaud committed
263
        idx = int(random.random() * len(self.dict_classes[classe_name]))
Félix Michaud's avatar
Félix Michaud committed
264
265
266
        filename = self.dict_classes[classe_name][idx] 
        files.append([classe_name, filename])        
      return files
Félix Michaud's avatar
Félix Michaud committed
267

Félix Michaud's avatar
Félix Michaud committed
268
  
Félix Michaud's avatar
Félix Michaud committed
269
270
271
272
273
274
275
276
277
278
  def load_class(self, classe_name): 
     files = []
     'Load audio file'
     #select a random file in the class
     idx = int(random.random() * len(self.dict_classes[classe_name]) )
     filename = self.dict_classes[classe_name][idx] 
     files.append([classe_name, filename])        
     return files    
    
    
Félix Michaud's avatar
Félix Michaud committed
279
280
  '[class_name, filename, [mask], [magnitude], [phase] ]'
  def __getitem__(self, index):
Félix Michaud's avatar
Félix Michaud committed
281
      
Félix Michaud's avatar
Félix Michaud committed
282
      'Load audio file for each classe'
Félix Michaud's avatar
Félix Michaud committed
283
      file = self.load_class(self.name_cl)
Félix Michaud's avatar
loss    
Félix Michaud committed
284

Félix Michaud's avatar
Félix Michaud committed
285
286
      audio_mix = None
      max_time = 2 
Félix Michaud's avatar
Félix Michaud committed
287
      for f in file:
Félix Michaud's avatar
Félix Michaud committed
288
          audio_raw, sr = load_file(f[1])
Félix Michaud's avatar
Félix Michaud committed
289
          audio_raw     = self.data_augment(audio_raw, sr)
Félix Michaud's avatar
Félix Michaud committed
290
291
292
293
294
295
296
297
298
299
300
          new             = time_elong(sr, audio_raw, max_time)
          audio           = filt(new, sr)
          mag, phase      = _stft(audio)
          mag             = create_mask(mag.squeeze(0).squeeze(0))
          mask            = threshold(mag)
          f.append(mask) 
          f.append(mag) 
          f.append(phase) 
          if audio_mix is None:
              audio_mix = audio
          else:
Félix Michaud's avatar
Félix Michaud committed
301
302
303
              audio_mix += audio
              
      'add calls as noise from a random class' 
Félix Michaud's avatar
Félix Michaud committed
304
305
306
307
308
309
310
      classes_noise = self.load_files(self.nb_classes_noise)
      for fn in classes_noise:
          audio_raw, sr = load_file(fn[1])
          new             = time_elong(sr, audio_raw, max_time)
          audion           = filt(new, sr) 
          audio_mix += audion
       
Félix Michaud's avatar
Félix Michaud committed
311
      'Randomly add either gaussian noise or natural noise'
Félix Michaud's avatar
Félix Michaud committed
312
      if self.augmentation:
Félix Michaud's avatar
loss    
Félix Michaud committed
313
314
315
316
          if random.randint(0, 1) == 1:
              n_noise = self.get_noise()
              snr = np.random.randint(-10, 0)
          else:
Félix Michaud's avatar
Félix Michaud committed
317
318
              n_noise = np.random.normal(loc=0, scale=1, size=(1, max_time*sr))
              n_noise = librosa.to_mono(n_noise)
Félix Michaud's avatar
Félix Michaud committed
319
              snr = np.random.randint(30, 50)  #-10/5 for natural noise, 30/50
Félix Michaud's avatar
loss    
Félix Michaud committed
320
321
      audio_mix = _add_noise(audio_mix, n_noise, snr, sr)

Félix Michaud's avatar
Félix Michaud committed
322
323
324
325
326
327
          
      mag_mix, phase_mix  = _stft(audio_mix)    
      mag_mix             = mag_mix.squeeze(0).squeeze(0)
      mag_mix             = self.spec_augmentation(mag_mix)
      mags_mix            = create_im(mag_mix)
      mags_mix            = mags_mix.squeeze(0)
Félix Michaud's avatar
Félix Michaud committed
328
      return [mags_mix, phase_mix, file]