diff --git a/docs/modules/policies.rst b/docs/modules/policies.rst index 4770424698..448df7c187 100644 --- a/docs/modules/policies.rst +++ b/docs/modules/policies.rst @@ -6,6 +6,8 @@ Policy Networks =============== Stable-baselines provides a set of default policies, that can be used with most action spaces. +To customize the default policies, you can specify the `policy_kwargs` parameter to the model class you use. +Those kwargs are then passed to the policy on instantiation. If you need more control on the policy architecture, You can also create a custom policy (see :ref:`custom_policy`). .. note:: diff --git a/stable_baselines/a2c/a2c.py b/stable_baselines/a2c/a2c.py index bfa209b0c2..e54a157c00 100644 --- a/stable_baselines/a2c/a2c.py +++ b/stable_baselines/a2c/a2c.py @@ -34,14 +34,15 @@ class A2C(ActorCriticRLModel): :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance (used only for loading) + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.01, max_grad_norm=0.5, learning_rate=7e-4, alpha=0.99, epsilon=1e-5, lr_schedule='linear', verbose=0, tensorboard_log=None, - _init_setup_model=True): + _init_setup_model=True, policy_kwargs=None): super(A2C, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, - _init_setup_model=_init_setup_model) + _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) self.n_steps = n_steps self.gamma = gamma @@ -99,12 +100,12 @@ def setup_model(self): n_batch_train = self.n_envs * self.n_steps step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, - n_batch_step, reuse=False) + n_batch_step, reuse=False, **self.policy_kwargs) with tf.variable_scope("train_model", reuse=True, custom_getter=tf_util.outer_scope_getter("train_model")): train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, - self.n_steps, n_batch_train, reuse=True) + self.n_steps, n_batch_train, reuse=True, **self.policy_kwargs) with tf.variable_scope("loss", reuse=False): self.actions_ph = train_model.pdtype.sample_placeholder([None], name="action_ph") @@ -260,7 +261,8 @@ def save(self, save_path): "observation_space": self.observation_space, "action_space": self.action_space, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) diff --git a/stable_baselines/acer/acer_simple.py b/stable_baselines/acer/acer_simple.py index 36012f72e1..04e4c1d964 100644 --- a/stable_baselines/acer/acer_simple.py +++ b/stable_baselines/acer/acer_simple.py @@ -90,15 +90,16 @@ class ACER(ActorCriticRLModel): :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, ent_coef=0.01, max_grad_norm=10, learning_rate=7e-4, lr_schedule='linear', rprop_alpha=0.99, rprop_epsilon=1e-5, buffer_size=5000, replay_ratio=4, replay_start=1000, correction_term=10.0, trust_region=True, alpha=0.99, delta=1, - verbose=0, tensorboard_log=None, _init_setup_model=True): + verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): super(ACER, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, - _init_setup_model=_init_setup_model) + _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) self.n_steps = n_steps self.replay_ratio = replay_ratio @@ -180,14 +181,14 @@ def setup_model(self): n_batch_train = self.n_envs * (self.n_steps + 1) step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, - n_batch_step, reuse=False) + n_batch_step, reuse=False, **self.policy_kwargs) self.params = find_trainable_variables("model") with tf.variable_scope("train_model", reuse=True, custom_getter=tf_util.outer_scope_getter("train_model")): train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, - self.n_steps + 1, n_batch_train, reuse=True) + self.n_steps + 1, n_batch_train, reuse=True, **self.policy_kwargs) with tf.variable_scope("moving_average"): # create averaged model @@ -202,7 +203,8 @@ def custom_getter(getter, name, *args, **kwargs): with tf.variable_scope("polyak_model", reuse=True, custom_getter=custom_getter): self.polyak_model = polyak_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, self.n_steps + 1, - self.n_envs * (self.n_steps + 1), reuse=True) + self.n_envs * (self.n_steps + 1), reuse=True, + **self.policy_kwargs) with tf.variable_scope("loss", reuse=False): self.done_ph = tf.placeholder(tf.float32, [self.n_batch]) # dones @@ -539,7 +541,8 @@ def save(self, save_path): "observation_space": self.observation_space, "action_space": self.action_space, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) diff --git a/stable_baselines/acktr/acktr_disc.py b/stable_baselines/acktr/acktr_disc.py index 3cfb7a21a5..de9134f135 100644 --- a/stable_baselines/acktr/acktr_disc.py +++ b/stable_baselines/acktr/acktr_disc.py @@ -38,14 +38,15 @@ class ACKTR(ActorCriticRLModel): :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param async_eigen_decomp: (bool) Use async eigen decomposition + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01, vf_coef=0.25, vf_fisher_coef=1.0, learning_rate=0.25, max_grad_norm=0.5, kfac_clip=0.001, lr_schedule='linear', verbose=0, - tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False): + tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False, policy_kwargs=None): super(ACKTR, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, - _init_setup_model=_init_setup_model) + _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) self.n_steps = n_steps self.gamma = gamma @@ -115,7 +116,7 @@ def setup_model(self): n_batch_train = self.n_envs * self.n_steps self.model = step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, - 1, n_batch_step, reuse=False) + 1, n_batch_step, reuse=False, **self.policy_kwargs) self.params = params = find_trainable_variables("model") @@ -123,7 +124,7 @@ def setup_model(self): custom_getter=tf_util.outer_scope_getter("train_model")): self.model2 = train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, self.n_steps, n_batch_train, - reuse=True) + reuse=True, **self.policy_kwargs) with tf.variable_scope("loss", reuse=False, custom_getter=tf_util.outer_scope_getter("loss")): self.advs_ph = advs_ph = tf.placeholder(tf.float32, [None]) @@ -330,7 +331,8 @@ def save(self, save_path): "observation_space": self.observation_space, "action_space": self.action_space, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index ff2c6bb406..ed01fcae8d 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -26,7 +26,7 @@ class BaseRLModel(ABC): :param policy_base: (BasePolicy) the base policy used by this method """ - def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base): + def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, policy_kwargs=None): if isinstance(policy, str): self.policy = get_policy_from_name(policy_base, policy) else: @@ -34,6 +34,7 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base): self.env = env self.verbose = verbose self._requires_vec_env = requires_vec_env + self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs self.observation_space = None self.action_space = None self.n_envs = None @@ -317,9 +318,9 @@ class ActorCriticRLModel(BaseRLModel): """ def __init__(self, policy, env, _init_setup_model, verbose=0, policy_base=ActorCriticPolicy, - requires_vec_env=False): + requires_vec_env=False, policy_kwargs=None): super(ActorCriticRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env, - policy_base=policy_base) + policy_base=policy_base, policy_kwargs=policy_kwargs) self.sess = None self.initial_state = None @@ -424,6 +425,11 @@ def save(self, save_path): def load(cls, load_path, env=None, **kwargs): data, params = cls._load_from_file(load_path) + if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: + raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. " + "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], + kwargs['policy_kwargs'])) + model = cls(policy=data["policy"], env=None, _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) @@ -451,9 +457,9 @@ class OffPolicyRLModel(BaseRLModel): :param policy_base: (BasePolicy) the base policy used by this method """ - def __init__(self, policy, env, replay_buffer, verbose=0, *, requires_vec_env, policy_base): + def __init__(self, policy, env, replay_buffer, verbose=0, *, requires_vec_env, policy_base, policy_kwargs=None): super(OffPolicyRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env, - policy_base=policy_base) + policy_base=policy_base, policy_kwargs=policy_kwargs) self.replay_buffer = replay_buffer diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 43c2f62e83..ef64d4404e 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -167,6 +167,7 @@ class DDPG(OffPolicyRLModel): :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, nb_train_steps=50, @@ -175,11 +176,11 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n normalize_returns=False, enable_popart=False, observation_range=(-5., 5.), critic_l2_reg=0., return_range=(-np.inf, np.inf), actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1., render=False, render_eval=False, memory_limit=100, verbose=0, tensorboard_log=None, - _init_setup_model=True): + _init_setup_model=True, policy_kwargs=None): # TODO: replay_buffer refactoring super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DDPGPolicy, - requires_vec_env=False) + requires_vec_env=False, policy_kwargs=policy_kwargs) # Parameters. self.gamma = gamma @@ -294,10 +295,12 @@ def setup_model(self): else: self.ret_rms = None - self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None) + self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None, + **self.policy_kwargs) # Create target networks. - self.target_policy = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None) + self.target_policy = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, None, + **self.policy_kwargs) self.obs_target = self.target_policy.obs_ph self.action_target = self.target_policy.action_ph @@ -309,13 +312,14 @@ def setup_model(self): if self.param_noise is not None: # Configure perturbed actor. self.param_noise_actor = self.policy(self.sess, self.observation_space, self.action_space, 1, 1, - None) + None, **self.policy_kwargs) self.obs_noise = self.param_noise_actor.obs_ph self.action_noise_ph = self.param_noise_actor.action_ph # Configure separate copy for stddev adoption. self.adaptive_param_noise_actor = self.policy(self.sess, self.observation_space, - self.action_space, 1, 1, None) + self.action_space, 1, 1, None, + **self.policy_kwargs) self.obs_adapt_noise = self.adaptive_param_noise_actor.obs_ph self.action_adapt_noise = self.adaptive_param_noise_actor.action_ph @@ -989,7 +993,8 @@ def save(self, save_path): "policy": self.policy, "memory_policy": self.memory_policy, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) @@ -1001,6 +1006,11 @@ def save(self, save_path): def load(cls, load_path, env=None, **kwargs): data, params = cls._load_from_file(load_path) + if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: + raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. " + "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], + kwargs['policy_kwargs'])) + model = cls(None, env, _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index 44f9bb151d..edee7f5416 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -52,11 +52,11 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 learning_starts=1000, target_network_update_freq=500, prioritized_replay=False, prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None, prioritized_replay_eps=1e-6, param_noise=False, verbose=0, tensorboard_log=None, - _init_setup_model=True): + _init_setup_model=True, policy_kwargs=None): # TODO: replay_buffer refactoring super(DQN, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DQNPolicy, - requires_vec_env=False) + requires_vec_env=False, policy_kwargs=policy_kwargs) self.checkpoint_path = checkpoint_path self.param_noise = param_noise @@ -115,7 +115,7 @@ def setup_model(self): optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) self.act, self._train_step, self.update_target, self.step_model = deepq.build_train( - q_func=self.policy, + q_func=partial(self.policy, **self.policy_kwargs), ob_space=self.observation_space, ac_space=self.action_space, optimizer=optimizer, @@ -315,7 +315,8 @@ def save(self, save_path): "action_space": self.action_space, "policy": self.policy, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) @@ -326,6 +327,11 @@ def save(self, save_path): def load(cls, load_path, env=None, **kwargs): data, params = cls._load_from_file(load_path) + if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: + raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. " + "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], + kwargs['policy_kwargs'])) + model = cls(policy=data["policy"], env=env, _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) diff --git a/stable_baselines/ppo1/pposgd_simple.py b/stable_baselines/ppo1/pposgd_simple.py index 6b429532ea..98a7024c5c 100644 --- a/stable_baselines/ppo1/pposgd_simple.py +++ b/stable_baselines/ppo1/pposgd_simple.py @@ -37,14 +37,15 @@ class PPO1(ActorCriticRLModel): :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ def __init__(self, policy, env, gamma=0.99, timesteps_per_actorbatch=256, clip_param=0.2, entcoeff=0.01, optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64, lam=0.95, adam_epsilon=1e-5, - schedule='linear', verbose=0, tensorboard_log=None, _init_setup_model=True): + schedule='linear', verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): super().__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=False, - _init_setup_model=_init_setup_model) + _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) self.gamma = gamma self.timesteps_per_actorbatch = timesteps_per_actorbatch @@ -85,12 +86,12 @@ def setup_model(self): # Construct network for new policy self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, - None, reuse=False) + None, reuse=False, **self.policy_kwargs) # Network for old policy with tf.variable_scope("oldpi", reuse=False): old_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, - None, reuse=False) + None, reuse=False, **self.policy_kwargs) with tf.variable_scope("loss", reuse=False): # Target advantage function (if applicable) @@ -327,7 +328,8 @@ def save(self, save_path): "observation_space": self.observation_space, "action_space": self.action_space, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) diff --git a/stable_baselines/ppo2/ppo2.py b/stable_baselines/ppo2/ppo2.py index 780ef25f26..bdf264e914 100644 --- a/stable_baselines/ppo2/ppo2.py +++ b/stable_baselines/ppo2/ppo2.py @@ -36,14 +36,15 @@ class PPO2(ActorCriticRLModel): :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning_rate=2.5e-4, vf_coef=0.5, max_grad_norm=0.5, lam=0.95, nminibatches=4, noptepochs=4, cliprange=0.2, verbose=0, - tensorboard_log=None, _init_setup_model=True): + tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): super(PPO2, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, - _init_setup_model=_init_setup_model) + _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) self.learning_rate = learning_rate self.cliprange = cliprange @@ -112,12 +113,12 @@ def setup_model(self): n_batch_train = self.n_batch // self.nminibatches act_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1, - n_batch_step, reuse=False) + n_batch_step, reuse=False, **self.policy_kwargs) with tf.variable_scope("train_model", reuse=True, custom_getter=tf_util.outer_scope_getter("train_model")): train_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs // self.nminibatches, self.n_steps, n_batch_train, - reuse=True) + reuse=True, **self.policy_kwargs) with tf.variable_scope("loss", reuse=False): self.action_ph = train_model.pdtype.sample_placeholder([None], name="action_ph") @@ -155,6 +156,8 @@ def setup_model(self): with tf.variable_scope('model'): self.params = tf.trainable_variables() + for var in self.params: + tf.summary.histogram(var.name, var) grads = tf.gradients(loss, self.params) if self.max_grad_norm is not None: grads, _grad_norm = tf.clip_by_global_norm(grads, self.max_grad_norm) @@ -350,7 +353,8 @@ def save(self, save_path): "observation_space": self.observation_space, "action_space": self.action_space, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index cb8897ec31..10f6858030 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -56,15 +56,16 @@ class SAC(OffPolicyRLModel): :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=50000, learning_starts=100, train_freq=1, batch_size=64, tau=0.005, ent_coef='auto', target_update_interval=1, gradient_steps=1, target_entropy='auto', - verbose=0, tensorboard_log=None, _init_setup_model=True): + verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): super(SAC, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, - policy_base=SACPolicy, requires_vec_env=False) + policy_base=SACPolicy, requires_vec_env=False, policy_kwargs=policy_kwargs) self.buffer_size = buffer_size self.learning_rate = learning_rate @@ -130,8 +131,10 @@ def setup_model(self): with tf.variable_scope("input", reuse=False): # Create policy and target TF objects - self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space) - self.target_policy = self.policy(self.sess, self.observation_space, self.action_space) + self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space, + **self.policy_kwargs) + self.target_policy = self.policy(self.sess, self.observation_space, self.action_space, + **self.policy_kwargs) # Initialize Placeholders self.observations_ph = self.policy_tf.obs_ph @@ -494,7 +497,8 @@ def save(self, save_path): "action_space": self.action_space, "policy": self.policy, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) @@ -506,6 +510,11 @@ def save(self, save_path): def load(cls, load_path, env=None, **kwargs): data, params = cls._load_from_file(load_path) + if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: + raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. " + "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], + kwargs['policy_kwargs'])) + model = cls(policy=data["policy"], env=env, _init_setup_model=False) model.__dict__.update(data) model.__dict__.update(kwargs) diff --git a/stable_baselines/trpo_mpi/trpo_mpi.py b/stable_baselines/trpo_mpi/trpo_mpi.py index 839b9b4a50..4cc1e9cdf7 100644 --- a/stable_baselines/trpo_mpi/trpo_mpi.py +++ b/stable_baselines/trpo_mpi/trpo_mpi.py @@ -21,7 +21,7 @@ class TRPO(ActorCriticRLModel): def __init__(self, policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, lam=0.98, entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, verbose=0, tensorboard_log=None, - _init_setup_model=True): + _init_setup_model=True, policy_kwargs=None): """ learns a TRPO policy using the given environment @@ -39,9 +39,10 @@ def __init__(self, policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.0 :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation """ super(TRPO, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=False, - _init_setup_model=_init_setup_model) + _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) self.using_gail = False self.timesteps_per_batch = timesteps_per_batch @@ -468,7 +469,8 @@ def save(self, save_path): "observation_space": self.observation_space, "action_space": self.action_space, "n_envs": self.n_envs, - "_vectorize_action": self._vectorize_action + "_vectorize_action": self._vectorize_action, + "policy_kwargs": self.policy_kwargs } params = self.sess.run(self.params) diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index bd574c5fd5..28fd46a522 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,9 +1,12 @@ import os +import gym import pytest +import tensorflow as tf from stable_baselines import A2C, ACER, ACKTR, DQN, PPO1, PPO2, TRPO, SAC, DDPG from stable_baselines.common.policies import FeedForwardPolicy +from stable_baselines.common.vec_env import DummyVecEnv from stable_baselines.deepq.policies import FeedForwardPolicy as DQNPolicy from stable_baselines.ddpg.policies import FeedForwardPolicy as DDPGPolicy from stable_baselines.sac.policies import FeedForwardPolicy as SACPolicy @@ -35,15 +38,15 @@ def __init__(self, *args, **kwargs): feature_extraction="mlp") MODEL_DICT = { - 'a2c': (A2C, CustomCommonPolicy), - 'acer': (ACER, CustomCommonPolicy), - 'acktr': (ACKTR, CustomCommonPolicy), - 'dqn': (DQN, CustomDQNPolicy), - 'ddpg': (DDPG, CustomDDPGPolicy), - 'ppo1': (PPO1, CustomCommonPolicy), - 'ppo2': (PPO2, CustomCommonPolicy), - 'sac': (SAC, CustomSACPolicy), - 'trpo': (TRPO, CustomCommonPolicy), + 'a2c': (A2C, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), + 'acer': (ACER, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), + 'acktr': (ACKTR, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), + 'dqn': (DQN, CustomDQNPolicy, dict()), + 'ddpg': (DDPG, CustomDDPGPolicy, dict()), + 'ppo1': (PPO1, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), + 'ppo2': (PPO2, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), + 'sac': (SAC, CustomSACPolicy, dict()), + 'trpo': (TRPO, CustomCommonPolicy, dict(act_fun=tf.nn.relu)), } @@ -55,11 +58,9 @@ def test_custom_policy(model_name): """ try: - model_class, policy = MODEL_DICT[model_name] - if model_name in ['ddpg', 'sac']: - env = 'MountainCarContinuous-v0' - else: - env = 'CartPole-v1' + model_class, policy, _ = MODEL_DICT[model_name] + env = 'MountainCarContinuous-v0' if model_name in ['ddpg', 'sac'] else 'CartPole-v1' + # create and train model = model_class(policy, env) model.learn(total_timesteps=100, seed=0) @@ -82,3 +83,44 @@ def test_custom_policy(model_name): finally: if os.path.exists("./test_model"): os.remove("./test_model") + + +@pytest.mark.parametrize("model_name", MODEL_DICT.keys()) +def test_custom_policy_kwargs(model_name): + """ + Test if the algorithm (with a custom policy) can be loaded and saved without any issues. + :param model_class: (BaseRLModel) A RL model + """ + + try: + model_class, policy, policy_kwargs = MODEL_DICT[model_name] + env = 'MountainCarContinuous-v0' if model_name in ['ddpg', 'sac'] else 'CartPole-v1' + + # create and train + model = model_class(policy, env, policy_kwargs=policy_kwargs) + model.learn(total_timesteps=100, seed=0) + + model.save("./test_model") + del model + + # loading + + env = DummyVecEnv([lambda: gym.make(env)]) + + # Load with specifying policy_kwargs + model = model_class.load("./test_model", policy=policy, env=env, policy_kwargs=policy_kwargs) + model.learn(total_timesteps=100, seed=0) + del model + + # Load without specifying policy_kwargs + model = model_class.load("./test_model", policy=policy, env=env) + model.learn(total_timesteps=100, seed=0) + del model + + # Load wit different wrong policy_kwargs + with pytest.raises(ValueError): + model = model_class.load("./test_model", policy=policy, env=env, policy_kwargs=dict(wrong="kwargs")) + + finally: + if os.path.exists("./test_model"): + os.remove("./test_model")