From 4929e54449d0b01b7c1d3841a0b66ccabc7f7bff Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 21 Sep 2019 18:14:02 +0200 Subject: [PATCH] Update DQN arguments: add double_q (#481) * Update DQN arguments: add double_q * Fix CI build * [ci skip] Update link * Replace pdf link with arxiv link --- docs/misc/changelog.rst | 4 +++- docs/modules/dqn.rst | 12 ++++++++++-- stable_baselines/deepq/dqn.py | 23 +++++++++++------------ tests/test_atari.py | 2 +- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8286e69b3f..0741d14d83 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -21,12 +21,14 @@ Breaking Changes: wrapped in `if __name__ == '__main__'`. You can restore previous behavior by explicitly setting `start_method = 'fork'`. See `PR #428 `_. -- updated dependencies: tensorflow v1.8.0 is now required +- Updated dependencies: tensorflow v1.8.0 is now required +- Remove `checkpoint_path` and `checkpoint_freq` argument from `DQN` that were not used New Features: ^^^^^^^^^^^^^ - **important change** Switch to using zip-archived JSON and Numpy `savez` for storing models for better support across library/Python versions. (@Miffyli) +- Add `double_q` argument to `DQN` constructor Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 0e14d8616c..1b3d48f9ef 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -27,7 +27,16 @@ and its extensions (Double-DQN, Dueling-DQN, Prioritized Experience Replay). Notes ----- -- Original paper: https://arxiv.org/abs/1312.5602 +- DQN paper: https://arxiv.org/abs/1312.5602 +- Dueling DQN: https://arxiv.org/abs/1511.06581 +- Double-Q Learning: https://arxiv.org/abs/1509.06461 +- Prioritized Experience Replay: https://arxiv.org/abs/1511.05952 + +.. note:: + + By default, the DQN class has double q learning and dueling extensions enabled. + See `Issue #406 `_ for disabling dueling. + To disable double-q learning, you can change the default value in the constructor. Can I use? @@ -60,7 +69,6 @@ Example from stable_baselines import DQN env = gym.make('CartPole-v1') - env = DummyVecEnv([lambda: env]) model = DQN(MlpPolicy, env, verbose=1) model.learn(total_timesteps=25000) diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index 9b10691807..637c9462c7 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -15,7 +15,11 @@ class DQN(OffPolicyRLModel): """ - The DQN model class. DQN paper: https://arxiv.org/pdf/1312.5602.pdf + The DQN model class. + DQN paper: https://arxiv.org/abs/1312.5602 + Dueling DQN: https://arxiv.org/abs/1511.06581 + Double-Q Learning: https://arxiv.org/abs/1509.06461 + Prioritized Experience Replay: https://arxiv.org/abs/1511.05952 :param policy: (DQNPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, ...) :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) @@ -27,11 +31,7 @@ class DQN(OffPolicyRLModel): :param exploration_final_eps: (float) final value of random action probability :param train_freq: (int) update the model every `train_freq` steps. set to None to disable printing :param batch_size: (int) size of a batched sampled from replay buffer for training - :param checkpoint_freq: (int) how often to save the model. This is so that the best version is restored at the - end of the training. If you do not wish to restore the best version - at the end of the training set this variable to None. - :param checkpoint_path: (str) replacement path used if you need to log to somewhere else than a temporary - directory. + :param double_q: (bool) Whether to enable Double-Q learning or not. :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts :param target_network_update_freq: (int) update the target network every `target_network_update_freq` steps. :param prioritized_replay: (bool) if True prioritized replay buffer will be used. @@ -50,7 +50,7 @@ class DQN(OffPolicyRLModel): """ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=50000, exploration_fraction=0.1, - exploration_final_eps=0.02, train_freq=1, batch_size=32, checkpoint_freq=10000, checkpoint_path=None, + exploration_final_eps=0.02, train_freq=1, batch_size=32, double_q=True, learning_starts=1000, target_network_update_freq=500, prioritized_replay=False, prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None, prioritized_replay_eps=1e-6, param_noise=False, verbose=0, tensorboard_log=None, @@ -60,7 +60,6 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 super(DQN, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DQNPolicy, requires_vec_env=False, policy_kwargs=policy_kwargs) - self.checkpoint_path = checkpoint_path self.param_noise = param_noise self.learning_starts = learning_starts self.train_freq = train_freq @@ -68,7 +67,6 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 self.prioritized_replay_eps = prioritized_replay_eps self.batch_size = batch_size self.target_network_update_freq = target_network_update_freq - self.checkpoint_freq = checkpoint_freq self.prioritized_replay_alpha = prioritized_replay_alpha self.prioritized_replay_beta0 = prioritized_replay_beta0 self.prioritized_replay_beta_iters = prioritized_replay_beta_iters @@ -79,6 +77,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 self.gamma = gamma self.tensorboard_log = tensorboard_log self.full_tensorboard_log = full_tensorboard_log + self.double_q = double_q self.graph = None self.sess = None @@ -131,7 +130,8 @@ def setup_model(self): grad_norm_clipping=10, param_noise=self.param_noise, sess=self.sess, - full_tensorboard_log=self.full_tensorboard_log + full_tensorboard_log=self.full_tensorboard_log, + double_q=self.double_q ) self.proba_step = self.step_model.proba_step self.params = tf_util.get_trainable_vars("deepq") @@ -334,7 +334,7 @@ def get_parameter_list(self): def save(self, save_path, cloudpickle=False): # params data = { - "checkpoint_path": self.checkpoint_path, + "double_q": self.double_q, "param_noise": self.param_noise, "learning_starts": self.learning_starts, "train_freq": self.train_freq, @@ -342,7 +342,6 @@ def save(self, save_path, cloudpickle=False): "prioritized_replay_eps": self.prioritized_replay_eps, "batch_size": self.batch_size, "target_network_update_freq": self.target_network_update_freq, - "checkpoint_freq": self.checkpoint_freq, "prioritized_replay_alpha": self.prioritized_replay_alpha, "prioritized_replay_beta0": self.prioritized_replay_beta0, "prioritized_replay_beta_iters": self.prioritized_replay_beta_iters, diff --git a/tests/test_atari.py b/tests/test_atari.py index d5633b7c63..2b94da238d 100644 --- a/tests/test_atari.py +++ b/tests/test_atari.py @@ -63,7 +63,7 @@ def test_deepq(): model = DQN(env=env, policy=CnnPolicy, learning_rate=1e-4, buffer_size=10000, exploration_fraction=0.1, exploration_final_eps=0.01, train_freq=4, learning_starts=10000, target_network_update_freq=1000, - gamma=0.99, prioritized_replay=True, prioritized_replay_alpha=0.6, checkpoint_freq=10000) + gamma=0.99, prioritized_replay=True, prioritized_replay_alpha=0.6) model.learn(total_timesteps=NUM_TIMESTEPS) env.close()