Commit 20d10c55 authored by Félix Michaud's avatar Félix Michaud
Browse files

before modifs

parent af00e477
import torch
from torch.utils import data
import os
import fnmatch
import librosa
from scipy import signal
import numpy as np
import torch.nn.functional as F
import random
import collections
#load 1 audio
def load_file(file):
audio_raw, rate = librosa.load(file, sr=22050, mono=True)
return audio_raw, rate
def filt(audio_raw, rate):
band = [800, 7000] # Desired pass band, Hz
trans_width = 100 # Width of transition from pass band to stop band, Hz
numtaps = 250 # Size of the FIR filter.
edges = [0, band[0] - trans_width,
band[0], band[1],
band[1] + trans_width, 0.5*rate]
taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')
sig_filt = signal.lfilter(taps, 1, audio_raw)
return sig_filt
# return the mag and phase for 1 stft in tensor
def _stft(audio):
spec = librosa.stft(
audio, n_fft=1022, hop_length=256)
amp = np.abs(spec)
phase = np.angle(spec)
W = np.shape(amp)[0]
H = np.shape(amp)[1]
tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
tch_mag[0, 0, :, :] = torch.from_numpy(amp)
tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
tch_phase[0, 0, :, :] = torch.from_numpy(phase)
return tch_mag, tch_phase
#return 1 torch matrix of dimensions of the stft
def threshold(mag):
gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
av = np.mean(mag[0, 0].numpy())
vari = np.var(mag[0, 0].numpy())
param = av + np.sqrt(vari)*2 #threshold
gt_mask = (mag[0, 0] > param).float()
return gt_mask
#create the grid for the image
def warpgrid_log(HO, WO, warp=True):
# meshgrid
x = np.linspace(-1, 1, WO)
y = np.linspace(-1, 1, HO)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((1, HO, WO, 2))
grid_x = xv
if warp:
grid_y = (np.power(21, (yv+1)/2) - 11) / 10
else:
grid_y = np.log(yv * 10 + 11) / np.log(21) * 2 - 1
grid[:, :, :, 0] = grid_x
grid[:, :, :, 1] = grid_y
grid = grid.astype(np.float32)
return grid
#create image from the grid
def create_im(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
#Zero center data
m = torch.mean(magim)
magim = magim - m
grid_warp = torch.from_numpy(warpgrid_log(256, magim.shape[3], warp=True))
# grid_warp = torch.from_numpy(warpgrid_log(384, 192, warp=True))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
def create_mask(mag):
magim = mag.unsqueeze(0).unsqueeze(0)
grid_warp = torch.from_numpy(warpgrid_log(256, magim.shape[3], warp=True))
# grid_warp = torch.from_numpy(warpgrid_log(264, 52, warp=True))
magim = F.grid_sample(magim, grid_warp)
return torch.from_numpy(np.flipud(magim).copy())
#kernel size:5, padding:3, image size:[256, 44]
#kernel size:3, padding:1, image size[256, 44]
#depends on the overlap of the stft
def freq_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(rows/60), int((rows/80)))
frame = np.zeros([fact1, columns])
#position of the band on the y axis
pos = random.randint(10, rows-fact1-1)
up = np.ones([pos-1, columns])
down = np.ones([rows-(pos+fact1)+1, columns])
mask = torch.from_numpy(np.concatenate((up, frame, down), axis=0)).float()
masked = spec * mask
return masked
def time_mask(spec):
fbank_size = np.shape(spec)
rows , columns = fbank_size[0], fbank_size[1]
#width of the band
fact1 = random.randint(int(columns/60), int((columns/80)))
frame = np.zeros([rows, fact1])
#position of the band on the x axis
pos = random.randint(10, columns-fact1-1)
left = np.ones([rows, pos-1])
right = np.ones([rows, columns-(pos+fact1)+1])
mask = torch.from_numpy(np.concatenate((left, frame, right), axis=1)).float()
masked = spec * mask
return masked
def manipulate(data, sampling_rate, shift_max, shift_direction):
shift = np.random.randint(sampling_rate * shift_max)
if shift_direction == 'right':
shift = -shift
elif shift_direction == 'both':
direction = np.random.randint(0, 2)
if direction == 1:
shift = -shift
augmented_data = np.roll(data, shift)
# Set to silence for heading/ tailing
if shift > 0:
augmented_data[:shift] = 0
else:
augmented_data[shift:] = 0
return augmented_data
def _rms_energy(x):
return np.sqrt(np.mean(x**2))
#add noise to signal from a same size vector
def _add_noise(signal, noise_file_name, snr, sample_rate):
"""
:param signal:
:param noise_file_name:
:param snr:
:return:
"""
# Open noise file
if isinstance(noise_file_name, np.ndarray):
noise = noise_file_name
else:
noise, fs_noise = librosa.load(noise_file_name, sample_rate)
# Generate random section of masker
if len(noise) < len(signal):
dup_factor = len(signal) // len(noise) + 1
noise = np.tile(noise, dup_factor)
if len(noise) != len(signal):
idx = np.random.randint(0, len(noise) - len(signal))
noise = noise[idx:idx + len(signal)]
# Compute energy of both signals
N_dB = _rms_energy(noise)
S_dB = _rms_energy(signal)
# Rescale N
N_new = S_dB - snr
noise_scaled = 10 ** (N_new / 20) * noise / 10 ** (N_dB / 20)
noisy = signal + noise_scaled
return (noisy - noisy.mean()) / noisy.std()
#create a new signal of length = max_time
def time_elong(sr, list_audio, max_time=2):
final_audio = np.zeros((1, sr*max_time))
for f in list_audio:
if len(f) > sr*max_time:
print('the new audio file has to be longer then the original')
else:
dim = len(f)
f = f*np.hanning(dim)
blockl = np.random.randint(0, sr*max_time -dim-1)
print(np.shape(f), 'f')
blockr = blockl + dim
left = np.zeros((blockl))
print(np.shape(left), 'left')
right = np.zeros((sr*max_time - blockr))
print(np.shape(right), 'right')
new = np.concatenate((left, f, right), axis=0)
final_audio += new
return final_audio
class Dataset(data.Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, path, nb_classes=2, augmentation=True, path_background="./noises"):
self.dict_classes = self.load_data(path)
self.nb_classes = nb_classes
self.augmentation = augmentation
self.path_background = path_background
if self.augmentation:
self.dict_noises = self.load_data(path_background)
def load_data(self, path, ext='wav'):
dict_classes = collections.OrderedDict()
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
classe = root.split("/")[-1]
if classe in dict_classes.keys():
dict_classes[classe].append(os.path.join(root, filename))
else:
dict_classes[classe] = [os.path.join(root, filename)]
if len(list(dict_classes.keys() )) == 0:
print("** WARNING ** No data loaded from " + path)
return dict_classes
def mixed_audio_augmentation(self, audio, sampling_rate):
classe_noise = random.randint(0, len(list(self.dict_noises.keys()))-1)
classe_noise = list(self.dict_noises.keys())[classe_noise]
#random natural noise augmentation
filename_noise = self.dict_noises[classe_noise][random.randint(0, len(self.dict_noises[classe_noise])-1)]
audio_noise, sr = load_file(filename_noise)
coeff = int(np.ceil(np.max(audio)*random.choice([1, 2, 3, 4, 5, 6, 7])))
noisy_audio = audio + (audio_noise)*coeff
#random pitch shifting
# step_pitch = random.uniform(-0.001, 0.001)
# mod_audio = librosa.effects.pitch_shift(noisy_audio, sampling_rate, n_steps=step_pitch)
#ramdom time shifting
# final_audio = manipulate(noisy_audio, sampling_rate, 0.1, 'both')
return noisy_audio
#apply randomly at list 1 band on the spectrogram
def spec_augmentation(self, spec):
n = random.randint(0, 2)
if n == 0:
t = random.randint(0, 1)
if t == 1:
spec = time_mask(spec)
if t == 0:
spec = freq_mask(spec)
else:
for ii in range(n):
spec = time_mask(spec)
spec = freq_mask(spec)
return spec
def __len__(self):
'Denotes the total number of samples'
# return len(self.dict_classes)
nb_samples = 400000
return nb_samples
def load_files(self, nb_classes):
files = []
for cl in range(nb_classes):
'Load audio file'
classe_name = list(self.dict_classes.keys())[cl]
idx = int(random.random() * len(self.dict_classes[classe_name]) )
filename = self.dict_classes[classe_name][idx]
files.append([classe_name, filename])
return files
'[class_name, filename, [mask], [magnitude], [phase] ]'
def __getitem__(self, index):
'Load audio file for each classe'
files = self.load_files(self.nb_classes)
audio_mix = None
for f in files:
audio_raw, rate = load_file(f[1])
audio = filt(audio_raw, rate)
mag, phase = _stft(audio)
mag = create_mask(mag.squeeze(0).squeeze(0))
mask = threshold(mag)
f.append(mask)
f.append(mag)
f.append(phase)
if audio_mix is None:
audio_mix = audio
else:
audio_mix += audio
'Build mixed mask'
if self.augmentation:
audio_mix = self.mixed_audio_augmentation(audio_mix, rate)
mag_mix, phase_mix = _stft(audio_mix)
mag_mix = mag_mix.squeeze(0).squeeze(0)
mag_mix = self.spec_augmentation(mag_mix)
mags_mix = create_im(mag_mix)
mags_mix = mags_mix.squeeze(0)
return [mags_mix, phase_mix, files]
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 26 22:14:38 2019
@author: felix
"""
import os
import random
import time
import fnmatch
import csv
import torch
from arguments import ArgParser
from unet import UNet5
import torch.nn as nn
#from tensorboardX import SummaryWriter
from matplotlib import image as mpimg
from Dataloader import Dataset
import matplotlib
import numpy as np
import collections
import scipy
#organize the name files according to their number
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def create_optimizer(nets, args):
net_sound = nets
param_groups = [{'params': net_sound.parameters(), 'lr': args.lr_sound}]
return torch.optim.Adam(param_groups)
def init_weights(m):
if type(m) == nn.Conv2d:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def unwrap_mask(infos):
#"for a kernel of 5"
gt_masks = torch.empty(args.batch_size, args.nb_classes, 256, 44, dtype=torch.float)
for ii in range(args.batch_size):
for jj in range(args.nb_classes):
gt_masks[ii, jj] = infos[jj][2][ii]
return gt_masks
def build_audio(audio_names, pred_masks, magmix, phasemix):
for ii in range(args.batch_size):
pred_masks = pred_masks[ii].numpy()
magmix = magmix[ii].squeeze(0).detach().numpy()
phasemix = phasemix[ii].squeeze(0).detach().numpy()
for n in range(args.nb_classes):
name = audio_names[n][1][ii]
magnew = pred_masks[n]*magmix
spec = magnew.astype(np.complex)*np.exp(1j*phasemix)
audio = librosa.istft(spec, hop_length=256)
scipy.io.wavfile.write('restored_audio/restored_{}.wav'.format(name), 22050, audio)
def save_arguments(args, path):
file1 = open(path+"/infos.txt","w")
print("Input arguments:")
for key, val in vars(args).items():
file1.writelines([key, str(val), '\n'])
print("{:16} {}".format(key, val))
file1.close()
def train(net, loader_train, optimizer, path, args):
torch.set_grad_enabled(True)
num_batch = 0
criterion = nn.BCELoss()
for ii, batch_data in enumerate(loader_train):
#add species names in infos.txt
if ii == 0:
args.species = []
for n in range(args.nb_classes):
args.species.append(batch_data[2][n][0][0])
save_arguments(args, args.path)
num_batch += 1
magmix = batch_data[0]
magmix = magmix.to(args.device)
masks_pred = net(magmix, dropout=True)
masks = unwrap_mask(batch_data[2])
masks = masks.to(args.device)
loss = criterion(masks_pred, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# #writing of the Loss values and elapsed time for every batch
batchtime = (time.time() - args.starting_training_time)/60 #minutes
# #Writing of the elapsed time and loss for every batch
with open(path + "/loss_times.csv", "a") as f:
writer = csv.writer(f)
writer.writerow([str(loss.cpu().detach().numpy()), batchtime, num_batch])
if ii%args.save_per_batchs == 0:
torch.save({ 'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
path + '/Saved_models/model_batchs{}.pth.tar'.format(num_batch))
def evaluation(net, loader, args):
#no upgrade over the gradient
torch.set_grad_enabled(False)
num_batch = 0
criterion = nn.BCELoss()
args.out_threshold = 0.4
for ii, batch_data in enumerate(loader):
# forward pass
magmix = batch_data[0]
magmix = magmix.to(args.device)
masks = unwrap_mask(batch_data[2])
masks = masks.to(args.device)
num_batch += 1
masks_pred = net(magmix)
# #loss
loss = criterion(masks_pred, masks)
#Visualization
with open("./losses/loss_eval/loss_times_eval{}.csv".format(args.saved_model), "a") as f:
writer = csv.writer(f)
writer.writerow([str(loss.cpu().detach().numpy()), batchtime, num_batch])
#***************************************************
#****************** MAIN ***************************
#***************************************************
if __name__ == '__main__':
# arguments
parser = ArgParser()
args = parser.parse_train_arguments()
args.batch_size = 16
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.starting_training_time = time.time()
args.save_per_batchs = 500
args.nb_classes = 3
args.mode = 'train'
args.lr_sounds = 1e-5
args.saved_model = '5_5000'
#model definition
net = UNet5(n_channels=1, n_classes=args.nb_classes)
net.apply(init_weights)
net = net.to(args.device)
# Set up optimizer
optimizer = create_optimizer(net, args)
args.path = "./Unet5/nm"
args._augment = 'nm'
###########################################################
################### TRAINING ##############################
###########################################################
if args.mode == 'train':
#OverWrite the Files for loss saving and time saving
fichierLoss = open(args.path+"/loss_times.csv", "w")
fichierLoss.close()
#Dataset loading
root = './data_sound/trainset/'
ext = '.wav'
train_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
loader_train = torch.utils.data.DataLoader(
train_classes,
batch_size = args.batch_size,
shuffle=True,
num_workers=20)
for epoch in range(0, 1):
train(net, loader_train, optimizer, args.path, args)
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
path+'/Saved_models/model_epoch{}.pth.tar'.format(epoch))
###########################################################
################### EVALUATION ############################
###########################################################
if args.mode == 'eval':
#OverWrite the Files for loss saving and time saving
fichierLoss = open("./losses/loss_eval/loss_times_eval{}.csv".format(args.saved_model), "w")
fichierLoss.close()
#Dataset loading
root = './data_sound/valset/'
ext = '.wav'
val_classes = Dataset(root, nb_classes=args.nb_classes, path_background="./data_sound/noises/")
#inisialization of the model from the saved model
checkpoint = torch.load('Saved_models2/model{}.pth.tar'.format(args.saved_model))
net.load_state_dict(checkpoint['model_state_dict'])
loader_eval = torch.utils.data.DataLoader(
val_classes,
batch_size = args.batch_size,
shuffle=True,
num_workers=20)
for epoch in range(0, 1):
evaluation(net, loader_eval, optimizer, args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment