Skip to content

Commit

Permalink
Flexible LSTM Architectures. (#139)
Browse files Browse the repository at this point in the history
* Remove n_lstm parameter from the BasePolicy and ActorCritic class.

* Initial draft for more flexible lstm architectures.

* Fix bug in value network construction.

* Fix formatting.

* Add tests for flexible lstm policy architectures.

* Add documentation.

* Also test old LSTM Policy.

* fix formatting stable_baselines/common/policies.py

Co-Authored-By: erniejunior <[email protected]>

* fix typo in stable_baselines/common/policies.py

Co-Authored-By: erniejunior <[email protected]>

* explain magic numbers in stable_baselines/common/policies.py

Co-Authored-By: erniejunior <[email protected]>

* Explain that 'lstm' entry is mandatory and the LSTM is shared between value and policy network.

* update changelog

* Fix indentation error.
  • Loading branch information
ernestum authored and araffin committed Jan 5, 2019
1 parent f2890e0 commit 596a5c4
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 55 deletions.
20 changes: 17 additions & 3 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,23 @@ Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]``
| |
action value
The ``LstmPolicy`` can be used to construct recurrent policies in a similar way:

If however, your task requires a more granular control over the policy architecture, you can redefine the policy directly:
.. code-block:: python
class CustomLSTMPolicy(LstmPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=64, reuse=False, **_kwargs):
super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
net_arch=[8, 'lstm', dict(vf=[5, 10], pi=[10])],
layer_norm=True, feature_extraction="mlp", **_kwargs)
Here the ``net_arch`` parameter takes an additional (mandatory) 'lstm' entry within the shared network section.
The LSTM is shared between value network and policy network.




If your task requires even more granular control over the policy architecture, you can redefine the policy directly:

.. code-block:: python
Expand All @@ -144,8 +159,7 @@ If however, your task requires a more granular control over the policy architect
# with a nature_cnn feature extractor
class CustomPolicy(ActorCriticPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **kwargs):
super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256,
reuse=reuse, scale=True)
super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=True)
with tf.variable_scope("model", reuse=reuse):
activ = tf.nn.relu
Expand Down
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Pre-Release 2.4.0a (WIP)
- fixed bug related to shape of true_reward (@abhiskk)
- fixed example code in documentation of tf_util:Function (@JohannesAck)
- added learning rate schedule for SAC
- added more flexible custom LSTM policies

Release 2.3.0 (2018-12-05)
--------------------------
Expand Down
137 changes: 104 additions & 33 deletions stable_baselines/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,14 @@ class BasePolicy(ABC):
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param scale: (bool) whether or not to scale the input
:param obs_phs: (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder
and the processed observation placeholder respectivly
:param add_action_ph: (bool) whether or not to create an action placeholder
"""

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, scale=False,
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, scale=False,
obs_phs=None, add_action_ph=False):
self.n_env = n_env
self.n_steps = n_steps
Expand All @@ -114,8 +113,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256
self.obs_ph, self.processed_obs = observation_input(ob_space, n_batch, scale=scale)
else:
self.obs_ph, self.processed_obs = obs_phs
self.masks_ph = tf.placeholder(tf.float32, [n_batch], name="masks_ph") # mask (done t-1)
self.states_ph = tf.placeholder(tf.float32, [self.n_env, n_lstm * 2], name="states_ph") # states

self.action_ph = None
if add_action_ph:
self.action_ph = tf.placeholder(dtype=ac_space.dtype, shape=(None,) + ac_space.shape, name="action_ph")
Expand Down Expand Up @@ -157,14 +155,13 @@ class ActorCriticPolicy(BasePolicy):
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param scale: (bool) whether or not to scale the input
"""

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, scale=False):
super(ActorCriticPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=n_lstm,
reuse=reuse, scale=scale)
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, scale=False):
super(ActorCriticPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse,
scale=scale)
self.pdtype = make_proba_dist_type(ac_space)
self.is_discrete = isinstance(ac_space, Discrete)
self.policy = None
Expand Down Expand Up @@ -235,40 +232,114 @@ class LstmPolicy(ActorCriticPolicy):
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param layers: ([int]) The size of the Neural network before the LSTM layer (if None, default to [64, 64])
:param net_arch: (list) Specification of the actor-critic policy network architecture. Notation similar to the
format described in mlp_extractor but with additional support for a 'lstm' entry in the shared network part.
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param layer_norm: (bool) Whether or not to use layer normalizing LSTMs
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None,
cnn_extractor=nature_cnn, layer_norm=False, feature_extraction="cnn", **kwargs):
super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
net_arch=None, act_fun=tf.tanh, cnn_extractor=nature_cnn, layer_norm=False, feature_extraction="cnn",
**kwargs):
super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
scale=(feature_extraction == "cnn"))

if layers is None:
layers = [64, 64]

with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_obs, **kwargs)
else:
activ = tf.tanh
extracted_features = tf.layers.flatten(self.processed_obs)
for i, layer_size in enumerate(layers):
extracted_features = activ(linear(extracted_features, 'pi_fc' + str(i), n_hidden=layer_size,
init_scale=np.sqrt(2)))
input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps)
masks = batch_to_seq(self.masks_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
rnn_output = seq_to_batch(rnn_output)
value_fn = linear(rnn_output, 'vf', 1)
with tf.variable_scope("input", reuse=True):
self.masks_ph = tf.placeholder(tf.float32, [n_batch], name="masks_ph") # mask (done t-1)
# n_lstm * 2 dim because of the cell and hidden states of the LSTM
self.states_ph = tf.placeholder(tf.float32, [self.n_env, n_lstm * 2], name="states_ph") # states

self.proba_distribution, self.policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output)
if net_arch is None: # Legacy mode
warnings.warn("The layers parameter is deprecated. Use the net_arch parameter instead.")
if layers is None:
layers = [64, 64]

self.value_fn = value_fn
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_obs, **kwargs)
else:
extracted_features = tf.layers.flatten(self.processed_obs)
for i, layer_size in enumerate(layers):
extracted_features = act_fun(linear(extracted_features, 'pi_fc' + str(i), n_hidden=layer_size,
init_scale=np.sqrt(2)))
input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps)
masks = batch_to_seq(self.masks_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
rnn_output = seq_to_batch(rnn_output)
value_fn = linear(rnn_output, 'vf', 1)

self.proba_distribution, self.policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output)

self.value_fn = value_fn
else: # Use the new net_arch parameter
if layers is not None:
warnings.warn("The new net_arch parameter overrides the deprecated layers parameter.")
if feature_extraction == "cnn":
raise NotImplementedError()

with tf.variable_scope("model", reuse=reuse):
latent = tf.layers.flatten(self.processed_obs)
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network

# Iterate through the shared layers and build the shared parts of the network
lstm_layer_constructed = False
for idx, layer in enumerate(net_arch):
if isinstance(layer, int): # Check that this is a shared layer
layer_size = layer
latent = act_fun(linear(latent, "shared_fc{}".format(idx), layer_size, init_scale=np.sqrt(2)))
elif layer == "lstm":
if lstm_layer_constructed:
raise ValueError("The net_arch parameter must only contain one occurrence of 'lstm'!")
input_sequence = batch_to_seq(latent, self.n_env, n_steps)
masks = batch_to_seq(self.masks_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
latent = seq_to_batch(rnn_output)
lstm_layer_constructed = True
else:
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
if 'pi' in layer:
assert isinstance(layer['pi'],
list), "Error: net_arch[-1]['pi'] must contain a list of integers."
policy_only_layers = layer['pi']

if 'vf' in layer:
assert isinstance(layer['vf'],
list), "Error: net_arch[-1]['vf'] must contain a list of integers."
value_only_layers = layer['vf']
break # From here on the network splits up in policy and value network

# Build the non-shared part of the policy-network
latent_policy = latent
for idx, pi_layer_size in enumerate(policy_only_layers):
if pi_layer_size == "lstm":
raise NotImplementedError("LSTMs are only supported in the shared part of the policy network.")
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
latent_policy = act_fun(
linear(latent_policy, "pi_fc{}".format(idx), pi_layer_size, init_scale=np.sqrt(2)))

# Build the non-shared part of the value-network
latent_value = latent
for idx, vf_layer_size in enumerate(value_only_layers):
if vf_layer_size == "lstm":
raise NotImplementedError("LSTMs are only supported in the shared part of the value function "
"network.")
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
latent_value = act_fun(
linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2)))

if not lstm_layer_constructed:
raise ValueError("The net_arch parameter must contain at least one occurrence of 'lstm'!")

self.value_fn = linear(latent_value, 'vf', 1)
# TODO: why not init_scale = 0.001 here like in the feedforward
self.proba_distribution, self.policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(latent_policy, latent_value)
self.initial_state = np.zeros((self.n_env, n_lstm * 2), dtype=np.float32)
self._setup_init()

Expand Down Expand Up @@ -310,8 +381,8 @@ class FeedForwardPolicy(ActorCriticPolicy):

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None, net_arch=None,
act_fun=tf.tanh, cnn_extractor=nature_cnn, feature_extraction="cnn", **kwargs):
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256,
reuse=reuse, scale=(feature_extraction == "cnn"))
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse,
scale=(feature_extraction == "cnn"))

if layers is not None:
warnings.warn("Usage of the `layers` parameter is deprecated! Use net_arch instead "
Expand Down
11 changes: 5 additions & 6 deletions stable_baselines/ddpg/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ class DDPGPolicy(BasePolicy):
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param scale: (bool) whether or not to scale the input
"""

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, scale=False):
super(DDPGPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=n_lstm, reuse=reuse,
scale=scale, add_action_ph=True)
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, scale=False):
super(DDPGPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=scale,
add_action_ph=True)
assert isinstance(ac_space, Box), "Error: the action space must be of type gym.spaces.Box"
assert (np.abs(ac_space.low) == ac_space.high).all(), "Error: the action space low and high must be symmetric"
self.qvalue_fn = None
Expand Down Expand Up @@ -106,8 +105,8 @@ class FeedForwardPolicy(DDPGPolicy):

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None,
cnn_extractor=nature_cnn, feature_extraction="cnn", layer_norm=False, **kwargs):
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256,
reuse=reuse, scale=(feature_extraction == "cnn"))
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse,
scale=(feature_extraction == "cnn"))
self.layer_norm = layer_norm
self.feature_extraction = feature_extraction
self.cnn_kwargs = kwargs
Expand Down
11 changes: 5 additions & 6 deletions stable_baselines/deepq/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ class DQNPolicy(BasePolicy):
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param scale: (bool) whether or not to scale the input
:param obs_phs: (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder
and the processed observation placeholder respectivly
:param dueling: (bool) if true double the output MLP to compute a baseline for action scores
"""

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, scale=False,
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, scale=False,
obs_phs=None, dueling=True):
# DQN policies need an override for the obs placeholder, due to the architecture of the code
super(DQNPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=n_lstm, reuse=reuse,
scale=scale, obs_phs=obs_phs)
super(DQNPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=scale,
obs_phs=obs_phs)
assert isinstance(ac_space, Discrete), "Error: the action space for DQN must be of type gym.spaces.Discrete"
self.n_actions = ac_space.n
self.value_fn = None
Expand Down Expand Up @@ -92,8 +91,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals
cnn_extractor=nature_cnn, feature_extraction="cnn",
obs_phs=None, layer_norm=False, dueling=True, **kwargs):
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps,
n_batch, n_lstm=256, dueling=dueling,
reuse=reuse, scale=(feature_extraction == "cnn"), obs_phs=obs_phs)
n_batch, dueling=dueling, reuse=reuse,
scale=(feature_extraction == "cnn"), obs_phs=obs_phs)
if layers is None:
layers = [64, 64]

Expand Down
Loading

0 comments on commit 596a5c4

Please sign in to comment.