Commit 02d9bff3 authored by Félix Michaud's avatar Félix Michaud
Browse files

audiotest

parent 53bacf18
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 26 22:14:38 2019
@author: felix
"""
import os
import random
import time
import fnmatch
import csv
import torch
from arguments import ArgParser
from unet import UNet
import torch.nn as nn
from tensorboardX import SummaryWriter
from matplotlib import image as mpimg
from Dataloader import Dataset
import matplotlib
import numpy as np
import collections
import scipy
#organize the name files according to their number
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def create_optimizer(nets, args):
net_sound = nets
param_groups = [{'params': net_sound.parameters(), 'lr': args.lr_sound}]
return torch.optim.Adam(param_groups)
def unwrap_mask(infos):
# gt_masks = torch.empty(args.batch_size, args.nb_classes, 256, 44, dtype=torch.float)
#"for a kernel of 5"
gt_masks = torch.empty(args.batch_size, args.nb_classes, 264, 52, dtype=torch.float)
for ii in range(args.batch_size):
for jj in range(args.nb_classes):
gt_masks[ii, jj] = infos[jj][2][ii]
return gt_masks
def build_audio(audio_names, pred_masks, magmix, phasemix):
for ii in range(args.batch_size):
pred_masks = pred_masks[ii].numpy()
magmix = magmix[ii].squeeze(0).detach().numpy()
phasemix = phasemix[ii].squeeze(0).detach().numpy()
for n in range(args.nb_classes):
name = audio_names[n][1][ii]
magnew = pred_masks[n]*magmix
spec = magnew.astype(np.complex)*np.exp(1j*phasemix)
audio = librosa.istft(spec, hop_length=256)
scipy.io.wavfile.write('restored_audio/restored_{}.wav'.format(name), 22050, audio)
def train(net, loader_train, optimizer, args):
torch.set_grad_enabled(True)
num_batch = 0
criterion = nn.BCELoss()
for ii, batch_data in enumerate(loader_train):
num_batch += 1
magmix = batch_data[0]
magmix = magmix.to(args.device)
masks_pred = net(magmix)
masks = unwrap_mask(batch_data[2])
masks = masks.to(args.device)
loss = criterion(masks_pred, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# #writing of the Loss values and elapsed time for every batch
batchtime = (time.time() - args.starting_training_time)/60 #minutes
# #Writing of the elapsed time and loss for every batch
with open("./losses/loss_train/loss_timesU65.csv", "a") as f:
writer = csv.writer(f)
writer.writerow([str(loss.cpu().detach().numpy()), batchtime, num_batch])
if ii%args.save_per_batchs == 0:
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
'Saved_modelsix/modelU65_batchs{}.pth.tar'.format(num_batch))
def evaluation(net, loader, args):
#no upgrade over the gradient
torch.set_grad_enabled(False)
num_batch = 0
criterion = nn.BCELoss()
args.out_threshold = 0.4
for ii, batch_data in enumerate(loader):
# forward pass
magmix = batch_data[0]
magmix = magmix.to(args.device)
masks = unwrap_mask(batch_data[2])
masks = masks.to(args.device)
num_batch += 1
masks_pred = net(magmix)
# #loss
loss = criterion(masks_pred, masks)
#Visualization
with open("./losses/loss_eval/loss_times_eval{}.csv".format(args.saved_model), "a") as f:
writer = csv.writer(f)
writer.writerow([str(loss.cpu().detach().numpy()), batchtime, num_batch])
#***************************************************
#****************** MAIN ***************************
#***************************************************
if __name__ == '__main__':
# arguments
parser = ArgParser()
args = parser.parse_train_arguments()
args.batch_size = 16
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.starting_training_time = time.time()
args.save_per_batchs = 30
args.nb_classes = 2
args.mode = 'train'
args.lr_sounds = 1e-5
args.saved_model = '5_5000'
#model definition
net = UNet(n_channels=1, n_classes=args.nb_classes)
net = net.to(args.device)
# Set up optimizer
optimizer = create_optimizer(net, args)
###########################################################
################### TRAINING ##############################
###########################################################
if args.mode == 'train':
#OverWrite the Files for loss saving and time saving
fichierLoss = open("./losses/loss_train/loss_timesU65.csv", "w")
fichierLoss.close()
#Dataset loading
root = './data_sound/trainset/'
ext = '.wav'
train_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
loader_train = torch.utils.data.DataLoader(
train_classes,
batch_size = args.batch_size,
shuffle=True,
num_workers=20)
for epoch in range(0, 1):
train(net, loader_train, optimizer, args)
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
'Saved_modelsix/modelU65epoch{}.pth.tar'.format(epoch))
###########################################################
################### EVALUATION ############################
###########################################################
if args.mode == 'eval':
#OverWrite the Files for loss saving and time saving
fichierLoss = open("./losses/loss_eval/loss_times_eval{}.csv".format(args.saved_model), "w")
fichierLoss.close()
#Dataset loading
root = './data_sound/valset/'
ext = '.wav'
val_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
#inisialization of the model from the saved model
checkpoint = torch.load('Saved_models2/model{}.pth.tar'.format(args.saved_model))
net.load_state_dict(checkpoint['model_state_dict'])
loader_eval = torch.utils.data.DataLoader(
val_classes,
batch_size = args.batch_size,
shuffle=True,
num_workers=20)
for epoch in range(0, 1):
evaluation(net, loader_eval, optimizer, args)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 28 16:14:55 2019
@author: felix
"""
import os
import fnmatch
import librosa
from scipy import signal
import numpy as np
import torch.nn.functional as F
import torch
from unet import UNet
from matplotlib import image as mpimg
import matplotlib
import soundfile as sf
import scipy
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def load_audio(audio):
audio_raw, rate = librosa.load(audio, sr=22050, mono=True)
marc = 22050*5
audio1 = audio_raw[0:int(marc)]
return audio1, rate
def load_file(file):
audio_raw, rate = sf.read(file)
audio_ = librosa.resample(audio_raw, rate, 22050)
marc = 22050 * 5
audio = audio_[0:int(marc), 0]
return audio, 22050
def Visualizer(r_input):
my_cm = matplotlib.cm.get_cmap('Greys')
r_input = r_input.detach().squeeze(0)
normed_data = (np.asarray(r_input) - torch.min(r_input).item()) / (torch.max(r_input).item() - torch.min(r_input).item())
mpimg.imsave('input.png', np.flipud(my_cm(normed_data)))
def Visualizer2(list_files, model):
# print(len(list_files), 'len list files')
for ii in range(len(list_files)):
mpimg.imsave('./simulation/resultmodel_{}_{}.png'.format(ii, model), np.flipud(np.asarray(list_files[ii])))
def np2rgb(tens_im):
my_cm1 = matplotlib.cm.get_cmap( 'Reds')
my_cm2 = matplotlib.cm.get_cmap( 'Greens')
my_cm3 = matplotlib.cm.get_cmap( 'Blues')
colors = [my_cm1, my_cm2, my_cm3]
N = np.shape(tens_im)[0]
# print(N, 'N')
tens_im = tens_im.detach()
mapped_data = [None for n in range(N)]
for n in range(N):
thres_im = tens_im[n] > 0.5
normed_data = (np.asarray(thres_im) - torch.min(thres_im).item()) / (torch.max(thres_im).item() - torch.min(thres_im).item())
mapped_data[n] = torch.from_numpy(colors[n](normed_data))
return mapped_data
def filt(audio_raw, rate):
band = [800, 7000] # Desired pass band, Hz
trans_width = 100 # Width of transition from pass band to stop band, Hz
numtaps = 250 # Size of the FIR filter.
edges = [0, band[0] - trans_width,
band[0], band[1],
band[1] + trans_width, 0.5*rate]
taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')
sig_filt = signal.lfilter(taps, 1, audio_raw)
return sig_filt
def _stft(audio):
spec = librosa.stft(
audio, n_fft=1022, hop_length=256)
amp = np.abs(spec)
phase = np.angle(spec)
W = np.shape(amp)[0]
H = np.shape(amp)[1]
tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
tch_mag[0, 0, :, :] = torch.from_numpy(amp)
tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
tch_phase[0, 0, :, :] = torch.from_numpy(phase)
return tch_mag, tch_phase
def threshold(mag):
gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
av = np.mean(mag[0, 0].numpy())
vari = np.var(mag[0, 0].numpy())
param = av + np.sqrt(vari)*1 #threshold
gt_mask = (mag[0, 0] > param).float()
return gt_mask
def warpgrid(HO, WO, warp=True):
# meshgrid
x = np.linspace(-1, 1, WO)
y = np.linspace(-1, 1, HO)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((1, HO, WO, 2))
grid_x = xv
grid_y = yv
grid[0, :, :, 0] = grid_x
grid[0, :, :, 1] = grid_y
grid = grid.astype(np.float32)
return grid
def create_im(mag):
T = mag.shape[3]
# 0.0 warp the spectrogram
grid_warp = torch.from_numpy(warpgrid(256, T, warp=True))
magim = F.grid_sample(mag, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
def build_audio(pred_masks, magmix, phasemix, model):
pred_masks = pred_masks.squeeze(0).detach().numpy()
magmix = magmix.squeeze(0).squeeze(0).detach().numpy()
phasemix = phasemix.squeeze(0).squeeze(0).detach().numpy()
nb_masks = pred_masks.shape[0]
for n in range(nb_masks):
magnew = pred_masks[n]*magmix
spec = magnew.astype(np.complex)*np.exp(1j*phasemix)
audio = librosa.istft(spec, hop_length=256)
scipy.io.wavfile.write('restored_audio/restored_model_{}_{}.wav'.format(model, n), 22050, audio)
model = []
root_dir2 = './Saved_models2/'
ext = '.tar'
for root, dirnames, filenames in os.walk(root_dir2):
for filename in fnmatch.filter(filenames, '*' + ext):
model.append(os.path.join(root, filename))
crow = './test/crow/43215301.ogg'
pewee = './test/pewee/pewee2.ogg'
noise = 'srain.wav'
file1, sr = load_file(crow)
file2, _ = load_file(pewee)
file3, _ = load_audio(noise)
mix = file1 + file2 + file3
filtmix = filt(mix, sr)
magmix, phasemix = _stft(filtmix)
im = create_im(magmix)
phase = create_im(phasemix)
Visualizer(im.squeeze(0))
#
net = UNet(n_channels=1, n_classes=2)
for mod in model:
real_mod = mod.rsplit('/', 1)[1].split('.', 1)[0]
checkpoint = torch.load('Saved_models2/{}.pth.tar'.format(real_mod))
net.load_state_dict(checkpoint['model_state_dict'])
masks_pred = net(im)
print(im.size(), 'im size')
print(masks_pred.size(), 'masks pred size')
# mapped_data = np2rgb(masks_pred.squeeze(0))
# Visualizer2(mapped_data, real_mod)
build_audio(masks_pred, im, phase, real_mod)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment