-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubpixel.py
99 lines (86 loc) · 3.56 KB
/
subpixel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#subpixel.py
import tensorflow as tf
from keras import backend as K
from tensorflow.keras.layers import Lambda, Conv2D
def SubpixelConv2D(input_shape, scale=4, idx=0):
"""
Keras layer to do subpixel convolution.
NOTE: Tensorflow backend only. Uses tf.depth_to_space
Ref:
[1] Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
Shi et Al.
https://arxiv.org/abs/1609.05158
:param input_shape: tensor shape, (batch, height, width, channel)
:param scale: upsampling scale. Default=4
:return:
"""
# upsample using depth_to_space
def subpixel_shape(input_shape):
dims = [input_shape[0],
input_shape[1] * scale,
input_shape[2] * scale,
int(input_shape[3] / (scale ** 2))]
output_shape = tuple(dims)
return output_shape
def subpixel(x):
return tf.nn.depth_to_space(x, scale)
return Lambda(subpixel, output_shape=subpixel_shape, name='subpixel_{}'.format(idx))
class Subpixel(Conv2D):
def __init__(self,
filters,
kernel_size,
r,
padding='valid',
data_format=None,
strides=(1,1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(Subpixel, self).__init__(
filters=r*r*filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.r = r
def _phase_shift(self, I):
r = self.r
bsize, a, b, c = I.get_shape().as_list()
bsize = K.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = K.reshape(I, [bsize, a, b, c/(r*r),r, r]) # bsize, a, b, c/(r*r), r, r
X = K.permute_dimensions(X, (0, 1, 2, 5, 4, 3)) # bsize, a, b, r, r, c/(r*r)
#Keras backend does not support tf.split, so in future versions this could be nicer
X = [X[:,i,:,:,:,:] for i in range(a)] # a, [bsize, b, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, b, a*r, r, c/(r*r)
X = [X[:,i,:,:,:] for i in range(b)] # b, [bsize, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, a*r, b*r, c/(r*r)
return X
def call(self, inputs):
return self._phase_shift(super(Subpixel, self).call(inputs))
def compute_output_shape(self, input_shape):
unshifted = super(Subpixel, self).compute_output_shape(input_shape)
return (unshifted[0], self.r*unshifted[1], self.r*unshifted[2], unshifted[3]/(self.r*self.r))
def get_config(self):
config = super(Conv2D, self).get_config()
config.pop('rank')
config.pop('dilation_rate')
config['filters']/=self.r*self.r
config['r'] = self.r
return config