Commit 1bbc378f authored by Marie Tahon's avatar Marie Tahon
Browse files

remove torchvision from data.py and ref to tf in utils

parents 96f42832 cfbb377a
...@@ -80,12 +80,12 @@ def parse(): ...@@ -80,12 +80,12 @@ def parse():
parser.add_argument('--isDebug', dest='isDebug', action='store_true') parser.add_argument('--isDebug', dest='isDebug', action='store_true')
parser.add_argument('--patch_size', dest='patch_size', default=50) parser.add_argument('--patch_size', dest='patch_size', type=int, default=50)
parser.add_argument('--stride', dest='stride', default=50) parser.add_argument('--stride', dest='stride', type=int, default=50)
parser.add_argument('--step', dest='step', default=0) parser.add_argument('--step', dest='step', type=int, default=0)
parser.add_argument('--freq_save', dest='freq_save', type=int, default=1) parser.add_argument('--freq_save', dest='freq_save', type=int, default=1)
parser.add_argument('--phase_type', dest='phase_type', default="two") parser.add_argument('--phase_type', dest='phase_type', default="two")
parser.add_argument('--patch_per_image', dest='patch_per_image', default=384) parser.add_argument('--patch_per_image', dest='patch_per_image', type=int, default=384)
parser.add_argument('--noise_src_dir', dest='noise_src_dir', default="./chemin/") parser.add_argument('--noise_src_dir', dest='noise_src_dir', default="./chemin/")
parser.add_argument('--clean_src_dir', dest='clean_src_dir', default="./chemin/") parser.add_argument('--clean_src_dir', dest='clean_src_dir', default="./chemin/")
......
...@@ -67,7 +67,7 @@ parser.add_argument('--params', dest='params', type=str, default='', help='hyper ...@@ -67,7 +67,7 @@ parser.add_argument('--params', dest='params', type=str, default='', help='hyper
# check output arguments # check output arguments
#parser.add_argument('--from_file', dest='from_file', default="./data/img_clean_pats.npy", help='get pic from file') #parser.add_argument('--from_file', dest='from_file', default="./data/img_clean_pats.npy", help='get pic from file')
#parser.add_argument('--num_pic', dest='num_pic', type=int, default=10, help='number of pic to pick') #parser.add_argument('--num_pic', dest='num_pic', type=int, default=10, help='number of pic to pick')
args = parser.parse_args() #args = parser.parse_args()
#print(args.params['patch_size']) #print(args.params['patch_size'])
#hparams.parse(args.params) #hparams.parse(args.params)
......
...@@ -171,7 +171,7 @@ def run(args): ...@@ -171,7 +171,7 @@ def run(args):
exp = nt.Experiment(net, adam, statsManager, batch_size=args.batch_size, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch, freq_save=args.freq_save) exp = nt.Experiment(net, adam, statsManager, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch, freq_save=args.freq_save)
if not args.test_mode : if not args.test_mode :
...@@ -181,7 +181,7 @@ def run(args): ...@@ -181,7 +181,7 @@ def run(args):
trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size, nb_rotation=args.nb_rotation) trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size, nb_rotation=args.nb_rotation)
evalData = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size) evalData = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size)
exp.initData(trainData, evalData) exp.initData(trainData, evalData, batch_size=args.batch_size)
exp.run(num_epochs=args.num_epochs) exp.run(num_epochs=args.num_epochs)
if(args.graph): if(args.graph):
......
...@@ -146,7 +146,7 @@ class Experiment(object): ...@@ -146,7 +146,7 @@ class Experiment(object):
""" """
def __init__(self, net, optimizer, stats_manager, startEpoch=None, def __init__(self, net, optimizer, stats_manager, startEpoch=None,
input_dir=None, batch_size=16, perform_validation_during_training=False, freq_save=1): input_dir=None, perform_validation_during_training=False, freq_save=1):
# Initialize history # Initialize history
history = [] history = []
...@@ -173,13 +173,14 @@ class Experiment(object): ...@@ -173,13 +173,14 @@ class Experiment(object):
# "I found a checkpoint conflicting with the current setting.") # "I found a checkpoint conflicting with the current setting.")
self.load(startEpoch) self.load(startEpoch)
def initData(self, train_set, val_set): def initData(self, train_set, val_set, batch_size=16):
self.train_set = train_set self.train_set = train_set
self.val_set = val_set self.val_set = val_set
self.batch_size = batch_size
# Define data loaders # Define data loaders
self.train_loader = td.DataLoader(train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True) self.train_loader = td.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
self.val_loader = td.DataLoader(val_set, batch_size=self.batch_size, shuffle=False, drop_last=True, pin_memory=True) self.val_loader = td.DataLoader(val_set, batch_size=1, shuffle=False, drop_last=True, pin_memory=True)
self.training_data = train_set.getTrainingName() self.training_data = train_set.getTrainingName()
......
...@@ -37,7 +37,7 @@ import sys ...@@ -37,7 +37,7 @@ import sys
import re import re
import pathlib import pathlib
import numpy as np import numpy as np
#import tensorflow as tf import tensorflow as tf
from PIL import Image from PIL import Image
from scipy.io import loadmat, savemat from scipy.io import loadmat, savemat
from glob import glob from glob import glob
...@@ -402,7 +402,7 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False): ...@@ -402,7 +402,7 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False):
return data_n return data_n
elif (phase_type == 'two') & (phase_augmentation == True): elif (phase_type == 'two') & (phase_augmentation == True):
numPatch = data.shape[0] numPatch = data.shape[0]
newshape = (numPatch * 4, data.shape[1], data.shape[2], data.shape[3]) newshape = (numPatch * 8, data.shape[1], data.shape[2], data.shape[3])
data_n = np.zeros(shape = newshape) data_n = np.zeros(shape = newshape)
cpt = 0 cpt = 0
for k in range(numPatch): for k in range(numPatch):
...@@ -410,11 +410,11 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False): ...@@ -410,11 +410,11 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False):
data_n[1*numPatch + k,:,:,0] = np.sin( data[k,:,:,0]) data_n[1*numPatch + k,:,:,0] = np.sin( data[k,:,:,0])
data_n[2*numPatch + k,:,:,0] = np.cos( np.transpose( data[k,:,:,0]) ) data_n[2*numPatch + k,:,:,0] = np.cos( np.transpose( data[k,:,:,0]) )
data_n[3*numPatch + k,:,:,0] = np.sin( np.transpose( data[k,:,:,0]) ) data_n[3*numPatch + k,:,:,0] = np.sin( np.transpose( data[k,:,:,0]) )
#data_n[4*numPatch + k,:,:,0] = np.cos( math.pi/4 + data[k,:,:,0]) data_n[4*numPatch + k,:,:,0] = np.cos( math.pi/4 + data[k,:,:,0])
#data_n[5*numPatch + k,:,:,0] = np.sin( math.pi/4 + data[k,:,:,0]) data_n[5*numPatch + k,:,:,0] = np.sin( math.pi/4 + data[k,:,:,0])
#data_n[6*numPatch + k,:,:,0] = np.cos( np.transpose( math.pi/4 + data[k,:,:,0]) ) data_n[6*numPatch + k,:,:,0] = np.cos( np.transpose( math.pi/4 + data[k,:,:,0]) )
#data_n[7*numPatch + k,:,:,0] = np.sin( np.transpose( math.pi/4 + data[k,:,:,0]) ) data_n[7*numPatch + k,:,:,0] = np.sin( np.transpose( math.pi/4 + data[k,:,:,0]) )
print('nb of cos / sin / cos + flipud / sin + flipud: ', numPatch * 4) print('nb of cos / sin / cos + flipud / sin + flipud: ', numPatch * 8)
return data_n return data_n
else: else:
print('[!] phase type not exists (phi|cos|sin|two)') print('[!] phase type not exists (phi|cos|sin|two)')
...@@ -593,6 +593,7 @@ def cal_std_phase(im1, im2): ...@@ -593,6 +593,7 @@ def cal_std_phase(im1, im2):
return dev return dev
<<<<<<< HEAD
def tf_psnr(im1, im2): def tf_psnr(im1, im2):
''' '''
this function is deprecated this function is deprecated
...@@ -603,6 +604,8 @@ def tf_psnr(im1, im2): ...@@ -603,6 +604,8 @@ def tf_psnr(im1, im2):
return 10.0 * (tf.log(1 / mse) / tf.log(10.0)) return 10.0 * (tf.log(1 / mse) / tf.log(10.0))
=======
>>>>>>> cfbb377a755bc7b5e2b9ead857e01e5576ec3f67
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment