Commit 8f519968 authored by Touklakos's avatar Touklakos
Browse files

MAJ Sprint 5

parent 0d7735ed
......@@ -76,12 +76,14 @@ def parse():
parser.add_argument('--exp_file', dest='exp_file', type=str, help='experiment file')
parser.add_argument('--nb_iteration', dest='nb_iteration', type=int, default=3, help='number of iteration for de-noising operation')
parser.add_argument('--nb_rotation', dest='nb_rotation', type=int, default=8, help='number of ration for data augmentation')
parser.add_argument('--isDebug', dest='isDebug', action='store_true')
parser.add_argument('--patch_size', dest='patch_size', default=50)
parser.add_argument('--stride', dest='stride', default=50)
parser.add_argument('--step', dest='step', default=0)
parser.add_argument('--freq_save', dest='freq_save', type=int, default=1)
parser.add_argument('--phase_type', dest='phase_type', default="two")
parser.add_argument('--patch_per_image', dest='patch_per_image', default=384)
parser.add_argument('--noise_src_dir', dest='noise_src_dir', default="./chemin/")
......
......@@ -57,7 +57,7 @@ class NoisyBSDSDataset(td.Dataset):
class TrainDataset(NoisyBSDSDataset):
def __init__(self, clean, noisy, image_mode, image_size):
def __init__(self, clean, noisy, image_mode, image_size, nb_rotation=8):
""" Initialize the data loader
Arguments:
......@@ -80,10 +80,8 @@ class TrainDataset(NoisyBSDSDataset):
self.clean = normalize_data(self.clean, 'two', rdm, True)
self.noisy = normalize_data(self.noisy, 'two', rdm, True)
rotation = 8
self.clean = rotate_data(self.clean, rotation)
self.noisy = rotate_data(self.noisy, rotation)
self.clean = rotate_data(self.clean, nb_rotation)
self.noisy = rotate_data(self.noisy, nb_rotation)
print("data_size : ", self.clean.shape)
print("data_type : ", type(self.clean))
......
......@@ -65,7 +65,7 @@ def evaluate_on_HOLODEEP(args, exp):
patterns = args.test_patterns
noises = args.test_noises
clean, noisy = from_DATABASE(args.eval_dir, noises, patterns, False)
clean, noisy = from_DATABASE(args.eval_dir, noises, patterns, True)
clean = np.array(clean)
noisy = np.array(noisy)
......@@ -109,7 +109,7 @@ def evaluate_on_DATAEVAL(args, exp):
def denoise_img(args, noisy, clean, name, exp, nb_iteration=3):
def denoise_img(args, noisy, clean, name, exp):
"""This method is used to do and save a de-noising operation on a given image
Arguments:
......@@ -118,10 +118,10 @@ def denoise_img(args, noisy, clean, name, exp, nb_iteration=3):
clean (numpy.array) : The clean reference
name (str) : The name used to save the results
exp (Experiment) : The model used to do the de-noising operation
nb_iteration (int, optional) : The number of iteration to de-noise the image
"""
clean_pred_rad = noisy
nb_iteration = args.nb_iteration
for j in range(nb_iteration):
clean_pred_rad = denoising_single_image(args, clean_pred_rad, exp)
......@@ -146,11 +146,12 @@ def denoising_single_image(args, noisy, exp):
noisyPy_cos = torch.Tensor(normalize_data(noisyPy, 'cos', None))
noisyPy_sin = torch.Tensor(normalize_data(noisyPy, 'sin', None))
clean_pred_cos = exp.test(noisyPy_cos)
clean_pred_sin = exp.test(noisyPy_sin)
clean_pred_cos = exp.test(noisyPy_cos).detach().cpu().numpy()
clean_pred_sin = exp.test(noisyPy_sin).detach().cpu().numpy()
clean_pred_rad = torch.angle(clean_pred_cos + clean_pred_sin * 1J)
clean_pred_rad = clean_pred_rad.detach().cpu().numpy().reshape(1, args.test_image_size[0], args.test_image_size[1], args.image_mode)
clean_pred_rad = np.angle(clean_pred_cos + clean_pred_sin * 1J)
clean_pred_rad = clean_pred_rad.reshape(1, args.test_image_size[0], args.test_image_size[1], args.image_mode)
return clean_pred_rad
......@@ -163,8 +164,6 @@ def run(args):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size)
evalData = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size)
net = DnCNN(D=args.D, C=args.C, image_mode=args.image_mode).to(device)
adam = torch.optim.Adam(net.parameters(), lr=args.lr)
......@@ -172,12 +171,17 @@ def run(args):
exp = nt.Experiment(net, trainData, evalData, adam, statsManager, batch_size=args.batch_size, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch)
exp = nt.Experiment(net, adam, statsManager, batch_size=args.batch_size, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch, freq_save=args.freq_save)
if not args.test_mode :
print("\n=>Training until epoch :<===\n", args.num_epochs)
print("\n\Model training")
trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size, nb_rotation=args.nb_rotation)
evalData = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size)
exp.initData(trainData, evalData)
exp.run(num_epochs=args.num_epochs)
if(args.graph):
......@@ -197,6 +201,7 @@ def run(args):
if __name__ == '__main__':
args = parse()
print("\n\n")
......
......@@ -57,6 +57,4 @@ class DnCNN(NNRegressor):
h = torch.nn.functional.relu(self.conv[0](input))
for i in range(self.D):
h = torch.nn.functional.relu(self.bn[i](self.conv[i+1](h)))
y = self.conv[self.D+1](h)
z = input - y
return z
return input - self.conv[self.D+1](h)
......@@ -145,13 +145,8 @@ class Experiment(object):
set and the validation set. (default: False)
"""
def __init__(self, net, train_set, val_set, optimizer, stats_manager, startEpoch=None,
input_dir=None, batch_size=16, perform_validation_during_training=False):
# Define data loaders
train_loader = td.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
val_loader = td.DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
def __init__(self, net, optimizer, stats_manager, startEpoch=None,
input_dir=None, batch_size=16, perform_validation_during_training=False, freq_save=1):
# Initialize history
history = []
......@@ -168,7 +163,6 @@ class Experiment(object):
locs = {k: v for k, v in locals().items() if k is not 'self'}
self.__dict__.update(locs)
self.training_data = train_set.getTrainingName()
# Load checkpoint and check compatibility
if os.path.isfile(config_path):
......@@ -178,8 +172,17 @@ class Experiment(object):
# "Cannot create this experiment: "
# "I found a checkpoint conflicting with the current setting.")
self.load(startEpoch)
else:
self.save()
def initData(self, train_set, val_set):
self.train_set = train_set
self.val_set = val_set
# Define data loaders
self.train_loader = td.DataLoader(train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True)
self.val_loader = td.DataLoader(val_set, batch_size=self.batch_size, shuffle=False, drop_last=True, pin_memory=True)
self.training_data = train_set.getTrainingName()
@property
def epoch(self):
......@@ -295,6 +298,7 @@ class Experiment(object):
server without display, ``plot`` can be used to show statistics
on ``stdout`` or save statistics in a log file. (default: None)
"""
self.save()
self.net.train()
self.stats_manager.init()
start_epoch = self.epoch
......@@ -320,7 +324,7 @@ class Experiment(object):
self.save_train(time.time() - s)
if((self.epoch % 1) == 0):
if((self.epoch % self.freq_save) == 0):
self.save()
if plot is not None:
plot(self)
......
......@@ -9,6 +9,6 @@
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mano.brabant.etu@univ-lemans.fr
python3 main_holo.py --batch_size 16 --perform_validation --D 4 --num_epoch 100 --noisy_train data1/img_noisy_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy --clean_train data1/img_clean_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy
python3 main_holo.py --batch_size 16 --perform_validation --freq_save 2 --D 4 --num_epoch 100 --noisy_train data1/img_noisy_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy --clean_train data1/img_clean_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy
......@@ -9,6 +9,7 @@
runTest=$1
epoch=$2
D=$3
nbIteration=$4
#test1=/info/etu/m1/s171085/Projets/Portage-Keras-PyTorch/Portage-reseau-de-neurones-de-Keras-vers-PyTorch/dncnn-tensorflow-holography-master/Holography/DATAEVAL/DATAEVAL/DATA_1_Phase_Type1_2_0.25_1.5_4_50.mat
#test2=/info/etu/m1/s171085/Projets/Portage-Keras-PyTorch/Portage-reseau-de-neurones-de-Keras-vers-PyTorch/dncnn-tensorflow-holography-master/Holography/DATAEVAL/DATAEVAL/DATA_20_Phase_Type4_2_0.25_2.5_4_100.mat
......@@ -17,4 +18,4 @@ D=$3
#keyClean='Phase'
python main_holo.py --test_mode --input_dir $runTest --epoch $epoch --D $D
python main_holo.py --test_mode --input_dir $runTest --epoch $epoch --D $D --nb_iteration $nbIteration
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment