Skip to content

Commit

Permalink
Feature/policy kwargs (openai#162)
Browse files Browse the repository at this point in the history
* Add policy_kwargs with tests.

* Add documentation.

* mark unused variable
  • Loading branch information
ernestum authored and hill-a committed Jan 15, 2019
1 parent 002fb35 commit 88a5c5d
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 64 deletions.
2 changes: 2 additions & 0 deletions docs/modules/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Policy Networks
===============

Stable-baselines provides a set of default policies, that can be used with most action spaces.
To customize the default policies, you can specify the `policy_kwargs` parameter to the model class you use.
Those kwargs are then passed to the policy on instantiation.
If you need more control on the policy architecture, You can also create a custom policy (see :ref:`custom_policy`).

.. note::
Expand Down
12 changes: 7 additions & 5 deletions stable_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ class A2C(ActorCriticRLModel):
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
(used only for loading)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
"""

def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.01, max_grad_norm=0.5,
learning_rate=7e-4, alpha=0.99, epsilon=1e-5, lr_schedule='linear', verbose=0, tensorboard_log=None,
_init_setup_model=True):
_init_setup_model=True, policy_kwargs=None):

super(A2C, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model)
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)

self.n_steps = n_steps
self.gamma = gamma
Expand Down Expand Up @@ -99,12 +100,12 @@ def setup_model(self):
n_batch_train = self.n_envs * self.n_steps

step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
n_batch_step, reuse=False)
n_batch_step, reuse=False, **self.policy_kwargs)

with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs,
self.n_steps, n_batch_train, reuse=True)
self.n_steps, n_batch_train, reuse=True, **self.policy_kwargs)

with tf.variable_scope("loss", reuse=False):
self.actions_ph = train_model.pdtype.sample_placeholder([None], name="action_ph")
Expand Down Expand Up @@ -260,7 +261,8 @@ def save(self, save_path):
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_envs": self.n_envs,
"_vectorize_action": self._vectorize_action
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
Expand Down
15 changes: 9 additions & 6 deletions stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,16 @@ class ACER(ActorCriticRLModel):
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
"""

def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, ent_coef=0.01, max_grad_norm=10,
learning_rate=7e-4, lr_schedule='linear', rprop_alpha=0.99, rprop_epsilon=1e-5, buffer_size=5000,
replay_ratio=4, replay_start=1000, correction_term=10.0, trust_region=True, alpha=0.99, delta=1,
verbose=0, tensorboard_log=None, _init_setup_model=True):
verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None):

super(ACER, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model)
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)

self.n_steps = n_steps
self.replay_ratio = replay_ratio
Expand Down Expand Up @@ -180,14 +181,14 @@ def setup_model(self):
n_batch_train = self.n_envs * (self.n_steps + 1)

step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
n_batch_step, reuse=False)
n_batch_step, reuse=False, **self.policy_kwargs)

self.params = find_trainable_variables("model")

with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs,
self.n_steps + 1, n_batch_train, reuse=True)
self.n_steps + 1, n_batch_train, reuse=True, **self.policy_kwargs)

with tf.variable_scope("moving_average"):
# create averaged model
Expand All @@ -202,7 +203,8 @@ def custom_getter(getter, name, *args, **kwargs):
with tf.variable_scope("polyak_model", reuse=True, custom_getter=custom_getter):
self.polyak_model = polyak_model = self.policy(self.sess, self.observation_space, self.action_space,
self.n_envs, self.n_steps + 1,
self.n_envs * (self.n_steps + 1), reuse=True)
self.n_envs * (self.n_steps + 1), reuse=True,
**self.policy_kwargs)

with tf.variable_scope("loss", reuse=False):
self.done_ph = tf.placeholder(tf.float32, [self.n_batch]) # dones
Expand Down Expand Up @@ -539,7 +541,8 @@ def save(self, save_path):
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_envs": self.n_envs,
"_vectorize_action": self._vectorize_action
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
Expand Down
12 changes: 7 additions & 5 deletions stable_baselines/acktr/acktr_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ class ACKTR(ActorCriticRLModel):
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
:param async_eigen_decomp: (bool) Use async eigen decomposition
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
"""

def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01, vf_coef=0.25, vf_fisher_coef=1.0,
learning_rate=0.25, max_grad_norm=0.5, kfac_clip=0.001, lr_schedule='linear', verbose=0,
tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False):
tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False, policy_kwargs=None):

super(ACKTR, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model)
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)

self.n_steps = n_steps
self.gamma = gamma
Expand Down Expand Up @@ -115,15 +116,15 @@ def setup_model(self):
n_batch_train = self.n_envs * self.n_steps

self.model = step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs,
1, n_batch_step, reuse=False)
1, n_batch_step, reuse=False, **self.policy_kwargs)

self.params = params = find_trainable_variables("model")

with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
self.model2 = train_model = self.policy(self.sess, self.observation_space, self.action_space,
self.n_envs, self.n_steps, n_batch_train,
reuse=True)
reuse=True, **self.policy_kwargs)

with tf.variable_scope("loss", reuse=False, custom_getter=tf_util.outer_scope_getter("loss")):
self.advs_ph = advs_ph = tf.placeholder(tf.float32, [None])
Expand Down Expand Up @@ -330,7 +331,8 @@ def save(self, save_path):
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_envs": self.n_envs,
"_vectorize_action": self._vectorize_action
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
Expand Down
16 changes: 11 additions & 5 deletions stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ class BaseRLModel(ABC):
:param policy_base: (BasePolicy) the base policy used by this method
"""

def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base):
def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, policy_kwargs=None):
if isinstance(policy, str):
self.policy = get_policy_from_name(policy_base, policy)
else:
self.policy = policy
self.env = env
self.verbose = verbose
self._requires_vec_env = requires_vec_env
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space = None
self.action_space = None
self.n_envs = None
Expand Down Expand Up @@ -317,9 +318,9 @@ class ActorCriticRLModel(BaseRLModel):
"""

def __init__(self, policy, env, _init_setup_model, verbose=0, policy_base=ActorCriticPolicy,
requires_vec_env=False):
requires_vec_env=False, policy_kwargs=None):
super(ActorCriticRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env,
policy_base=policy_base)
policy_base=policy_base, policy_kwargs=policy_kwargs)

self.sess = None
self.initial_state = None
Expand Down Expand Up @@ -424,6 +425,11 @@ def save(self, save_path):
def load(cls, load_path, env=None, **kwargs):
data, params = cls._load_from_file(load_path)

if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
kwargs['policy_kwargs']))

model = cls(policy=data["policy"], env=None, _init_setup_model=False)
model.__dict__.update(data)
model.__dict__.update(kwargs)
Expand Down Expand Up @@ -451,9 +457,9 @@ class OffPolicyRLModel(BaseRLModel):
:param policy_base: (BasePolicy) the base policy used by this method
"""

def __init__(self, policy, env, replay_buffer, verbose=0, *, requires_vec_env, policy_base):
def __init__(self, policy, env, replay_buffer, verbose=0, *, requires_vec_env, policy_base, policy_kwargs=None):
super(OffPolicyRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env,
policy_base=policy_base)
policy_base=policy_base, policy_kwargs=policy_kwargs)

self.replay_buffer = replay_buffer

Expand Down
24 changes: 17 additions & 7 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class DDPG(OffPolicyRLModel):
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
"""

def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, nb_train_steps=50,
Expand All @@ -175,11 +176,11 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n
normalize_returns=False, enable_popart=False, observation_range=(-5., 5.), critic_l2_reg=0.,
return_range=(-np.inf, np.inf), actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.,
render=False, render_eval=False, memory_limit=100, verbose=0, tensorboard_log=None,
_init_setup_model=True):
_init_setup_model=True, policy_kwargs=None):

# TODO: replay_buffer refactoring
super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DDPGPolicy,
requires_vec_env=False)
requires_vec_env=False, policy_kwargs=policy_kwargs)

# Parameters.
self.gamma = gamma
Expand Down Expand Up @@ -294,10 +295,12 @@ def setup_model(self):
else:
self.ret_rms = None

self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None)
self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None,
**self.policy_kwargs)

# Create target networks.
self.target_policy = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None)
self.target_policy = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None,
**self.policy_kwargs)
self.obs_target = self.target_policy.obs_ph
self.action_target = self.target_policy.action_ph

Expand All @@ -309,13 +312,14 @@ def setup_model(self):
if self.param_noise is not None:
# Configure perturbed actor.
self.param_noise_actor = self.policy(self.sess, self.observation_space, self.action_space, 1, 1,
None)
None, **self.policy_kwargs)
self.obs_noise = self.param_noise_actor.obs_ph
self.action_noise_ph = self.param_noise_actor.action_ph

# Configure separate copy for stddev adoption.
self.adaptive_param_noise_actor = self.policy(self.sess, self.observation_space,
self.action_space, 1, 1, None)
self.action_space, 1, 1, None,
**self.policy_kwargs)
self.obs_adapt_noise = self.adaptive_param_noise_actor.obs_ph
self.action_adapt_noise = self.adaptive_param_noise_actor.action_ph

Expand Down Expand Up @@ -989,7 +993,8 @@ def save(self, save_path):
"policy": self.policy,
"memory_policy": self.memory_policy,
"n_envs": self.n_envs,
"_vectorize_action": self._vectorize_action
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
Expand All @@ -1001,6 +1006,11 @@ def save(self, save_path):
def load(cls, load_path, env=None, **kwargs):
data, params = cls._load_from_file(load_path)

if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
kwargs['policy_kwargs']))

model = cls(None, env, _init_setup_model=False)
model.__dict__.update(data)
model.__dict__.update(kwargs)
Expand Down
14 changes: 10 additions & 4 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000
learning_starts=1000, target_network_update_freq=500, prioritized_replay=False,
prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6, param_noise=False, verbose=0, tensorboard_log=None,
_init_setup_model=True):
_init_setup_model=True, policy_kwargs=None):

# TODO: replay_buffer refactoring
super(DQN, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DQNPolicy,
requires_vec_env=False)
requires_vec_env=False, policy_kwargs=policy_kwargs)

self.checkpoint_path = checkpoint_path
self.param_noise = param_noise
Expand Down Expand Up @@ -115,7 +115,7 @@ def setup_model(self):
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)

self.act, self._train_step, self.update_target, self.step_model = deepq.build_train(
q_func=self.policy,
q_func=partial(self.policy, **self.policy_kwargs),
ob_space=self.observation_space,
ac_space=self.action_space,
optimizer=optimizer,
Expand Down Expand Up @@ -315,7 +315,8 @@ def save(self, save_path):
"action_space": self.action_space,
"policy": self.policy,
"n_envs": self.n_envs,
"_vectorize_action": self._vectorize_action
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}

params = self.sess.run(self.params)
Expand All @@ -326,6 +327,11 @@ def save(self, save_path):
def load(cls, load_path, env=None, **kwargs):
data, params = cls._load_from_file(load_path)

if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
kwargs['policy_kwargs']))

model = cls(policy=data["policy"], env=env, _init_setup_model=False)
model.__dict__.update(data)
model.__dict__.update(kwargs)
Expand Down
Loading

0 comments on commit 88a5c5d

Please sign in to comment.