Skip to content

Commit

Permalink
Change imports to use sonnet
Browse files Browse the repository at this point in the history
  • Loading branch information
sergomezcol committed Apr 7, 2017
1 parent caa1448 commit 7ebea07
Show file tree
Hide file tree
Showing 18 changed files with 44 additions and 3,937 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# [Learning to Learn](https://arxiv.org/abs/1606.04474) in TensorFlow

Compatible with TensorFlow 1.0

## Dependencies

* [TensorFlow >=1.0](https://www.tensorflow.org/)
* [Sonnet >=1.0](https://github.com/deepmind/sonnet)


## Training
Expand Down
4 changes: 2 additions & 2 deletions meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import os

import mock
import sonnet as snt
import tensorflow as tf

from tensorflow.python.framework import ops
from tensorflow.python.util import nest

import networks
import nn


def _nested_assign(ref, value):
Expand Down Expand Up @@ -379,7 +379,7 @@ def time_step(t, fx_array, x, state):
# Log internal variables.
for k, net in nets.items():
print("Optimizer '{}' variables".format(k))
print([op.name for op in nn.get_variables_in_module(net)])
print([op.name for op in snt.get_variables_in_module(net)])

return MetaLoss(loss, update, reset, fx_final, x_final)

Expand Down
10 changes: 5 additions & 5 deletions meta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from nose_parameterized import parameterized
import numpy as np
from six.moves import xrange
import sonnet as snt
import tensorflow as tf

import meta
import nn
import problems


Expand Down Expand Up @@ -141,10 +141,10 @@ def testConvolutional(self):
"""Tests L2L applied to problem with convolutions."""
kernel_shape = 4
def convolutional_problem():
conv = nn.Conv2D(output_channels=1,
kernel_shape=kernel_shape,
stride=1,
name="conv")
conv = snt.Conv2D(output_channels=1,
kernel_shape=kernel_shape,
stride=1,
name="conv")
output = conv(tf.random_normal((100, 100, 3, 10)))
return tf.reduce_sum(output)

Expand Down
18 changes: 9 additions & 9 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import dill as pickle
import numpy as np
import six
import sonnet as snt
import tensorflow as tf

import nn
import preprocess


Expand All @@ -47,7 +47,7 @@ def factory(net, net_options=(), net_path=None):
def save(network, sess, filename=None):
"""Save the variables contained by a network to disk."""
to_save = collections.defaultdict(dict)
variables = nn.get_variables_in_module(network)
variables = snt.get_variables_in_module(network)

for v in variables:
split = v.name.split(":")[0].split("/")
Expand All @@ -63,7 +63,7 @@ def save(network, sess, filename=None):


@six.add_metaclass(abc.ABCMeta)
class Network(nn.RNNCore):
class Network(snt.RNNCore):
"""Base class for meta-optimizer networks."""

@abc.abstractmethod
Expand Down Expand Up @@ -166,8 +166,8 @@ def __init__(self, output_size, layers, preprocess_name="identity",
tf modules). Default is `tf.identity`.
preprocess_options: Gradient preprocessing options.
scale: Gradient scaling (default is 1.0).
initializer: Variable initializer for linear layer. See `nn.Linear` and
`nn.LSTM` docs for more info. This parameter can be a string (e.g.
initializer: Variable initializer for linear layer. See `snt.Linear` and
`snt.LSTM` docs for more info. This parameter can be a string (e.g.
"zeros" will be converted to tf.zeros_initializer).
name: Module name.
"""
Expand All @@ -188,12 +188,12 @@ def __init__(self, output_size, layers, preprocess_name="identity",
name = "lstm_{}".format(i)
init = _get_layer_initializers(initializer, name,
("w_gates", "b_gates"))
self._cores.append(nn.LSTM(size, name=name, initializers=init))
self._rnn = nn.DeepRNN(self._cores, skip_connections=False,
name="deep_rnn")
self._cores.append(snt.LSTM(size, name=name, initializers=init))
self._rnn = snt.DeepRNN(self._cores, skip_connections=False,
name="deep_rnn")

init = _get_layer_initializers(initializer, "linear", ("w", "b"))
self._linear = nn.Linear(output_size, name="linear", initializers=init)
self._linear = snt.Linear(output_size, name="linear", initializers=init)

def _build(self, inputs, prev_state):
"""Connects the `StandardDeepLSTM` module into the graph.
Expand Down
10 changes: 5 additions & 5 deletions networks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

from nose_parameterized import parameterized
import numpy as np
import sonnet as snt
import tensorflow as tf

import networks
import nn


class CoordinateWiseDeepLSTMTest(tf.test.TestCase):
Expand All @@ -45,7 +45,7 @@ def testTrainable(self):
state = net.initial_state_for_inputs(gradients)
net(gradients, state)
# Weights and biases for two layers.
variables = nn.get_variables_in_module(net)
variables = snt.get_variables_in_module(net)
self.assertEqual(len(variables), 4)

@parameterized.expand([
Expand Down Expand Up @@ -90,7 +90,7 @@ def testTrainable(self):
state = net.initial_state_for_inputs(gradients)
net(gradients, state)
# Weights and biases for two layers.
variables = nn.get_variables_in_module(net)
variables = snt.get_variables_in_module(net)
self.assertEqual(len(variables), 4)

@parameterized.expand([
Expand Down Expand Up @@ -134,7 +134,7 @@ def testNonTrainable(self):
net = networks.Sgd()
state = net.initial_state_for_inputs(gradients)
net(gradients, state)
variables = nn.get_variables_in_module(net)
variables = snt.get_variables_in_module(net)
self.assertEqual(len(variables), 0)

def testResults(self):
Expand Down Expand Up @@ -169,7 +169,7 @@ def testNonTrainable(self):
net = networks.Adam()
state = net.initial_state_for_inputs(gradients)
net(gradients, state)
variables = nn.get_variables_in_module(net)
variables = snt.get_variables_in_module(net)
self.assertEqual(len(variables), 0)

def testZeroLearningRate(self):
Expand Down
44 changes: 0 additions & 44 deletions nn/__init__.py

This file was deleted.

Loading

0 comments on commit 7ebea07

Please sign in to comment.