Commit 8e223d52 authored by Marie Tahon's avatar Marie Tahon
Browse files

add holodncnn package

parent 0945ffde
import argparse
def parse():
'''
Add arguments.
'''
parser = argparse.ArgumentParser(
description='DnCNN')
# #name for user #name for program #type #default values #explanation sentences
parser.add_argument('--input_dir', dest='input_dir', type=str, default='./PyTorchCheckpoint/', help='directory of saved checkpoints for denoising operation or retraining')
parser.add_argument('--output_dir', dest='output_dir', type=str, default=None, help='directory of saved checkpoints for denoising operation or retraining')
parser.add_argument('--train_dir', dest='train_dir', type=str, default='./Holography/HOLODEEPmat/DATABASE/', help='directory of training database')
parser.add_argument('--eval_dir', dest='eval_dir', type=str, default='./Holography/HOLODEEPmat/DATABASE/', help='directory of evaluation database')
parser.add_argument('--test_dir', dest='test_dir', type=str, default='./Holography/DATAEVAL/DATAEVAL/', help='directory of testing database')
parser.add_argument('--save_test_dir', dest='save_test_dir', type=str, default='./TestImages/', help='directory where results of de-noising operation will be saved')
parser.add_argument('--train_patterns', dest='train_patterns', type=int, nargs='+', default=(1, 2, 3, 4, 5), help='patterns used for training')
parser.add_argument('--train_noises', dest='train_noises', type=str, default="0-1-1.5-2-2.5", help='noise levels used for training ')
parser.add_argument('--eval_patterns', dest='eval_patterns', type=int, nargs='+', default=(1, 2, 3, 4, 5), help='patterns used for eval')
parser.add_argument('--eval_noises', dest='eval_noises', type=str, default="0-1-1.5-2-2.5", help='noise levels used for eval ')
parser.add_argument('--test_patterns', dest='test_patterns', type=int, nargs='+', default=(1, 2, 3, 4, 5), help='patterns used for testing')
parser.add_argument('--test_noises', dest='test_noises', type=str, default="0-1-1.5-2-2.5", help='noise levels used for testing ')
parser.add_argument('--clean_train', dest='clean_train', type=str, default='data1/img_clean_train_1_0_two_50_50_3.npy', help='filepath of noise free file for training')
parser.add_argument('--noisy_train', dest='noisy_train', type=str, default='data1/img_noisy_train_1_0_two_50_50_3.npy', help='filepath of noisy file for training')
parser.add_argument('--clean_eval', dest='clean_eval', type=str, default='data1/img_clean_train_1-2-3_0-1-1.5two.npy', help='filepath of noise free file for eval')
parser.add_argument('--noisy_eval', dest='noisy_eval', type=str, default='data1/img_noisy_train_1-2-3_0-1-1.5two.npy', help='filepath of noisy file for eval')
parser.add_argument('--num_epochs', dest='num_epochs', type=int, default=200, help='number of epochs to train')
parser.add_argument('--D', dest='D', type=int, default=4, help='number of dilated convolutional layer (resBlock)')
parser.add_argument('--C', dest='C', type=int, default=64, help='kernel size of convolutional layer')
parser.add_argument('--plot', dest='plot', action='store_true', help='plot loss during training')
parser.add_argument('--lr', dest='lr', type=float, default=1e-3, help='learning rate for training')
parser.add_argument('--train_image_size', dest='train_image_size',type=int, nargs='+', default=(50, 50), help='size of train images')
parser.add_argument('--eval_image_size', dest='eval_image_size', type=int, nargs='+', default=(1024, 1024), help='size of eval images')
parser.add_argument('--test_image_size', dest='test_image_size', type=int, nargs='+', default=(1024, 1024), help='size of test images')
parser.add_argument('--image_mode', dest='image_mode', type=int, default=1, help='1 or 3 (black&white or RGB)')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=384, help="")
parser.add_argument('--epoch', dest='epoch', type=int, default=None, help='epoch\'s number from which we going to retrain')
parser.add_argument('--test_mode', dest='test_mode', action='store_true', help='testing phase')
parser.add_argument('--tsf', dest='tsf', action='store_true', help='add if code in tensorflow')
parser.add_argument('--graph', dest='graph', action='store_true', help='add if graph is visible')
parser.add_argument('--graph_fin', dest='graph_fin', action='store_true', help='add if graph is visible during training')
# Tensorflow arguments
parser.add_argument('--use_gpu', dest='use_gpu', type=int, default=1, help='gpu flag, 1 for GPU and 0 for CPU')
parser.add_argument('--checkpoint_dir', dest='ckpt_dir', type=str, default='./checkpoint', help='models are saved here')
parser.add_argument('--ckpt_dir', dest='ckpt_dir', type=str, default='./checkpoint', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', type=str, default='./sample', help='sample are saved here')
#parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
parser.add_argument('--params', dest='params', type=str, default='', help='hyper parameters')
parser.add_argument('--test_noisy_img', dest='noisy_img', type=str, help='path of the noisy image for testing')
parser.add_argument('--test_noisy_key', dest='noisy_key', type=str, help='key for noisy matlab image for testing')
parser.add_argument('--test_clean_img', dest='clean_img', type=str, help='path of the clean image for testing')
parser.add_argument('--test_clean_key', dest='clean_key', type=str, help='key for clean matlab image for testing')
parser.add_argument('--test_flip', dest='flip', type=bool, default=False, help='option for upside down flip of noisy (and clean) test image')
#parser.add_argument('--test_ckpt_index', dest='ckpt_index', type=str, default='', help='name and directory of the checkpoint that will be restored.')
parser.add_argument('--save_dir', dest='save_dir', type=str, default='./data1/', help='dir of patches')
parser.add_argument('--exp_file', dest='exp_file', type=str, help='experiment file')
parser.add_argument('--nb_iteration', dest='nb_iteration', type=int, default=3, help='number of iteration for de-noising operation')
parser.add_argument('--nb_rotation', dest='nb_rotation', type=int, default=8, help='number of ration for data augmentation')
parser.add_argument('--isDebug', dest='isDebug', action='store_true')
parser.add_argument('--patch_size', dest='patch_size', type=int, default=50)
parser.add_argument('--stride', dest='stride', type=int, default=50)
parser.add_argument('--step', dest='step', type=int, default=0)
parser.add_argument('--freq_save', dest='freq_save', type=int, default=1)
parser.add_argument('--phase_type', dest='phase_type', default="two")
parser.add_argument('--patch_per_image', dest='patch_per_image', type=int, default=384)
parser.add_argument('--noise_src_dir', dest='noise_src_dir', default="./chemin/")
parser.add_argument('--clean_src_dir', dest='clean_src_dir', default="./chemin/")
parser.add_argument('--perform_validation', dest='perform_validation', action="store_true")
parser.add_argument('--scales', dest='scales', type=int, nargs='+', default=(1), help='size of test images')
parser.add_argument('--originalsize', dest='originalsize', type=int, nargs='+', default=(1024, 1024), help='size of test images')
return parser.parse_args()
class Args():
'''
For jupyter notebook
Not up to date
'''
def __init__(self):
self.root_dir = '../dataset/BSDS300/images'
self.output_dir = '../checkpoints/'
self.noisy = 'data1/img_noisy_train_1-2-3-4-5_0_two_50_50_9.npy'
self.clean = 'data1/img_clean_train_1-2-3-4-5_0_two_50_50_9.npy'
self.num_epochs = 200
self.D = 4
self.C = 64
self.plot = False
self.model = 'dudncnn'
self.lr = 1e-3
self.image_size = (180, 180)
self.test_image_size = (320, 320)
self.batch_size = 60
self.sigma = 30
self.is_training = False
self.image_mode = 1
self.graph = False
import os
import torch
import torch.utils.data as td
import torchvision as tv
import numpy as np
from PIL import Image
from utils import *
from argument import *
class NoisyBSDSDataset(td.Dataset):
""" This class allow us to load and use data needed for model training
"""
def __init__(self, image_mode, image_size):
""" Initialize the data loader
Arguments:
clean(String) : The path of clean data
noisy(String) : The path of noisy data
image_mode(int) : The number of channel of the clean and noisy data
image_size((int, int)) : The size (in pixels) of clean and noisy data
"""
super(NoisyBSDSDataset, self).__init__()
self.image_mode = image_mode
self.image_size = image_size
def __len__(self):
return len(self.clean)
def __repr__(self):
return "NoisyBSDSDataset(image_mode={}, image_size={})". \
format(self.image_mode, self.image_size)
def getTrainingName(self):
return self.training_name
def __getitem__(self, idx):
cleanSample = self.clean[idx]
noisySample = self.noisy[idx]
cleanSample = cleanSample.reshape(self.image_mode, self.image_size[0], self.image_size[1])
noisySample = noisySample.reshape(self.image_mode, self.image_size[0], self.image_size[1])
cleanSample = torch.Tensor(cleanSample)
noisySample = torch.Tensor(noisySample)
return noisySample, cleanSample
class TrainDataset(NoisyBSDSDataset):
def __init__(self, clean, noisy, image_mode, image_size, nb_rotation=8):
""" Initialize the data loader
Arguments:
clean(String) : The path of clean data
noisy(String) : The path of noisy data
image_mode(int) : The number of channel of the clean and noisy data
image_size((int, int)) : The size (in pixels) of clean and noisy data
"""
super(TrainDataset, self).__init__(image_mode, image_size)
self.training_name = noisy
print("clean : ", clean)
print("noisy : ", noisy)
self.clean, self.noisy = load_train_data(filepath=clean, noisyfilepath=noisy, phase_type="two")
rdm = np.random.randint(0, 2, self.clean.shape[0])
self.clean = normalize_data(self.clean, 'two', rdm, True)
self.noisy = normalize_data(self.noisy, 'two', rdm, True)
self.clean = rotate_data(self.clean, nb_rotation)
self.noisy = rotate_data(self.noisy, nb_rotation)
print("data_size : ", self.clean.shape)
print("data_type : ", type(self.clean))
class EvalDataset(NoisyBSDSDataset):
def __init__(self, eval_dir, noises, patterns, image_mode, image_size):
""" Initialize the data loader
Arguments:
clean(String) : The path of clean data
noisy(String) : The path of noisy data
image_mode(int) : The number of channel of the clean and noisy data
image_size((int, int)) : The size (in pixels) of clean and noisy data
"""
super(EvalDataset, self).__init__(image_mode, image_size)
self.training_name = eval_dir + "-".join([str(pattern) for pattern in patterns]) + "/" + noises
self.clean, self.noisy = from_DATABASE(eval_dir, noises, patterns, False)
self.clean = np.array(self.clean)
self.noisy = np.array(self.noisy)
self.clean = self.clean.reshape(-1, self.image_size[0], self.image_size[1], self.image_mode)
self.noisy = self.noisy.reshape(-1, self.image_size[0], self.image_size[1], self.image_mode)
print("data_size : ", self.clean.shape)
print("data_type : ", type(self.clean))
def __getitem__(self, idx):
cleanSample = self.clean[idx]
noisySample = self.noisy[idx]
cleanSample = cleanSample.reshape(self.image_mode, self.image_size[0], self.image_size[1])
noisySample = noisySample.reshape(self.image_mode, self.image_size[0], self.image_size[1])
cleanSample = torch.Tensor(cleanSample)
noisySin = torch.Tensor(normalize_data(noisySample, 'sin', None))
noisyCos = torch.Tensor(normalize_data(noisySample, 'cos', None))
return (noisySin, noisyCos), cleanSample
#
# This file is part of DnCnn4Holo.
#
# Adapted from https://github.com/wbhu/DnCNN-tensorflow by Hu Wenbo
#
# DnCnn4Holo is a python script for phase image denoising.
# Home page: https://git-lium.univ-lemans.fr/tahon/dncnn-tensorflow-holography
#
# DnCnn4Holo is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# DnCnn4Holo is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DnCnn4Holo. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2019-2022 Marie Tahon
:mod:`__init__.py` definition function for DnCnn4Holo
"""
from holodncnn.utils import save_MAT_images, save_images, cal_psnr, cal_std_phase, rad_to_flat
from holodncnn.utils import *
from holodncnn.model import DnCNN
from holodncnn.nntools import Experiment, DenoisingStatsManager
from holodncnn.holosets import TrainHoloset, EvalHoloset
...@@ -4,29 +4,24 @@ matplotlib.use('TkAgg') ...@@ -4,29 +4,24 @@ matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg#, NavigationToolbar2TkAgg from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg#, NavigationToolbar2TkAgg
import multiprocessing import multiprocessing
import time import time
import random
from tkinter import * from tkinter import *
import torch from .model import *
from utils import * # from argument import *
import nntools as nt # from main_holo_tsf import *
from model import *
from data import *
from argument import *
from main_holo_tsf import *
import numpy as np import numpy as np
import datetime import datetime
from argument import *
#Create a window #Create a window
window=Tk() window=Tk()
# Main need arguments to use it # Main need arguments to use it
# ex: python main_test.py --output_dir ./PyTorchCheckpoint/ # ex: python main_test.py --output_dir ./PyTorchCheckpoint/
def main(): def main():
"""TO UPDATE with new yaml argument file"""
args = parse() args = parse()
max_epochs = args.num_epochs max_epochs = args.num_epochs
...@@ -70,7 +65,6 @@ def main(): ...@@ -70,7 +65,6 @@ def main():
print ('Done') print ('Done')
# =================================== # ===================================
def plot(): #Function to create the base plot, make sure to make global the lines, axes, canvas and any part that you would want to update later def plot(): #Function to create the base plot, make sure to make global the lines, axes, canvas and any part that you would want to update later
global line,ax,canvas global line,ax,canvas
...@@ -86,9 +80,9 @@ def plot(): #Function to create the base plot, make sure to make global the l ...@@ -86,9 +80,9 @@ def plot(): #Function to create the base plot, make sure to make global the l
canvas._tkcanvas.pack(side=TOP, fill=BOTH, expand=1) canvas._tkcanvas.pack(side=TOP, fill=BOTH, expand=1)
def update_checkpoint(exp_file,max): def update_checkpoint(exp_file, max):
checkpoint = None checkpoint = None
test = None # test = None
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
while(checkpoint is None): while(checkpoint is None):
...@@ -151,5 +145,6 @@ def simulation(q,ex_file,max_epochs): ...@@ -151,5 +145,6 @@ def simulation(q,ex_file,max_epochs):
q.put('Q') q.put('Q')
if __name__ == '__main__': if __name__ == '__main__':
main() main()
import torch import torch
from utils import NNRegressor
import numpy as np import numpy as np
from .nntools import NNRegressor
# from .nntools import * #as nt
class DnCNN(NNRegressor): class DnCNN(NNRegressor):
......
...@@ -13,7 +13,10 @@ from abc import ABC, abstractmethod ...@@ -13,7 +13,10 @@ from abc import ABC, abstractmethod
import datetime import datetime
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import utils from .utils import save_images
# import utils
toto = 4
class NeuralNetwork(nn.Module, ABC): class NeuralNetwork(nn.Module, ABC):
...@@ -54,6 +57,26 @@ class NeuralNetwork(nn.Module, ABC): ...@@ -54,6 +57,26 @@ class NeuralNetwork(nn.Module, ABC):
pass pass
class NNRegressor(NeuralNetwork):
"""
This class represent an abstract neural network
"""
def __init__(self):
super(NNRegressor, self).__init__()
self.mse = torch.nn.MSELoss()
def criterion(self, y ,d):
"""
This method return a float that evaluation the accuracy of the network
Arguments:
y (torch.Tensor) : The predicted noise free reference
d (torch.Tensor) : The clean reference
"""
return self.mse(y, d)
class StatsManager(object): class StatsManager(object):
""" """
A class meant to track the loss during a neural network learning experiment. A class meant to track the loss during a neural network learning experiment.
...@@ -102,6 +125,46 @@ class StatsManager(object): ...@@ -102,6 +125,46 @@ class StatsManager(object):
return self.running_loss / self.number_update return self.running_loss / self.number_update
class DenoisingStatsManager(StatsManager):
"""
This class manage the stats of an experiment
"""
def __init__(self):
super(DenoisingStatsManager, self).__init__()
def init(self):
super(DenoisingStatsManager, self).init()
self.running_psnr = 0
def accumulate(self, loss, x, y, d):
"""
This method add new results for the stats manager
Arguments:
loss (???)
x (torch.Tensor) : The noisy reference
y (torch.Tensor) : The predicted noise free reference
d (torch.Tensor) : The clean reference
"""
#print("test accumulate")
super(DenoisingStatsManager, self).accumulate(loss, x, y, d)
n = x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3]
self.running_psnr += 10*torch.log10(4*n/(torch.norm(y-d)**2))
def summarize(self):
"""
This method return the actual stats managed by the stats manager
"""
loss = super(DenoisingStatsManager, self).summarize()
psnr = self.running_psnr / self.number_update if(self.number_update !=0) else self.running_psnr
return {'loss': loss, 'PSNR': psnr}
class Experiment(object): class Experiment(object):
""" """
A class meant to run a neural network learning experiment. A class meant to run a neural network learning experiment.
...@@ -351,25 +414,27 @@ class Experiment(object): ...@@ -351,25 +414,27 @@ class Experiment(object):
self.stats_manager.init() self.stats_manager.init()
self.net.eval() self.net.eval()
with torch.no_grad(): with torch.no_grad():
for (xSin, xCos), d in self.val_loader: for x, d in self.val_loader: # (xSin, xCos)
xSin = torch.sin(x)
xCos = torch.cos(x)
xSin, xCos, d = xSin.to(self.net.device), xCos.to(self.net.device), d.to(self.net.device) xSin, xCos, d = xSin.to(self.net.device), xCos.to(self.net.device), d.to(self.net.device)
ySin = self.net.forward(xSin) ySin = self.net.forward(xSin)
yCos = self.net.forward(xCos) yCos = self.net.forward(xCos)
x = torch.angle(xCos + xSin * 1J) x = torch.angle(xCos + xSin * 1J)
y = torch.angle(yCos + ySin * 1J) y = torch.angle(yCos + ySin * 1J)
utils.save_images(os.path.join('.', 'test.tiff'), y.cpu().numpy().reshape(1, 1024, 1024, 1)) save_images(os.path.join('.', 'test.tiff'), y.cpu().numpy().squeeze())
loss = self.net.criterion(y, d) loss = self.net.criterion(y, d)
self.stats_manager.accumulate(loss.item(), x, y, d) self.stats_manager.accumulate(loss.item(), x, y, d)
self.net.train() self.net.train()
return self.stats_manager.summarize() return self.stats_manager.summarize()
def getConfig(): def getConfig():
param = "null" # config = "null"
with open(os.path.join(self.input_dir, 'config.txt'), 'r') as f: with open(os.path.join(self.input_dir, 'config.txt'), 'r') as f:
param = f.read()[:-1] config = f.read()[:-1]
return param return config
def test(self, noisy): def test(self, noisy):
......
#
# This file is part of DnCnn4Holo.
#
# Adapted from https://github.com/wbhu/DnCNN-tensorflow by Hu Wenbo
#
# DnCnn4Holo is a python script for phase image denoising.
# Home page: https://git-lium.univ-lemans.fr/tahon/dncnn-tensorflow-holography
#
# DnCnn4Holo is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# DnCnn4Holo is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DnCnn4Holo. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2019-2020 Marie Tahon
:mod:`utils.py` definition of util function for DnCnn4Holo
"""
import numpy as np
from PIL import Image
from scipy.io import savemat
__license__ = "LGPL"
__author__ = "Marie Tahon"
__copyright__ = "Copyright 2019-2020 Marie Tahon"
__maintainer__ = "Marie Tahon"
__email__ = "marie.tahon@univ-lemans.fr"
__status__ = "Production"
#__docformat__ = 'reStructuredText'
def extract_sess_name(lp, ln, pt, stride, ps, np):
"""DEPRECATED
This method return a sessions name with his given parameters