Commit a6ece072 authored by Touklakos's avatar Touklakos
Browse files

MAJ Cos and Sin for Eval

parent 3f76cf40
import os
import torch
import torch.utils.data as td
#import torchvision as tv
import torchvision as tv
import numpy as np
from PIL import Image
from utils import *
......@@ -103,8 +103,7 @@ class EvalDataset(NoisyBSDSDataset):
super(EvalDataset, self).__init__(image_mode, image_size)
self.training_name = eval_dir + "-".join([str(pattern) for pattern in patterns]) + "/" + noises
#get full images from HOLODEEP without filpupdow=False.
self.clean, self.noisy = from_DATABASE(eval_dir, noises, patterns, False)
self.clean = np.array(self.clean)
......@@ -116,3 +115,20 @@ class EvalDataset(NoisyBSDSDataset):
print("data_size : ", self.clean.shape)
print("data_type : ", type(self.clean))
def __getitem__(self, idx):
cleanSample = self.clean[idx]
noisySample = self.noisy[idx]
cleanSample = cleanSample.reshape(self.image_mode, self.image_size[0], self.image_size[1])
noisySample = noisySample.reshape(self.image_mode, self.image_size[0], self.image_size[1])
cleanSample = torch.Tensor(cleanSample)
noisySin = torch.Tensor(normalize_data(noisySample, 'sin', None))
noisyCos = torch.Tensor(normalize_data(noisySample, 'cos', None))
return (noisySin, noisyCos), cleanSample
......@@ -174,13 +174,20 @@ class Experiment(object):
self.load(startEpoch)
def initData(self, train_set, val_set, batch_size=16):
"""This method is used to initialize the training and evaluation data
Arguments:
train_set () : The training dataset
val_set () : The evaluation dataset
batch_size (int, optionnal) : The size of the batch for training data
"""
self.train_set = train_set
self.val_set = val_set
self.batch_size = batch_size
# Define data loaders
self.train_loader = td.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
self.val_loader = td.DataLoader(val_set, batch_size=1, shuffle=False, drop_last=True, pin_memory=True)
self.train_loader = td.DataLoader(train_set, batch_size=batch_size, num_workers=0,shuffle=True, drop_last=True, pin_memory=True)
self.val_loader = td.DataLoader(val_set, batch_size=1, num_workers=0, shuffle=False, drop_last=True, pin_memory=True)
self.training_data = train_set.getTrainingName()
......@@ -307,18 +314,8 @@ class Experiment(object):
if plot is not None:
plot(self)
s = time.time()
for param_group in self.optimizer.param_groups:
lr = param_group['lr']
for epoch in range(start_epoch, num_epochs):
if epoch < 30:
current_lr = lr
else:
current_lr = lr / 10.
#set learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = current_lr
self.stats_manager.init()
#train
for x, d in self.train_loader:
x, d = x.to(self.net.device), d.to(self.net.device)
self.optimizer.zero_grad()
......@@ -349,9 +346,13 @@ class Experiment(object):
self.stats_manager.init()
self.net.eval()
with torch.no_grad():
for x, d in self.val_loader:
x, d = x.to(self.net.device), d.to(self.net.device)
y = self.net.forward(x)
for (xSin, xCos), d in self.val_loader:
xSin, xCos, d = xSin.to(self.net.device), xCos.to(self.net.device), d.to(self.net.device)
ySin = self.net.forward(xSin)
yCos = self.net.forward(xCos)
x = torch.angle(xCos + xSin * 1J)
y = torch.angle(yCos + ySin * 1J)
utils.save_images(os.path.join('.', 'test.tiff'), y.cpu().numpy().reshape(1, 1024, 1024, 1))
loss = self.net.criterion(y, d)
self.stats_manager.accumulate(loss.item(), x, y, d)
self.net.train()
......@@ -379,6 +380,10 @@ class Experiment(object):
def trace(self):
loss_tab = []
for k,v in (self.history):
loss_tab = np.append(loss_tab,round(k['loss'],6))
print("affichage graphique loss: ")
plt.plot(np.arange(0,len(loss_tab)),loss_tab)
plt.title("Losses/epoch Graph ")
......
......@@ -9,6 +9,6 @@
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mano.brabant.etu@univ-lemans.fr
python3 main_holo.py --batch_size 16 --perform_validation --freq_save 2 --D 4 --num_epoch 100 --noisy_train data1/img_noisy_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy --clean_train data1/img_clean_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy
python3 main_holo.py --batch_size 16 --perform_validation --D 4 --num_epoch 10 --noisy_train data1/img_noisy_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy --clean_train data1/img_clean_train_1-2-3-4-5_0-1-1.5-2-2.5_two_50_50_384.npy
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment