Commit 415e8076 authored by Félix Michaud's avatar Félix Michaud
Browse files

image 5 secs

parent 0cb369ee
......@@ -142,9 +142,9 @@ class Dataset(data.Dataset):
mag_mix, phase_max = _stft(audio_mix)
mag_mix = mag_mix.squeeze(0).squeeze(0)
mask_mix = create_im(mag_mix)
mags_mix = create_im(mag_mix)
return [mask_mix, files]
return [mags_mix, files]
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 27 01:04:00 2019
@author: felix
"""
import os
import fnmatch
import librosa
from scipy import signal
import numpy as np
import torch.nn.functional as F
import torch
import soundfile as sf
from matplotlib import image as mpimg
from unet import UNet
import matplotlib
def create_list(path, ext):
list_names = []
for root, dirnames, filenames in os.walk(path):
for filename in fnmatch.filter(filenames, '*' + ext):
list_names.append(os.path.join(root, filename))
return list_names
def load_file(file):
audio_raw, rate = sf.read(file)
return audio_raw, rate
def filt(audio_raw, rate):
band = [800, 7000] # Desired pass band, Hz
trans_width = 100 # Width of transition from pass band to stop band, Hz
numtaps = 250 # Size of the FIR filter.
edges = [0, band[0] - trans_width,
band[0], band[1],
band[1] + trans_width, 0.5*rate]
taps = signal.remez(numtaps, edges, [0, 1, 0], Hz=rate, type='bandpass')
sig_filt = signal.lfilter(taps, 1, audio_raw)
return sig_filt
def _stft(audio):
spec = librosa.stft(
audio, n_fft=1022, hop_length=256)
amp = np.abs(spec)
phase = np.angle(spec)
W = np.shape(amp)[0]
H = np.shape(amp)[1]
tch_mag = torch.empty(1, 1, W, H, dtype=torch.float)
tch_mag[0, 0, :, :] = torch.from_numpy(amp)
tch_phase = torch.empty(1, 1, W, H, dtype=torch.float)
tch_phase[0, 0, :, :] = torch.from_numpy(phase)
return tch_mag, tch_phase
def warpgrid(HO, WO, warp=True):
# meshgrid
x = np.linspace(-1, 1, WO)
y = np.linspace(-1, 1, HO)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((1, HO, WO, 2))
grid_x = xv
grid_y = yv
grid[0, :, :, 0] = grid_x
grid[0, :, :, 1] = grid_y
grid = grid.astype(np.float32)
return grid
def create_im(mag):
T = mag.shape[3]
# 0.0 warp the spectrogram
grid_warp = torch.from_numpy(warpgrid(256, T, warp=True))
magim = F.grid_sample(mag, grid_warp)
return magim
def np2rgb(tens_im):
my_cm = matplotlib.cm.get_cmap('BuGn')
N = np.shape(tens_im)[0]
# print(N, 'N')
mapped_data = [None for n in range(N)]
for n in range(N):
normed_data = (np.asarray(tens_im[n]) - torch.min(tens_im[n]).item()) / (torch.max(tens_im[n]).item() - torch.min(tens_im[n]).item())
mapped_data[n] = torch.from_numpy(my_cm(normed_data))
return mapped_data
def Visualizer(image):
mpimg.imsave('big_mi.png', np.flip(np.asarray(image.squeeze(0).squeeze(0))))
def VisualizerR(image):
mpimg.imsave('big_miResult1.png', np.flip(np.asarray(image.squeeze(0).squeeze(0))))
def VisualizerR2(image):
mpimg.imsave('big_miResult2.png', np.flip(np.asarray(image.squeeze(0).squeeze(0))))
data1, rate = load_file('./test/wood/43209071.ogg')
data2, _ = load_file('./test/crow/43215281.ogg')
data1 = data1[:, 0]
limit = data1.shape[0]
data2 = data2[0:limit, 0]
audio_mix = filt(data1 + data2, rate)
magmix, _ = _stft(audio_mix)
print(magmix.shape)
im_mig = create_im(magmix)
print(np.asarray(im_mig).shape)
Visualizer(im_mig)
net = UNet(n_channels=1, n_classes=2)
checkpoint = torch.load('Saved_models/model5_5000.pth.tar')
net.load_state_dict(checkpoint['model_state_dict'])
mask_pred = net(im_mig)
print(mask_pred.size())
mask_pred.squeeze(0)
mapped = np2rgb(mask_pred)
VisualizerR(mapped[0])
VisualizerR2(mapped[1])
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 27 01:04:00 2019
Created on Fri Jun 28 16:14:55 2019
@author: felix
"""
......@@ -13,10 +13,10 @@ from scipy import signal
import numpy as np
import torch.nn.functional as F
import torch
import soundfile as sf
from matplotlib import image as mpimg
from unet import UNet
from matplotlib import image as mpimg
import matplotlib
import soundfile as sf
def create_list(path, ext):
list_names = []
......@@ -25,9 +25,38 @@ def create_list(path, ext):
list_names.append(os.path.join(root, filename))
return list_names
#def load_file(file):
# audio_raw, rate = librosa.load(file, sr=None, mono=True)
# return audio_raw, rate
def load_file(file):
audio_raw, rate = sf.read(file)
return audio_raw, rate
audio_raw, rate = sf.read(file)
marc = rate * 5
audio = audio_raw[0:int(marc), 0]
return audio, rate
def Visualizer(list_files):
# print(len(list_files), 'len list files')
for ii in range(len(list_files)):
mpimg.imsave('Input{}.png'.format(ii), np.flip(np.asarray(list_files[ii])))
def Visualizer2(list_files):
# print(len(list_files), 'len list files')
for ii in range(len(list_files)):
mpimg.imsave('result{}.png'.format(ii), np.flip(np.asarray(list_files[ii])))
def np2rgb(tens_im):
my_cm = matplotlib.cm.get_cmap('Reds')
N = np.shape(tens_im)[0]
# print(N, 'N')
tens_im = tens_im.detach()
mapped_data = [None for n in range(N)]
for n in range(N):
normed_data = (np.asarray(tens_im[n]) - torch.min(tens_im[n]).item()) / (torch.max(tens_im[n]).item() - torch.min(tens_im[n]).item())
mapped_data[n] = torch.from_numpy(my_cm(normed_data))
return mapped_data
def filt(audio_raw, rate):
......@@ -56,6 +85,15 @@ def _stft(audio):
return tch_mag, tch_phase
def threshold(mag):
gt_mask = torch.zeros(mag.shape[2], mag.shape[3])
av = np.mean(mag[0, 0].numpy())
vari = np.var(mag[0, 0].numpy())
param = av + np.sqrt(vari)*2 #threshold
gt_mask = (mag[0, 0] > param).float()
return gt_mask
def warpgrid(HO, WO, warp=True):
# meshgrid
x = np.linspace(-1, 1, WO)
......@@ -74,59 +112,31 @@ def create_im(mag):
T = mag.shape[3]
# 0.0 warp the spectrogram
grid_warp = torch.from_numpy(warpgrid(256, T, warp=True))
magim = F.grid_sample(mag, grid_warp)
magim = torch.log(F.grid_sample(mag, grid_warp))
return magim
def np2rgb(tens_im):
my_cm = matplotlib.cm.get_cmap('BuGn')
N = np.shape(tens_im)[0]
# print(N, 'N')
mapped_data = [None for n in range(N)]
for n in range(N):
normed_data = (np.asarray(tens_im[n]) - torch.min(tens_im[n]).item()) / (torch.max(tens_im[n]).item() - torch.min(tens_im[n]).item())
mapped_data[n] = torch.from_numpy(my_cm(normed_data))
return mapped_data
def Visualizer(image):
mpimg.imsave('big_mi.png', np.flip(np.asarray(image.squeeze(0).squeeze(0))))
def VisualizerR(image):
mpimg.imsave('big_miResult1.png', np.flip(np.asarray(image.squeeze(0).squeeze(0))))
def VisualizerR2(image):
mpimg.imsave('big_miResult2.png', np.flip(np.asarray(image.squeeze(0).squeeze(0))))
data1, rate = load_file('./test/wood/43209071.ogg')
data2, _ = load_file('./test/crow/43215281.ogg')
data1 = data1[:, 0]
limit = data1.shape[0]
data2 = data2[0:limit, 0]
audio_mix = filt(data1 + data2, rate)
magmix, _ = _stft(audio_mix)
print(magmix.shape)
im_mig = create_im(magmix)
print(np.asarray(im_mig).shape)
Visualizer(im_mig)
crow = './data/test/crow/43215281.ogg'
wood = './data/test/wood/43209071.ogg'
file1, sr = load_file(crow)
file2, _ = load_file(wood)
mix = file1 + file2
filtmix = filt(mix, sr)
magmix, _ = _stft(filtmix)
im = create_im(magmix)
print(im.size(), 'im size')
Visualizer(im.squeeze(0))
#
net = UNet(n_channels=1, n_classes=2)
checkpoint = torch.load('Saved_models/model5_5000.pth.tar')
net.load_state_dict(checkpoint['model_state_dict'])
mask_pred = net(im_mig)
print(mask_pred.size())
mask_pred.squeeze(0)
mapped = np2rgb(mask_pred)
VisualizerR(mapped[0])
VisualizerR2(mapped[1])
net.load_state_dict(checkpoint['model_state_dict'])
masks_pred = net(im)
print(masks_pred.size(), 'mask pred size')
mapped_data = np2rgb(masks_pred.squeeze(0))
Visualizer2(mapped_data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment