-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtest_real_denoising_dnd.py
94 lines (74 loc) · 3.22 KB
/
test_real_denoising_dnd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# The implementation builds on Restormer code https://github.com/swz30/Restormer
import numpy as np
import os
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.archs.cat_unet_arch import CAT_Unet
from skimage import img_as_ubyte
import cv2
import h5py
import scipy.io as sio
parser = argparse.ArgumentParser(description='Real Image Denoising using CAT')
parser.add_argument('--input_dir', default='datasets/real-DN/DND/', type=str, help='Directory of validation images')
parser.add_argument('--result_dir', default='results/test_CAT_Real_DN/DND/', type=str, help='Directory for results')
parser.add_argument('--weights', default='experiments/pretrained_models/Real-DN/Real_DN_CAT.pth', type=str, help='Path to weights')
parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
args = parser.parse_args()
####### Load yaml #######
yaml_file = 'options/test/test_CAT_RealDenoising.yml'
import yaml
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
s = x['network_g'].pop('type')
##########################
result_dir_mat = os.path.join(args.result_dir, 'mat')
os.makedirs(result_dir_mat, exist_ok=True)
if args.save_images:
result_dir_png = os.path.join(args.result_dir, 'png')
os.makedirs(result_dir_png, exist_ok=True)
model_restoration = CAT_Unet(**x['network_g'])
checkpoint = torch.load(args.weights)
model_restoration.load_state_dict(checkpoint['params'])
print("===>Testing using weights: ",args.weights)
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()
israw = False
eval_version="1.0"
# Load info
infos = h5py.File(os.path.join(args.input_dir, 'info.mat'), 'r')
info = infos['info']
bb = info['boundingboxes']
# Process data
with torch.no_grad():
for i in tqdm(range(50)):
Idenoised = np.zeros((20,), dtype=object)
filename = '%04d.mat'%(i+1)
filepath = os.path.join(args.input_dir, 'images_srgb', filename)
img = h5py.File(filepath, 'r')
Inoisy = np.float32(np.array(img['InoisySRGB']).T)
# bounding box
ref = bb[0][i]
boxes = np.array(info[ref]).T
for k in range(20):
idx = [int(boxes[k,0]-1),int(boxes[k,2]),int(boxes[k,1]-1),int(boxes[k,3])]
noisy_patch = torch.from_numpy(Inoisy[idx[0]:idx[1],idx[2]:idx[3],:]).unsqueeze(0).permute(0,3,1,2).cuda()
restored_patch = model_restoration(noisy_patch)
restored_patch = torch.clamp(restored_patch,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
Idenoised[k] = restored_patch
if args.save_images:
save_file = os.path.join(result_dir_png, '%04d_%02d.png'%(i+1,k+1))
denoised_img = img_as_ubyte(restored_patch)
cv2.imwrite(save_file, cv2.cvtColor(img_as_ubyte(denoised_img), cv2.COLOR_RGB2BGR))
# save denoised data
sio.savemat(os.path.join(result_dir_mat, filename),
{"Idenoised": Idenoised,
"israw": israw,
"eval_version": eval_version},
)