-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathind_rnn_cell_test.py
59 lines (49 loc) · 2.37 KB
/
ind_rnn_cell_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Tests for the IndRNN cell."""
import numpy as np
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from ind_rnn_cell import IndRNNCell
class IndRNNCellTest(test.TestCase):
def testIndRNNCell(self):
"""Tests basic cell functionality"""
with self.test_session() as sess:
x = array_ops.zeros([1, 4])
m = array_ops.zeros([1, 4])
# Create the cell with input weights = 1 and constant recurrent weights
recurrent_init = init_ops.constant_initializer([-3., -2., 1., 3.])
input_init = init_ops.constant_initializer(1.)
cell = IndRNNCell(num_units=4,
recurrent_kernel_initializer=recurrent_init,
input_kernel_initializer=input_init,
activation=array_ops.identity)
output, _ = cell(x, m)
sess.run([variables.global_variables_initializer()])
res = sess.run([output],
{x.name: np.array([[1., 0., 0., 0.]]),
m.name: np.array([[2., 2., 2., 2.]])})
# (Pre)activations (1*1 + 2*rec_weight) should be -5, -3, 3, 7
self.assertAllEqual(res[0], [[-5., -3., 3., 7.]])
def testIndRNNCellBounds(self):
"""Tests cell with recurrent weights exceeding the bounds."""
with self.test_session() as sess:
x = array_ops.zeros([1, 4])
m = array_ops.zeros([1, 4])
# Create the cell with input weights = 1 and constant recurrent weights
recurrent_init = init_ops.constant_initializer([-5., -2., 0.1, 5.])
input_init = init_ops.constant_initializer(1.)
cell = IndRNNCell(num_units=4,
recurrent_min_abs=1.,
recurrent_max_abs=3.,
recurrent_kernel_initializer=recurrent_init,
input_kernel_initializer=input_init,
activation=array_ops.identity)
output, _ = cell(x, m)
sess.run([variables.global_variables_initializer()])
res = sess.run([output],
{x.name: np.array([[1., 0., 0., 0.]]),
m.name: np.array([[2., 2., 2., 2.]])})
# Recurrent weights should be clipped to -3, -2, 1, 3
# (Pre)activations (1*1 + 2*rec_weight) should be -5, -3, 3, 7
self.assertAllEqual(res[0], [[-5., -3., 3., 7.]])