-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmodel.py
118 lines (100 loc) · 4.32 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import tensorflow as tf
def attention(tensor, params):
"""Attention model for grid world domain
"""
S1, S2 = params
# Flatten
s1 = tf.reshape(S1, [-1])
s2 = tf.reshape(S2, [-1])
# Indices for slicing
N = tf.shape(tensor)[0]
idx = tf.stack([tf.range(N), s1, s2], axis=1)
# Slicing values
q_out = tf.gather_nd(tensor, idx, name='q_out')
return q_out
def VIN(X, S1, S2, args):
k = args.k # Number of Value Iteration computations
ch_i = args.ch_i # Channels in input layer
ch_h = args.ch_h # Channels in initial hidden layer
ch_q = args.ch_q # Channels in q layer (~actions)
h = tf.layers.conv2d(inputs=X,
filters=ch_h,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
activation=None,
use_bias=True,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
name='h0',
reuse=None)
r = tf.layers.conv2d(inputs=h,
filters=1,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
activation=None,
use_bias=False,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
bias_initializer=None,
name='r',
reuse=None)
# Add collection of reward image
tf.add_to_collection('r', r)
# Initialize value map (zero everywhere)
v = tf.zeros_like(r)
rv = tf.concat([r, v], axis=3)
q = tf.layers.conv2d(inputs=rv,
filters=ch_q,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
activation=None,
use_bias=False,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
bias_initializer=None,
name='q',
reuse=None) # Initial set before sharing weights
v = tf.reduce_max(q, axis=3, keep_dims=True, name='v')
# K iterations of VI module
for i in range(0, k - 1):
rv = tf.concat([r, v], axis=3)
q = tf.layers.conv2d(inputs=rv,
filters=ch_q,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
activation=None,
use_bias=False,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
bias_initializer=None,
name='q',
reuse=True) # Sharing weights
v = tf.reduce_max(q, axis=3, keep_dims=True, name='v')
# Add collection of value images
tf.add_to_collection('v', v)
# Do one last convolution
rv = tf.concat([r, v], axis=3)
q = tf.layers.conv2d(inputs=rv,
filters=ch_q,
kernel_size=[3, 3],
strides=[1, 1],
padding='same',
activation=None,
use_bias=False,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
bias_initializer=None,
name='q',
reuse=True) # Sharing weights
# Attention model
q_out = attention(tensor=q, params=[S1, S2])
# Final Fully Connected layer
logits = tf.layers.dense(inputs=q_out,
units=8,
activation=None,
use_bias=False,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
name='logits')
prob_actions = tf.nn.softmax(logits, name='probability_actions')
return logits, prob_actions