Commit bb430801 authored by Félix Michaud's avatar Félix Michaud
Browse files

before images

parents 7ad181f0 846b6c59
......@@ -41,7 +41,7 @@ def Visualizer(r_input):
def Visualizer2(list_files, model):
# print(len(list_files), 'len list files')
for ii in range(len(list_files)):
mpimg.imsave('./simulation/resultmodel_{}_{}.png'.format(ii, model), np.flipud(np.asarray(list_files[ii])))
mpimg.imsave('./simul_npt/resultmodel_{}_{}.png'.format(ii, model), np.flipud(np.asarray(list_files[ii])))
def np2rgb(tens_im):
......@@ -54,7 +54,7 @@ def np2rgb(tens_im):
tens_im = tens_im.detach()
mapped_data = [None for n in range(N)]
for n in range(N):
thres_im = tens_im[n] > 0.5
thres_im = tens_im[n] > 0.6
normed_data = (np.asarray(thres_im) - torch.min(thres_im).item()) / (torch.max(thres_im).item() - torch.min(thres_im).item())
mapped_data[n] = torch.from_numpy(colors[n](normed_data))
return mapped_data
......@@ -130,7 +130,7 @@ def build_audio(pred_masks, sr, magmix, phasemix, model):
model = []
root_dir2 = './Saved_models2/'
root_dir2 = './Saved_models_npt/'
ext = '.tar'
for root, dirnames, filenames in os.walk(root_dir2):
for filename in fnmatch.filter(filenames, '*' + ext):
......@@ -147,7 +147,7 @@ file3, _ = load_audio(noise)
mix = file1 + file2 + file3
filtmix = filt(mix, sr)
scipy.io.wavfile.write('./test/inputaudio.wav', sr, filtmix)
#scipy.io.wavfile.write('./test/inputaudio.wav', sr, filtmix)
magmix, phasemix = _stft(filtmix)
im = create_im(magmix)
phase = create_im(phasemix)
......@@ -157,11 +157,9 @@ Visualizer(im.squeeze(0))
net = UNet(n_channels=1, n_classes=2)
for mod in model:
real_mod = mod.rsplit('/', 1)[1].split('.', 1)[0]
checkpoint = torch.load('Saved_models2/{}.pth.tar'.format(real_mod))
checkpoint = torch.load('Saved_models_npt/{}.pth.tar'.format(real_mod))
net.load_state_dict(checkpoint['model_state_dict'])
masks_pred = net(im)
print(im.size(), 'im size')
print(masks_pred.size(), 'masks pred size')
# mapped_data = np2rgb(masks_pred.squeeze(0))
# Visualizer2(mapped_data, real_mod)
build_audio(masks_pred, sr, im, phase, real_mod)
mapped_data = np2rgb(masks_pred.squeeze(0))
Visualizer2(mapped_data, real_mod)
# build_audio(masks_pred, sr, im, phase, real_mod)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment