-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
69 lines (51 loc) · 1.95 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import sys
import torch
gpu_available = torch.cuda.is_available()
DEVICE = 'cuda' if gpu_available else 'cpu'
torch.set_default_tensor_type('torch.cuda.FloatTensor') if gpu_available else None
sys.path.append(os.path.join(os.path.dirname(__file__), 'DPIR'))
import cv2
import matplotlib.pyplot as plt
from libs.pnp import admm, deep_denoiser, modulo
from libs.utils import load_img, load_model
normalize = lambda x: (x - torch.min(x, dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]) / (torch.max(x, dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] - torch.min(x, dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0])
tensor2img = lambda x: (x.permute(0, 2, 3, 1).squeeze().cpu().numpy() * 255).astype('uint8')
img2tensor = lambda x: torch.tensor(x, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
centering = lambda x: x - torch.mean(x, dim=(-1, -2), keepdim=True)
# image parameters
img_size = 1024
wrapping_threshold = 64
# recovert parameters
max_iters = 5
epsilon = 0.1
_lambda = 0.1
gamma = 1.1
# Load image
img_path = os.path.join(".", "data", "kodim23.png")
img = load_img(img_path, img_size)
img_t = img2tensor(img) / wrapping_threshold
# Modulo operation
modulo_t = modulo(img_t + torch.rand_like(img_t) * 0.1, 1.0).to(DEVICE)
img_t = centering(img_t).to(DEVICE)
# Unwrapping
model_name = 'drunet_color'
model_pool = 'model_zoo'
model = load_model(model_name, model_pool)
model = model.to(DEVICE)
img_est = admm(modulo_t, deep_denoiser, model, max_iters=max_iters, epsilon=epsilon, _lambda=_lambda*epsilon, gamma=gamma)
# Visualize
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow( tensor2img(normalize(img_t)) )
plt.title("Original Image")
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow( tensor2img(normalize(modulo_t)) )
plt.title("Modulo Image")
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow( tensor2img(normalize(img_est)) )
plt.title("Recovered Image")
plt.axis('off')
plt.show()