forked from deepak112/Keras-SRGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNetwork.py
125 lines (86 loc) · 4.51 KB
/
Network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
#title :Network.py
#description :Architecture file(Generator and Discriminator)
#author :Deepak Birla
#date :2018/10/30
#usage :from Network import Generator, Discriminator
#python_version :3.5.4
# Modules
from keras.layers import Dense
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.core import Flatten
from keras.layers import Input
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model
from keras.layers.advanced_activations import LeakyReLU, PReLU
from keras.layers import add
# Residual block
def res_block_gen(model, kernal_size, filters, strides):
gen = model
model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)
# Using Parametric ReLU
model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)
model = add([gen, model])
return model
def up_sampling_block(model, kernal_size, filters, strides):
# In place of Conv2D and UpSampling2D we can also use Conv2DTranspose (Both are used for Deconvolution)
# Even we can have our own function for deconvolution (i.e one made in Utils.py)
#model = Conv2DTranspose(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = UpSampling2D(size = 2)(model)
model = LeakyReLU(alpha = 0.2)(model)
return model
def discriminator_block(model, filters, kernel_size, strides):
model = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)
model = LeakyReLU(alpha = 0.2)(model)
return model
# Network Architecture is same as given in Paper https://arxiv.org/pdf/1609.04802.pdf
class Generator(object):
def __init__(self, noise_shape):
self.noise_shape = noise_shape
def generator(self):
gen_input = Input(shape = self.noise_shape)
model = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
gen_model = model
# Using 16 Residual Blocks
for index in range(16):
model = res_block_gen(model, 3, 64, 1)
model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)
model = add([gen_model, model])
# Using 2 UpSampling Blocks
for index in range(2):
model = up_sampling_block(model, 3, 256, 1)
model = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(model)
model = Activation('tanh')(model)
generator_model = Model(inputs = gen_input, outputs = model)
return generator_model
# Network Architecture is same as given in Paper https://arxiv.org/pdf/1609.04802.pdf
class Discriminator(object):
def __init__(self, image_shape):
self.image_shape = image_shape
def discriminator(self):
dis_input = Input(shape = self.image_shape)
model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(dis_input)
model = LeakyReLU(alpha = 0.2)(model)
model = discriminator_block(model, 64, 3, 2)
model = discriminator_block(model, 128, 3, 1)
model = discriminator_block(model, 128, 3, 2)
model = discriminator_block(model, 256, 3, 1)
model = discriminator_block(model, 256, 3, 2)
model = discriminator_block(model, 512, 3, 1)
model = discriminator_block(model, 512, 3, 2)
model = Flatten()(model)
model = Dense(1024)(model)
model = LeakyReLU(alpha = 0.2)(model)
model = Dense(1)(model)
model = Activation('sigmoid')(model)
discriminator_model = Model(inputs = dis_input, outputs = model)
return discriminator_model