Commit cfbb377a authored by Touklakos's avatar Touklakos
Browse files

Correction batch_size

parent a7eff5a3
...@@ -80,12 +80,12 @@ def parse(): ...@@ -80,12 +80,12 @@ def parse():
parser.add_argument('--isDebug', dest='isDebug', action='store_true') parser.add_argument('--isDebug', dest='isDebug', action='store_true')
parser.add_argument('--patch_size', dest='patch_size', default=50) parser.add_argument('--patch_size', dest='patch_size', type=int, default=50)
parser.add_argument('--stride', dest='stride', default=50) parser.add_argument('--stride', dest='stride', type=int, default=50)
parser.add_argument('--step', dest='step', default=0) parser.add_argument('--step', dest='step', type=int, default=0)
parser.add_argument('--freq_save', dest='freq_save', type=int, default=1) parser.add_argument('--freq_save', dest='freq_save', type=int, default=1)
parser.add_argument('--phase_type', dest='phase_type', default="two") parser.add_argument('--phase_type', dest='phase_type', default="two")
parser.add_argument('--patch_per_image', dest='patch_per_image', default=384) parser.add_argument('--patch_per_image', dest='patch_per_image', type=int, default=384)
parser.add_argument('--noise_src_dir', dest='noise_src_dir', default="./chemin/") parser.add_argument('--noise_src_dir', dest='noise_src_dir', default="./chemin/")
parser.add_argument('--clean_src_dir', dest='clean_src_dir', default="./chemin/") parser.add_argument('--clean_src_dir', dest='clean_src_dir', default="./chemin/")
......
...@@ -67,7 +67,7 @@ parser.add_argument('--params', dest='params', type=str, default='', help='hyper ...@@ -67,7 +67,7 @@ parser.add_argument('--params', dest='params', type=str, default='', help='hyper
# check output arguments # check output arguments
#parser.add_argument('--from_file', dest='from_file', default="./data/img_clean_pats.npy", help='get pic from file') #parser.add_argument('--from_file', dest='from_file', default="./data/img_clean_pats.npy", help='get pic from file')
#parser.add_argument('--num_pic', dest='num_pic', type=int, default=10, help='number of pic to pick') #parser.add_argument('--num_pic', dest='num_pic', type=int, default=10, help='number of pic to pick')
args = parser.parse_args() #args = parser.parse_args()
#print(args.params['patch_size']) #print(args.params['patch_size'])
#hparams.parse(args.params) #hparams.parse(args.params)
......
...@@ -171,7 +171,7 @@ def run(args): ...@@ -171,7 +171,7 @@ def run(args):
exp = nt.Experiment(net, adam, statsManager, batch_size=args.batch_size, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch, freq_save=args.freq_save) exp = nt.Experiment(net, adam, statsManager, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch, freq_save=args.freq_save)
if not args.test_mode : if not args.test_mode :
...@@ -181,7 +181,7 @@ def run(args): ...@@ -181,7 +181,7 @@ def run(args):
trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size, nb_rotation=args.nb_rotation) trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size, nb_rotation=args.nb_rotation)
evalData = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size) evalData = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size)
exp.initData(trainData, evalData) exp.initData(trainData, evalData, batch_size=args.batch_size)
exp.run(num_epochs=args.num_epochs) exp.run(num_epochs=args.num_epochs)
if(args.graph): if(args.graph):
......
...@@ -146,7 +146,7 @@ class Experiment(object): ...@@ -146,7 +146,7 @@ class Experiment(object):
""" """
def __init__(self, net, optimizer, stats_manager, startEpoch=None, def __init__(self, net, optimizer, stats_manager, startEpoch=None,
input_dir=None, batch_size=16, perform_validation_during_training=False, freq_save=1): input_dir=None, perform_validation_during_training=False, freq_save=1):
# Initialize history # Initialize history
history = [] history = []
...@@ -173,13 +173,14 @@ class Experiment(object): ...@@ -173,13 +173,14 @@ class Experiment(object):
# "I found a checkpoint conflicting with the current setting.") # "I found a checkpoint conflicting with the current setting.")
self.load(startEpoch) self.load(startEpoch)
def initData(self, train_set, val_set): def initData(self, train_set, val_set, batch_size=16):
self.train_set = train_set self.train_set = train_set
self.val_set = val_set self.val_set = val_set
self.batch_size = batch_size
# Define data loaders # Define data loaders
self.train_loader = td.DataLoader(train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True) self.train_loader = td.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
self.val_loader = td.DataLoader(val_set, batch_size=self.batch_size, shuffle=False, drop_last=True, pin_memory=True) self.val_loader = td.DataLoader(val_set, batch_size=1, shuffle=False, drop_last=True, pin_memory=True)
self.training_data = train_set.getTrainingName() self.training_data = train_set.getTrainingName()
......
...@@ -37,7 +37,6 @@ import sys ...@@ -37,7 +37,6 @@ import sys
import re import re
import pathlib import pathlib
import numpy as np import numpy as np
import tensorflow as tf
from PIL import Image from PIL import Image
from scipy.io import loadmat, savemat from scipy.io import loadmat, savemat
from glob import glob from glob import glob
...@@ -402,7 +401,7 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False): ...@@ -402,7 +401,7 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False):
return data_n return data_n
elif (phase_type == 'two') & (phase_augmentation == True): elif (phase_type == 'two') & (phase_augmentation == True):
numPatch = data.shape[0] numPatch = data.shape[0]
newshape = (numPatch * 4, data.shape[1], data.shape[2], data.shape[3]) newshape = (numPatch * 8, data.shape[1], data.shape[2], data.shape[3])
data_n = np.zeros(shape = newshape) data_n = np.zeros(shape = newshape)
cpt = 0 cpt = 0
for k in range(numPatch): for k in range(numPatch):
...@@ -410,11 +409,11 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False): ...@@ -410,11 +409,11 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False):
data_n[1*numPatch + k,:,:,0] = np.sin( data[k,:,:,0]) data_n[1*numPatch + k,:,:,0] = np.sin( data[k,:,:,0])
data_n[2*numPatch + k,:,:,0] = np.cos( np.transpose( data[k,:,:,0]) ) data_n[2*numPatch + k,:,:,0] = np.cos( np.transpose( data[k,:,:,0]) )
data_n[3*numPatch + k,:,:,0] = np.sin( np.transpose( data[k,:,:,0]) ) data_n[3*numPatch + k,:,:,0] = np.sin( np.transpose( data[k,:,:,0]) )
#data_n[4*numPatch + k,:,:,0] = np.cos( math.pi/4 + data[k,:,:,0]) data_n[4*numPatch + k,:,:,0] = np.cos( math.pi/4 + data[k,:,:,0])
#data_n[5*numPatch + k,:,:,0] = np.sin( math.pi/4 + data[k,:,:,0]) data_n[5*numPatch + k,:,:,0] = np.sin( math.pi/4 + data[k,:,:,0])
#data_n[6*numPatch + k,:,:,0] = np.cos( np.transpose( math.pi/4 + data[k,:,:,0]) ) data_n[6*numPatch + k,:,:,0] = np.cos( np.transpose( math.pi/4 + data[k,:,:,0]) )
#data_n[7*numPatch + k,:,:,0] = np.sin( np.transpose( math.pi/4 + data[k,:,:,0]) ) data_n[7*numPatch + k,:,:,0] = np.sin( np.transpose( math.pi/4 + data[k,:,:,0]) )
print('nb of cos / sin / cos + flipud / sin + flipud: ', numPatch * 4) print('nb of cos / sin / cos + flipud / sin + flipud: ', numPatch * 8)
return data_n return data_n
else: else:
print('[!] phase type not exists (phi|cos|sin|two)') print('[!] phase type not exists (phi|cos|sin|two)')
...@@ -593,13 +592,6 @@ def cal_std_phase(im1, im2): ...@@ -593,13 +592,6 @@ def cal_std_phase(im1, im2):
return dev return dev
def tf_psnr(im1, im2):
# assert pixel value range is 0-1
#mse = tf.losses.mean_squared_error(labels=im2 * 255.0, predictions=im1 * 255.0)
mse = tf.losses.mean_squared_error(labels=im2, predictions=im1)
return 10.0 * (tf.log(1 / mse) / tf.log(10.0))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment