Commit 55888559 authored by Félix Michaud's avatar Félix Michaud
Browse files

before audiotest

parent 4f39f42e
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 26 22:14:38 2019
@author: felix
"""
import os
import random
import time
import fnmatch
import csv
import torch
from arguments import ArgParser
from Calls import CallDataset
from unet import UNet
import torch.nn as nn
from tensorboardX import SummaryWriter
from matplotlib import image as mpimg
from Dataloader import Dataset
import matplotlib
import numpy as np
import collections
#organize the name files according to their number
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def create_optimizer(nets, args):
net_sound = nets
param_groups = [{'params': net_sound.parameters(), 'lr': args.lr_sound}]
return torch.optim.Adam(param_groups)
def unwrap_mask(infos):
gt_masks = torch.empty(args.batch_size, args.nb_classes, 256, 44, dtype=torch.float)
for ii in range(args.batch_size):
for jj in range(args.nb_classes):
gt_masks[ii, jj] = infos[jj][2][ii]
return gt_masks
def train(net, loader_train, optimizer, args):
torch.set_grad_enabled(True)
num_batch = 0
criterion = nn.BCELoss()
for ii, batch_data in enumerate(loader_train):
num_batch += 1
magmix = batch_data[0]
# magmix = magmix.to(args.device)
masks_pred = net(magmix)
masks = unwrap_mask(batch_data[1])
# masks = masks.to(args.device)
loss = criterion(masks_pred, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# #writing of the Loss values and elapsed time for every batch
batchtime = (time.time() - args.starting_training_time)/60 #minutes
# #Writing of the elapsed time and loss for every batch
with open("./losses/loss_train/loss_times5.csv", "a") as f:
writer = csv.writer(f)
writer.writerow([str(loss.cpu().detach().numpy()), batchtime, num_batch])
if ii%args.save_per_batchs == 0:
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
'Saved_models3/model5_batchs{}.pth.tar'.format(num_batch))
#***************************************************
#****************** MAIN ***************************
#***************************************************
if __name__ == '__main__':
# arguments
parser = ArgParser()
args = parser.parse_train_arguments()
args.batch_size = 3
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.starting_training_time = time.time()
args.save_per_batchs = 180
args.nb_classes = 2
args.mode = 'train'
args.lr_sounds = 1e-5
args.saved_model = '5_5000'
#model definition
net = UNet(n_channels=1, n_classes=args.nb_classes)
# net = net.to(args.device)
# Set up optimizer
optimizer = create_optimizer(net, args)
###########################################################
################### TRAINING ##############################
###########################################################
if args.mode == 'train':
#OverWrite the Files for loss saving and time saving
fichierLoss = open("./losses/loss_train/loss_times5.csv", "w")
fichierLoss.close()
#Dataset loading
root = './data_sound/trainset/'
ext = '.wav'
train_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
print('apres dataset')
loader_train = torch.utils.data.DataLoader(
train_classes,
batch_size = args.batch_size,
shuffle=True,
num_workers=20)
lala = next(iter(loader_train))
# print('avant epoch')
# for epoch in range(0, 1):
# train(net, loader_train, optimizer, args)
# torch.save({
# 'model_state_dict': net.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict()},
# 'Saved_models3/model5epoch{}.pth.tar'.format(epoch))
###########################################################
################### EVALUATION ############################
###########################################################
if args.mode == 'eval':
#OverWrite the Files for loss saving and time saving
fichierLoss = open("./losses/loss_eval/loss_times_eval{}.csv".format(args.saved_model), "w")
fichierLoss.close()
#Dataset loading
# first class of birds
crow = [] # concatenate path and file name
root_dir1 = './data/valset/crow/'
ext = '.wav'
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 28 16:14:55 2019
@author: felix
"""
import os
import fnmatch
import librosa
from librosa import display
from scipy import signal
import scipy
import numpy as np
import torch.nn.functional as F
import torch
from unet import UNet
from matplotlib import image as mpimg
import matplotlib
import soundfile as sf
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def load_audio(audio):
audio_raw, rate = librosa.load(audio,sr = 44100, mono=True)
print(audio_raw.shape, 'audio noise shape')
print(rate, 'rate noise')
audio_ = librosa.resample(audio_raw, rate, 22050)
marc = 22050*2.5
audio1 = audio_[0:int(marc)]
print(audio1.shape, 'len noise')
return audio1
def load_file(file):
audio_raw, rate = sf.read(file)
print(audio_raw.shape, 'audio call shape')
audio_ = librosa.resample(audio_raw, rate, 22050)
marc = 22050 * 2.5
audio = audio_[0:int(marc), 0]
print(int(marc), 'int marc')
print(audio.shape, 'call shape')
return audio, 22050
def Visualizer(raw_input):
my_cm = matplotlib.cm.get_cmap( 'Greys')
raw_input = raw_input.detach().squeeze(0)
normed_data = (np.asarray(raw_input) - torch.min(raw_input).item()) / (torch.max(raw_input).item() - torch.min(raw_input).item())
# mapped_data = my_cm(normed_data)
print(np.shape(normed_data), 'shape normed data')
mpimg.imsave('Inputrain.png', np.flipud(my_cm(normed_data)))
def Visualizer2(list_files, model):
# print(len(list_files), 'len list files')
for ii in range(len(list_files)):
mpimg.imsave('./simulation/resultmodel_{}_{}.png'.format(ii, model), np.asarray(list_files[ii]))
def np2rgb(tens_im):
my_cm1 = matplotlib.cm.get_cmap( 'Reds')
my_cm2 = matplotlib.cm.get_cmap(sns.color_palette("Blues"))
my_cm3 = matplotlib.cm.get_cmap( 'Greens')
colors = [my_cm1, my_cm2, my_cm3]
N = np.shape(tens_im)[0]
# print(N, 'N')
tens_im = tens_im.detach()
mapped_data = [None for n in range(N)]
for n in range(N):
thres_im = tens_im[n] > 0.5
normed_data = (np.asarray(thres_im) - torch.min(thres_im).item()) / (torch.max(thres_im).item() - torch.min(thres_im).item())
mapped_data[n] = torch.from_numpy(colors[n](normed_data))
return mapped_data
def filt(audio_raw, rate):
band = [800, 7000] # Desired pass band, Hz
trans_width = 100 # Width of transition from pass band to stop band, Hz
numtaps = 250 # Size of the FIR filter.
edges = [0, band[0] - trans_width,
band[0], band[1],
band[1] + trans_width, 0.5*rate]
taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')
sig_filt = signal.lfilter(taps, 1, audio_raw)
return sig_filt
def _stft(audio):
spec = librosa.stft(
audio, n_fft=1022, hop_length=256)
amp = np.abs(spec)
phase = np.angle(spec)
W = np.shape(amp)[0]
H = np.shape(amp)[1]
tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
tch_mag[0, 0, :, :] = torch.from_numpy(amp)
tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
tch_phase[0, 0, :, :] = torch.from_numpy(phase)
return tch_mag, tch_phase, amp
def threshold(mag):
gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
av = np.mean(mag[0, 0].numpy())
vari = np.var(mag[0, 0].numpy())
param = av + np.sqrt(vari)*1 #threshold
gt_mask = (mag[0, 0] > param).float()
return gt_mask
def warpgrid(HO, WO, warp=True):
# meshgrid
x = np.linspace(-1, 1, WO)
y = np.linspace(-1, 1, HO)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((1, HO, WO, 2))
grid_x = xv
grid_y = yv
grid[0, :, :, 0] = grid_x
grid[0, :, :, 1] = grid_y
grid = grid.astype(np.float32)
return grid
def create_im(mag):
T = mag.shape[3]
# 0.0 warp the spectrogram
grid_warp = torch.from_numpy(warpgrid(256, T, warp=True))
magim = torch.log(F.grid_sample(mag, grid_warp))
return magim
model = []
root_dir2 = './Saved_models2/'
ext = '.tar'
for root, dirnames, filenames in os.walk(root_dir2):
for filename in fnmatch.filter(filenames, '*' + ext):
model.append(os.path.join(root, filename))
crow = './test/crow/43215301.ogg'
pewee = './test/pewee/pewee2.ogg'
noise = 'srain.wav'
file1, sr = load_file(crow)
file2, _ = load_file(pewee)
file3 = load_audio(noise)
#file3 = file3*5
mix = file1 + file2 + file3
filtmix = filt(mix, sr)
#magmix, _, amp = _stft(filtmix)
#SNR = 10*np.log10((sum(filtmix**2))/sum(file3**2))
#print(SNR, 'SNR')
scipy.io.wavfile.write('inputaudio.wav', 22050, filtmix)
#mpl_fig = plt.figure()
#ax = mpl_fig.add_subplot(111)
#display.specshow(np.asarray(magmix[0].squeeze(0)), hop_length = 256, x_axis='time', y_axis='linear', sr = 22050)
#ax.set_xlabel("Time", fontsize=25, weight='bold')
#ax.set_ylabel("Hz", fontsize=25, weight='bold')
#font = {'family' : 'normal',
# 'weight' : 'bold',
# 'size' : 25}
#
#matplotlib.rc('font', **font)
#im = create_im(magmix)
###
##Visualizer(im.squeeze(0))
###
#net = UNet(n_channels=1, n_classes=2)
#for mod in model:
# real_mod = mod.rsplit('/', 1)[1].split('.', 1)[0]
# checkpoint = torch.load('Saved_models2/{}.pth.tar'.format(real_mod))
# net.load_state_dict(checkpoint['model_state_dict'])
# masks_pred = net(im)
# mapped_data = np2rgb(masks_pred.squeeze(0))
# Visualizer2(mapped_data, real_mod)
##
#Visualizer(im.squeeze(0))
##
#net = UNet(n_channels=1, n_classes=2)
#for mod in model:
# real_mod = mod.rsplit('/', 1)[1].split('.', 1)[0]
# checkpoint = torch.load('Saved_models/{}.pth.tar'.format(real_mod))
# net.load_state_dict(checkpoint['model_state_dict'])
# masks_pred = net(im)
# mapped_data = np2rgb(masks_pred.squeeze(0))
# Visualizer2(mapped_data, real_mod)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment