Commit 3a3f49af authored by Félix Michaud's avatar Félix Michaud
Browse files

new dataset

parent 3b5c1c8b
import torch
from torch.utils import data
import os
import fnmatch
import librosa
from scipy import signal
import numpy as np
import torch.nn.functional as F
import random
import collections
#load 1 audio
def load_file(file):
audio_raw, rate = librosa.load(file, sr=22050, mono=True)
return audio_raw, rate
def filt(audio_raw, rate):
band = [800, 7000] # Desired pass band, Hz
trans_width = 100 # Width of transition from pass band to stop band, Hz
numtaps = 250 # Size of the FIR filter.
edges = [0, band[0] - trans_width,
band[0], band[1],
band[1] + trans_width, 0.5*rate]
taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')
sig_filt = signal.lfilter(taps, 1, audio_raw)
return sig_filt
# return the mag and phase for 1 stft in tensor
def _stft(audio):
spec = librosa.stft(
audio, n_fft=1022, hop_length=256)
amp = np.abs(spec)
phase = np.angle(spec)
W = np.shape(amp)[0]
H = np.shape(amp)[1]
tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
tch_mag[0, 0, :, :] = torch.from_numpy(amp)
tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
tch_phase[0, 0, :, :] = torch.from_numpy(phase)
return tch_mag, tch_phase
#return 1 torch matrix of dimensions of the stft
def threshold(mag):
gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
av = np.mean(mag[0, 0].numpy())
vari = np.var(mag[0, 0].numpy())
param = av + np.sqrt(vari)*2 #threshold
gt_mask = (mag[0, 0] > param).float()
return gt_mask
#create the grid for the image
def warpgrid(HO, WO, warp=True):
# meshgridu
x = np.linspace(-1, 1, WO)
y = np.linspace(-1, 1, HO)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((1, HO, WO, 2))
grid_x = xv
grid_y = yv
grid[0, :, :, 0] = grid_x
grid[0, :, :, 1] = grid_y
grid = grid.astype(np.float32)
return grid
#create image from the grid
def create_im(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid(256, magim.shape[3], warp=True))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
def create_mask(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid(256, 44, warp=True))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
#kernel size:5, padding:3, image size:[264, 52]
#kernel size:3, padding:1, image size[256, 44]
#depends on the overlap of the stft
def freq_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(rows/40), int((rows/10)))
frame = np.zeros([fact1, columns])
#position of the band on the y axis
pos = random.randint(10, rows-fact1-1)
up = np.ones([pos-1, columns])
down = np.ones([rows-(pos+fact1)+1, columns])
mask = torch.from_numpy(np.concatenate((up, frame, down), axis=0)).float()
masked = spec * mask
return masked
def time_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(columns/40), int((columns/10)))
frame = np.zeros([rows, fact1])
#position of the band on the x axis
pos = random.randint(10, columns-fact1-1)
left = np.ones([rows, pos-1])
right = np.ones([rows, columns-(pos+fact1)+1])
mask = torch.from_numpy(np.concatenate((left, frame, right), axis=1)).float()
masked = spec * mask
return masked
def manipulate(data, sampling_rate, shift_max, shift_direction):
shift = np.random.randint(sampling_rate * shift_max)
if shift_direction == 'right':
shift = -shift
elif shift_direction == 'both':
direction = np.random.randint(0, 2)
if direction == 1:
shift = -shift
augmented_data = np.roll(data, shift)
# Set to silence for heading/ tailing
if shift > 0:
augmented_data[:shift] = 0
else:
augmented_data[shift:] = 0
return augmented_data
class Dataset(data.Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, path, nb_classes=2, augmentation=True, path_background="./noises"):
self.dict_classes = self.load_data(path)
self.nb_classes = nb_classes
self.augmentation = augmentation
self.path_background = path_background
if self.augmentation:
self.dict_noises = self.load_data(path_background)
def load_data(self, path, ext='wav'):
dict_classes = collections.OrderedDict()
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
classe = root.split("/")[-1]
if classe in dict_classes.keys():
dict_classes[classe].append(os.path.join(root, filename))
else:
dict_classes[classe] = [os.path.join(root, filename)]
if len(list(dict_classes.keys() )) == 0:
print("** WARNING ** No data loaded from " + path)
return dict_classes
def mixed_audio_augmentation(self, audio, sampling_rate):
classe_noise = random.randint(0, len(list(self.dict_noises.keys()))-1)
classe_noise = list(self.dict_noises.keys())[classe_noise]
#random natural noise augmentation
filename_noise = self.dict_noises[classe_noise][random.randint(0, len(self.dict_noises[classe_noise])-1)]
audio_noise, sr = load_file(filename_noise)
coeff = int(np.ceil(np.max(audio)*random.choice([1, 2, 3, 4, 5, 6, 7])))
noisy_audio = audio + (audio_noise)*coeff
#random pitch shifting
step_pitch = random.uniform(-10, 10)
mod_audio = librosa.effects.pitch_shift(noisy_audio, sampling_rate, n_steps=step_pitch)
#ramdom time shifting
final_audio = manipulate(mod_audio, sampling_rate, 0.25, 'both')
return final_audio
#apply randomly at list 1 band on the spectrogram
def spec_augmentation(self, spec):
n = random.randint(0, 2)
if n == 0:
t = random.randint(0, 1)
if t == 1:
spec = time_mask(spec)
if t == 0:
spec = freq_mask(spec)
else:
for ii in range(n):
spec = time_mask(spec)
spec = freq_mask(spec)
return spec
def __len__(self):
'Denotes the total number of samples'
# return len(self.dict_classes)
nb_samples = 400000
return nb_samples
def load_files(self, nb_classes):
files = []
for cl in range(nb_classes):
'Load audio file'
classe_name = list(self.dict_classes.keys())[cl]
idx = int(random.random() * len(self.dict_classes[classe_name]) )
filename = self.dict_classes[classe_name][idx]
files.append([classe_name, filename])
return files
'[class_name, filename, [mask], [magnitude], [phase] ]'
def __getitem__(self, index):
'Load audio file for each classe'
files = self.load_files(self.nb_classes)
audio_mix = None
for f in files:
audio_raw, rate = load_file(f[1])
audio = filt(audio_raw, rate)
mag, phase = _stft(audio)
mag = create_mask(mag.squeeze(0).squeeze(0))
print(np.shape(mag), 'mag')
mask = threshold(mag)
f.append(mask)
f.append(mag)
f.append(phase)
if audio_mix is None:
audio_mix = audio
else:
audio_mix += audio
'Build mixed mask'
if self.augmentation:
audio_mix = self.mixed_audio_augmentation(audio_mix, rate)
mag_mix, phase_mix = _stft(audio_mix)
mag_mix = mag_mix.squeeze(0).squeeze(0)
# mag_mix = self.spec_augmentation(mag_mix)
mags_mix = create_im(mag_mix)
mags_mix = mags_mix.squeeze(0)
return [mags_mix, phase_mix, files]
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 26 22:14:38 2019
@author: felix
"""
import os
import random
import time
import fnmatch
import csv
import torch
from arguments import ArgParser
from unet import UNet
import torch.nn as nn
from tensorboardX import SummaryWriter
from matplotlib import image as mpimg
from Dataloader import Dataset
import matplotlib
import numpy as np
import collections
import scipy
#organize the name files according to their number
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def create_optimizer(nets, args):
net_sound = nets
param_groups = [{'params': net_sound.parameters(), 'lr': args.lr_sound}]
return torch.optim.Adam(param_groups)
def unwrap_mask(infos):
# gt_masks = torch.empty(args.batch_size, args.nb_classes, 256, 44, dtype=torch.float)
#"for a kernel of 5"
gt_masks = torch.empty(args.batch_size, args.nb_classes, 264, 52, dtype=torch.float)
for ii in range(args.batch_size):
for jj in range(args.nb_classes):
gt_masks[ii, jj] = infos[jj][2][ii]
return gt_masks
def build_audio(audio_names, pred_masks, magmix, phasemix):
for ii in range(args.batch_size):
pred_masks = pred_masks[ii].numpy()
magmix = magmix[ii].squeeze(0).detach().numpy()
phasemix = phasemix[ii].squeeze(0).detach().numpy()
for n in range(args.nb_classes):
name = audio_names[n][1][ii]
magnew = pred_masks[n]*magmix
spec = magnew.astype(np.complex)*np.exp(1j*phasemix)
audio = librosa.istft(spec, hop_length=256)
scipy.io.wavfile.write('restored_audio/restored_{}.wav'.format(name), 22050, audio)
def train(net, loader_train, optimizer, args):
torch.set_grad_enabled(True)
num_batch = 0
criterion = nn.BCELoss()
for ii, batch_data in enumerate(loader_train):
num_batch += 1
magmix = batch_data[0]
magmix = magmix.to(args.device)
masks_pred = net(magmix)
masks = unwrap_mask(batch_data[2])
masks = masks.to(args.device)
loss = criterion(masks_pred, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# #writing of the Loss values and elapsed time for every batch
batchtime = (time.time() - args.starting_training_time)/60 #minutes
# #Writing of the elapsed time and loss for every batch
with open("./losses/loss_train/loss_timesU65.csv", "a") as f:
writer = csv.writer(f)
writer.writerow([str(loss.cpu().detach().numpy()), batchtime, num_batch])
if ii%args.save_per_batchs == 0:
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
'Saved_modelsix/modelU65_batchs{}.pth.tar'.format(num_batch))
def evaluation(net, loader, args):
#no upgrade over the gradient
torch.set_grad_enabled(False)
num_batch = 0
criterion = nn.BCELoss()
args.out_threshold = 0.4
for ii, batch_data in enumerate(loader):
# forward pass
magmix = batch_data[0]
magmix = magmix.to(args.device)
masks = unwrap_mask(batch_data[2])
masks = masks.to(args.device)
num_batch += 1
masks_pred = net(magmix)
# #loss
loss = criterion(masks_pred, masks)
#Visualization
with open("./losses/loss_eval/loss_times_eval{}.csv".format(args.saved_model), "a") as f:
writer = csv.writer(f)
writer.writerow([str(loss.cpu().detach().numpy()), batchtime, num_batch])
#***************************************************
#****************** MAIN ***************************
#***************************************************
if __name__ == '__main__':
# arguments
parser = ArgParser()
args = parser.parse_train_arguments()
args.batch_size = 16
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.starting_training_time = time.time()
args.save_per_batchs = 30
args.nb_classes = 2
args.mode = 'train'
args.lr_sounds = 1e-5
args.saved_model = '5_5000'
#model definition
net = UNet(n_channels=1, n_classes=args.nb_classes)
net = net.to(args.device)
# Set up optimizer
optimizer = create_optimizer(net, args)
###########################################################
################### TRAINING ##############################
###########################################################
if args.mode == 'train':
#OverWrite the Files for loss saving and time saving
fichierLoss = open("./losses/loss_train/loss_timesU65.csv", "w")
fichierLoss.close()
#Dataset loading
root = './data_sound/trainset/'
ext = '.wav'
train_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
loader_train = torch.utils.data.DataLoader(
train_classes,
batch_size = args.batch_size,
shuffle=True,
num_workers=20)
for epoch in range(0, 1):
train(net, loader_train, optimizer, args)
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
'Saved_modelsix/modelU65epoch{}.pth.tar'.format(epoch))
###########################################################
################### EVALUATION ############################
###########################################################
if args.mode == 'eval':
#OverWrite the Files for loss saving and time saving
fichierLoss = open("./losses/loss_eval/loss_times_eval{}.csv".format(args.saved_model), "w")
fichierLoss.close()
#Dataset loading
root = './data_sound/valset/'
ext = '.wav'
val_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
#inisialization of the model from the saved model
checkpoint = torch.load('Saved_models2/model{}.pth.tar'.format(args.saved_model))
net.load_state_dict(checkpoint['model_state_dict'])
loader_eval = torch.utils.data.DataLoader(
val_classes,
batch_size = args.batch_size,
shuffle=True,
num_workers=20)
for epoch in range(0, 1):
evaluation(net, loader_eval, optimizer, args)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 28 16:14:55 2019
@author: felix
"""
import os
import fnmatch
import librosa
from scipy import signal
import numpy as np
import torch.nn.functional as F
import torch
from unet import UNet
from matplotlib import image as mpimg
import matplotlib
import scipy
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def load_audio(audio):
audio_raw, rate = librosa.load(audio, mono=True)
marc = rate*5
audio1 = audio_raw[0:int(marc)]
return audio1, rate
def Visualizer(r_input):
my_cm = matplotlib.cm.get_cmap('Greys')
r_input = r_input.detach().squeeze(0)
normed_data = (np.asarray(r_input) - torch.min(r_input).item()) / (torch.max(r_input).item() - torch.min(r_input).item())
mpimg.imsave('input.png', np.flipud(my_cm(normed_data)))
def Visualizer2(list_files, model):
# print(len(list_files), 'len list files')
for ii in range(len(list_files)):
mpimg.imsave('./simul_npt/resultmodel_{}_{}.png'.format(ii, model), np.flipud(np.asarray(list_files[ii])))
def np2rgb(tens_im):
my_cm1 = matplotlib.cm.get_cmap( 'Reds')
my_cm2 = matplotlib.cm.get_cmap( 'Greens')
my_cm3 = matplotlib.cm.get_cmap( 'Blues')
colors = [my_cm1, my_cm2, my_cm3]
N = np.shape(tens_im)[0]
# print(N, 'N')
tens_im = tens_im.detach()
mapped_data = [None for n in range(N)]
for n in range(N):
thres_im = tens_im[n] > 0.6
normed_data = (np.asarray(thres_im) - torch.min(thres_im).item()) / (torch.max(thres_im).item() - torch.min(thres_im).item())
mapped_data[n] = torch.from_numpy(colors[n](normed_data))
return mapped_data
def filt(audio_raw, rate):
band = [800, 7000] # Desired pass band, Hz
trans_width = 100 # Width of transition from pass band to stop band, Hz
numtaps = 250 # Size of the FIR filter.
edges = [0, band[0] - trans_width,
band[0], band[1],
band[1] + trans_width, 0.5*rate]
taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')
sig_filt = signal.lfilter(taps, 1, audio_raw)
return sig_filt
def _stft(audio):
spec = librosa.stft(
audio, n_fft=1022, hop_length=256)
amp = np.abs(spec)
phase = np.angle(spec)
W = np.shape(amp)[0]
H = np.shape(amp)[1]
tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
tch_mag[0, 0, :, :] = torch.from_numpy(amp)
tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
tch_phase[0, 0, :, :] = torch.from_numpy(phase)
return tch_mag, tch_phase
def threshold(mag):
gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
av = np.mean(mag[0, 0].numpy())
vari = np.var(mag[0, 0].numpy())
param = av + np.sqrt(vari)*1 #threshold
gt_mask = (mag[0, 0] > param).float()
return gt_mask
def warpgrid(HO, WO, warp=True):
# meshgrid
x = np.linspace(-1, 1, WO)
y = np.linspace(-1, 1, HO)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((1, HO, WO, 2))
grid_x = xv
grid_y = yv
grid[0, :, :, 0] = grid_x
grid[0, :, :, 1] = grid_y
grid = grid.astype(np.float32)
return grid
def create_im(mag):
T = mag.shape[3]
# 0.0 warp the spectrogram
grid_warp = torch.from_numpy(warpgrid(256, T, warp=True))
magim = F.grid_sample(mag, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
def build_audio(pred_masks, sr, magmix, phasemix, model):
pred_masks = pred_masks.squeeze(0).detach().numpy()
magmix = magmix.squeeze(0).squeeze(0).detach().numpy()
phasemix = phasemix.squeeze(0).squeeze(0).detach().numpy()
nb_masks = pred_masks.shape[0]
for n in range(nb_masks):
magnew = pred_masks[n]*magmix
spec = magnew.astype(np.complex)*np.exp(1j*phasemix)
audio = librosa.istft(spec, hop_length=256)
scipy.io.wavfile.write('./test/restored_audio/restored_model_{}_{}.wav'.format(model, n), sr, audio)
model = []
root_dir2 = './Saved_models_npt/'
ext = '.tar'
for root, dirnames, filenames in os.walk(root_dir2):
for filename in fnmatch.filter(filenames, '*' + ext):
model.append(os.path.join(root, filename))
crow = './test/crow/43215301.ogg'
pewee = './test/pewee/pewee2.ogg'
noise = 'srain.wav'
file1, sr = load_audio(crow)
file2, _ = load_audio(pewee)
file3, _ = load_audio(noise)
mix = file1 + file2 + file3
filtmix = filt(mix, sr)
#scipy.io.wavfile.write('./test/inputaudio.wav', sr, filtmix)
magmix, phasemix = _stft(filtmix)
im = create_im(magmix)
phase = create_im(phasemix)
Visualizer(im.squeeze(0))
#
net = UNet(n_channels=1, n_classes=2)
for mod in model:
real_mod = mod.rsplit('/', 1)[1].split('.', 1)[0]
checkpoint = torch.load('Saved_models_npt/{}.pth.tar'.format(real_mod))