Commit 1bbc378f authored by Marie Tahon's avatar Marie Tahon
Browse files

remove torchvision from data.py and ref to tf in utils

parents 96f42832 cfbb377a
......@@ -80,12 +80,12 @@ def parse():
parser.add_argument('--isDebug', dest='isDebug', action='store_true')
parser.add_argument('--patch_size', dest='patch_size', default=50)
parser.add_argument('--stride', dest='stride', default=50)
parser.add_argument('--step', dest='step', default=0)
parser.add_argument('--patch_size', dest='patch_size', type=int, default=50)
parser.add_argument('--stride', dest='stride', type=int, default=50)
parser.add_argument('--step', dest='step', type=int, default=0)
parser.add_argument('--freq_save', dest='freq_save', type=int, default=1)
parser.add_argument('--phase_type', dest='phase_type', default="two")
parser.add_argument('--patch_per_image', dest='patch_per_image', default=384)
parser.add_argument('--patch_per_image', dest='patch_per_image', type=int, default=384)
parser.add_argument('--noise_src_dir', dest='noise_src_dir', default="./chemin/")
parser.add_argument('--clean_src_dir', dest='clean_src_dir', default="./chemin/")
......
......@@ -67,7 +67,7 @@ parser.add_argument('--params', dest='params', type=str, default='', help='hyper
# check output arguments
#parser.add_argument('--from_file', dest='from_file', default="./data/img_clean_pats.npy", help='get pic from file')
#parser.add_argument('--num_pic', dest='num_pic', type=int, default=10, help='number of pic to pick')
args = parser.parse_args()
#args = parser.parse_args()
#print(args.params['patch_size'])
#hparams.parse(args.params)
......
......@@ -171,7 +171,7 @@ def run(args):
exp = nt.Experiment(net, adam, statsManager, batch_size=args.batch_size, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch, freq_save=args.freq_save)
exp = nt.Experiment(net, adam, statsManager, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, startEpoch=args.epoch, freq_save=args.freq_save)
if not args.test_mode :
......@@ -181,7 +181,7 @@ def run(args):
trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size, nb_rotation=args.nb_rotation)
evalData = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size)
exp.initData(trainData, evalData)
exp.initData(trainData, evalData, batch_size=args.batch_size)
exp.run(num_epochs=args.num_epochs)
if(args.graph):
......
......@@ -146,7 +146,7 @@ class Experiment(object):
"""
def __init__(self, net, optimizer, stats_manager, startEpoch=None,
input_dir=None, batch_size=16, perform_validation_during_training=False, freq_save=1):
input_dir=None, perform_validation_during_training=False, freq_save=1):
# Initialize history
history = []
......@@ -173,13 +173,14 @@ class Experiment(object):
# "I found a checkpoint conflicting with the current setting.")
self.load(startEpoch)
def initData(self, train_set, val_set):
def initData(self, train_set, val_set, batch_size=16):
self.train_set = train_set
self.val_set = val_set
self.batch_size = batch_size
# Define data loaders
self.train_loader = td.DataLoader(train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True)
self.val_loader = td.DataLoader(val_set, batch_size=self.batch_size, shuffle=False, drop_last=True, pin_memory=True)
self.train_loader = td.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
self.val_loader = td.DataLoader(val_set, batch_size=1, shuffle=False, drop_last=True, pin_memory=True)
self.training_data = train_set.getTrainingName()
......
......@@ -37,7 +37,7 @@ import sys
import re
import pathlib
import numpy as np
#import tensorflow as tf
import tensorflow as tf
from PIL import Image
from scipy.io import loadmat, savemat
from glob import glob
......@@ -402,7 +402,7 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False):
return data_n
elif (phase_type == 'two') & (phase_augmentation == True):
numPatch = data.shape[0]
newshape = (numPatch * 4, data.shape[1], data.shape[2], data.shape[3])
newshape = (numPatch * 8, data.shape[1], data.shape[2], data.shape[3])
data_n = np.zeros(shape = newshape)
cpt = 0
for k in range(numPatch):
......@@ -410,11 +410,11 @@ def normalize_data(data,phase_type, rdm, phase_augmentation = False):
data_n[1*numPatch + k,:,:,0] = np.sin( data[k,:,:,0])
data_n[2*numPatch + k,:,:,0] = np.cos( np.transpose( data[k,:,:,0]) )
data_n[3*numPatch + k,:,:,0] = np.sin( np.transpose( data[k,:,:,0]) )
#data_n[4*numPatch + k,:,:,0] = np.cos( math.pi/4 + data[k,:,:,0])
#data_n[5*numPatch + k,:,:,0] = np.sin( math.pi/4 + data[k,:,:,0])
#data_n[6*numPatch + k,:,:,0] = np.cos( np.transpose( math.pi/4 + data[k,:,:,0]) )
#data_n[7*numPatch + k,:,:,0] = np.sin( np.transpose( math.pi/4 + data[k,:,:,0]) )
print('nb of cos / sin / cos + flipud / sin + flipud: ', numPatch * 4)
data_n[4*numPatch + k,:,:,0] = np.cos( math.pi/4 + data[k,:,:,0])
data_n[5*numPatch + k,:,:,0] = np.sin( math.pi/4 + data[k,:,:,0])
data_n[6*numPatch + k,:,:,0] = np.cos( np.transpose( math.pi/4 + data[k,:,:,0]) )
data_n[7*numPatch + k,:,:,0] = np.sin( np.transpose( math.pi/4 + data[k,:,:,0]) )
print('nb of cos / sin / cos + flipud / sin + flipud: ', numPatch * 8)
return data_n
else:
print('[!] phase type not exists (phi|cos|sin|two)')
......@@ -593,6 +593,7 @@ def cal_std_phase(im1, im2):
return dev
<<<<<<< HEAD
def tf_psnr(im1, im2):
'''
this function is deprecated
......@@ -603,6 +604,8 @@ def tf_psnr(im1, im2):
return 10.0 * (tf.log(1 / mse) / tf.log(10.0))
=======
>>>>>>> cfbb377a755bc7b5e2b9ead857e01e5576ec3f67
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment