main_holo.py 7.28 KB
Newer Older
1
2
3
4
5



import torch
import time
Marie Tahon's avatar
Marie Tahon committed
6
from utils import *
7
8
9
10
import nntools as nt
from model import *
from data import *
from argument import *
11

12

13
def save_clean_pred_rad(args, exp, clean_pred_rad, noisy, clean, nom_img = "NoisyPhasePred"):
14
15
16
17
18
19
20
21
22
23
24
    """This method is used to save the result of a de-noising operation

    Arguments:
        args (ArgumentParser) :         The different info used to do and save the de-noising operation
        exp (Experiment) :              The de-noising model
        clean_pred_rad (numpy.array) :  The de-noised image
        noisy (numpy.array) :           The noised image
        clean (numpy.array) :           The noise free image
        nom_img (str, optional) :       The saving name for the result
    """

25
    save_name = os.path.join(args.save_test_dir, args.input_dir, "Test")
26
27
28

    if not os.path.exists(save_name):
        os.makedirs(save_name)
29

30
31
32
33
34
35
36
37
38
39
40

    save_images(os.path.join(save_name , '%s-noisy.tiff' % (nom_img)), noisy)
    save_images(os.path.join(save_name , '%s-clean.tiff' % (nom_img)), clean)

    save_images(os.path.join(save_name , '%s-%d.tiff' % (nom_img, exp.epoch)), clean_pred_rad)
    save_MAT_images(os.path.join(save_name , '%s-%d.mat' % (nom_img, exp.epoch)), clean_pred_rad)


    epoch = exp.epoch
    psnr = cal_psnr(rad_to_flat(clean_pred_rad), rad_to_flat(clean))
    std = cal_std_phase(clean_pred_rad, clean)
41

42
    print("\n")
Marie Tahon's avatar
Marie Tahon committed
43
    print("image : ", nom_img)
44
45
46
    print("epoch : ", epoch)
    print("psnr : ", psnr)
    print("std : ", std)
47
    print("\n")
48
49

    with open(os.path.join(save_name , '%s-%d.res' % (nom_img, exp.epoch)), 'w') as f:
Marie Tahon's avatar
Marie Tahon committed
50
        print("image : ", nom_img, file=f)
51
52
53
54
55
56
57
58
59
60
61
62
63
        print("epoch : ", epoch, file=f)
        print("psnr : ", psnr, file=f)
        print("std : ", std, file=f)




def evaluate_on_HOLODEEP(args, exp):
    """This method is used to run an evaluation on the training database

    Arguments:
        args (ArgumentParser) : The different info used to do and save the de-noising operations
        exp (Experiment) :      The de-noising model
64

65
66
    """

67

68
69
70
    patterns = args.test_patterns
    noises = args.test_noises

Touklakos's avatar
Touklakos committed
71
    clean, noisy = from_DATABASE(args.eval_dir, noises, patterns, True)
72

73

74
75
76
77
78
79
80
81
82
83
    clean = np.array(clean)
    noisy = np.array(noisy)

    running_std = 0

    for i in range(noisy.shape[0]):
        clean_pred_rad = denoise_img(args, noisy[i], clean[i], "test-{:0>2}".format(i), exp)

        std = cal_std_phase(clean_pred_rad, clean[i])
        running_std += std
84
85


86
87
    print("On the patterns : ", patterns)
    print("With noise : ", noises)
88
89
90
91
92
93
94
    print("average_std : ", running_std/noisy.shape[0])




def evaluate_on_DATAEVAL(args, exp):
    """This method is used to run an evaluation on the three test images
95

96
97
98
99
100
    Arguments:
        args (ArgumentParser) : The different info used to do and save the de-noising operations
        exp (Experiment) :      The model used to do the de-noising operation
    """

101

102
103
104
    dir_name = args.test_dir
    #nameList = ["DATA_1_Phase_Type1_2_0.25_1.5_4_50.mat", "DATA_20_Phase_Type4_2_0.25_2.5_4_100.mat", "VibPhaseDATA.mat"]
    nameList = get_files(pathlib.Path(dir_name), '.*.mat')
105
106
107
    dataList = []

    for name in nameList:
108
        dataList.append((   load_test_data(name, key = "Phaseb", flipupdown = True),
109
                            load_test_data(name, key = "Phase", flipupdown = True)))
110
111


112
113
114
115
116
117
    for idx, (noisy, clean) in enumerate(dataList):
        denoise_img(args, noisy, clean, os.path.basename(nameList[idx]), exp)




Touklakos's avatar
Touklakos committed
118
def denoise_img(args, noisy, clean, name, exp):
119
120
121
122
123
124
125
126
127
128
129
    """This method is used to do and save a de-noising operation on a given image

    Arguments:
        args (ArgumentParser) :         The different info used to do and save the de-noising operations
        noisy (numpy.array) :           The image to de-noise
        clean (numpy.array) :           The clean reference
        name (str) :                    The name used to save the results
        exp (Experiment) :              The model used to do the de-noising operation
    """

    clean_pred_rad = noisy
Touklakos's avatar
Touklakos committed
130
    nb_iteration = args.nb_iteration
131

132

133
134
135
136
137
138
139
140
141
142
143
144
145
    for j in range(nb_iteration):
        clean_pred_rad = denoising_single_image(args, clean_pred_rad, exp)

    save_clean_pred_rad(args, exp, clean_pred_rad, noisy, clean, nom_img = name)

    return clean_pred_rad



def denoising_single_image(args, noisy, exp):
    """This method is used to do a de-noising operation on a given image

    Arguments:
146
        args (ArgumentParser) : The different info used to do the de-noising operation
147
148
149
        noisy (numpy.array) :   The image to denoise
        exp (Experiment) :      The model used to denoise
    """
150

151
    noisyPy = noisy.reshape(1, args.image_mode, args.test_image_size[0], args.test_image_size[1])
152
153
154
155

    noisyPy_cos = torch.Tensor(normalize_data(noisyPy, 'cos', None))
    noisyPy_sin = torch.Tensor(normalize_data(noisyPy, 'sin', None))

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    #clean_pred_cos = exp.test(noisyPy_cos).detach().cpu().numpy()
    #clean_pred_sin = exp.test(noisyPy_sin).detach().cpu().numpy()


    clean_pred_cos = exp.test(noisyPy_cos)
    clean_pred_sin = exp.test(noisyPy_sin)

    clean_pred_cos = clean_pred_cos.detach()
    clean_pred_sin = clean_pred_sin.detach()

    clean_pred_cos = clean_pred_cos.cpu()
    clean_pred_sin = clean_pred_sin.cpu()

    clean_pred_cos = clean_pred_cos.numpy()
    clean_pred_sin = clean_pred_sin.numpy()
Touklakos's avatar
Touklakos committed
171
172
173
174


    clean_pred_rad = np.angle(clean_pred_cos + clean_pred_sin * 1J)
    clean_pred_rad = clean_pred_rad.reshape(1, args.test_image_size[0], args.test_image_size[1], args.image_mode)
175
176
177
178
179
180
181
182
183

    return clean_pred_rad



def run(args):
    """This method is the main method
    """

184

185
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
186

187

188
189
190
    net = DnCNN(D=args.D, C=args.C, image_mode=args.image_mode).to(device)
    adam = torch.optim.Adam(net.parameters(), lr=args.lr)
    statsManager = DenoisingStatsManager()
191
192
193



194
    exp = nt.Experiment(net, adam, statsManager, perform_validation_during_training=args.perform_validation, input_dir=args.input_dir, output_dir=args.output_dir, startEpoch=args.epoch, freq_save=args.freq_save)
195
196


197
198
199
    if not args.test_mode :
        print("\n=>Training until epoch :<===\n", args.num_epochs)
        print("\n\Model training")
Touklakos's avatar
Touklakos committed
200
201
202
203

        trainData = TrainDataset(args.clean_train, args.noisy_train, args.image_mode, args.train_image_size, nb_rotation=args.nb_rotation)
        evalData  = EvalDataset(args.eval_dir, args.eval_noises, args.eval_patterns, args.image_mode, args.eval_image_size)

Touklakos's avatar
Touklakos committed
204
        exp.initData(trainData, evalData, batch_size=args.batch_size)
205
        exp.run(num_epochs=args.num_epochs)
206

207
208
        if(args.graph):
            exp.trace()
209

210
    else :
211

212
213
214
        print("args.noisy_img : ", args.noisy_img)


215
216
217
218
219
220
        if args.noisy_img is None:
            evaluate_on_HOLODEEP(args, exp)
            evaluate_on_DATAEVAL(args, exp)
        else :
            noisy = load_test_data(args.noisy_img, key = args.noisy_key, flipupdown=args.flip)
            clean = load_test_data(args.clean_img, key = args.clean_key, flipupdown=args.flip)
221

222
            denoise_img(args, noisy, clean, os.path.basename(args.noisy_img), exp)
223

Marie Tahon's avatar
Marie Tahon committed
224
225
226


if __name__ == '__main__':
227

228
229
230
231
232
233
234
235
236
237
238
    args = parse()

    print("\n\n")
    list(map(lambda p : print(p + " : ", vars(args)[p]), vars(args)))
    print("\n\n")

    torch.manual_seed(123456789)

    timeElapsed = time.time()
    run(args)
    print("Time elapsed : ", time.time() - timeElapsed)