Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DQN Atari examples #187

Merged
merged 19 commits into from
Aug 29, 2020
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m

All of the platforms use 5 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns (except for PyTorch-DRL).

We will add results of Atari Pong / Mujoco these days.
The Atari/Mujoco benchmark results are under [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders.

### Reproducible

Expand Down
6 changes: 6 additions & 0 deletions docs/tutorials/trick.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ Tianshou has many short-but-efficient lines of code. For example, when we want t
.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.


Atari/Mujoco Task Specific
--------------------------

Please refer to `Atari examples page <https://github.com/thu-ml/tianshou/tree/master/examples/atari>`_ and `Mujoco examples page <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco>`_.


Finally
-------

Expand Down
25 changes: 25 additions & 0 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Atari General

The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.

The Atari env seed cannot be fixed due to the discussion [here](https://github.com/openai/gym/issues/1478), but it is not a big issue since on Atari it will always have the similar results.

The env wrapper is a crucial thing. Without wrappers, the agent cannot perform well enough on Atari games. Many existing RL codebases use [OpenAI wrapper](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py), but it is not the original DeepMind version ([related issue](https://github.com/openai/baselines/issues/240)). Dopamine has a different [wrapper](https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py) but unfortunately it cannot work very well in our codebase.

# DQN (single run)

One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

| task | best reward | reward curve | parameters | time cost |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- |
| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) |
| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) |
| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |

Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.

We haven't tuned this result to the best, so have fun with playing these hyperparameters!
147 changes: 147 additions & 0 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import DQN
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps_test', type=float, default=0.005)
parser.add_argument('--eps_train', type=float, default=1.)
parser.add_argument('--eps_train_final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n_step', type=int, default=3)
parser.add_argument('--target_update_freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step_per_epoch', type=int, default=10000)
parser.add_argument('--collect_per_step', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--training_num', type=int, default=16)
parser.add_argument('--test_num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_dqn(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape: ", args.state_shape)
print("Actions shape: ", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = DQN(*args.state_shape,
args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DQNPolicy(net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
save_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(x):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
elif 'Pong' in args.task:
return x >= 20

def train_fn(x):
# nature DQN setting, linear decay in the first 1M steps
now = x * args.collect_per_step * args.step_per_epoch
if now <= 1e6:
eps = args.eps_train - now / 1e6 * \
(args.eps_train - args.eps_train_final)
policy.set_eps(eps)
else:
policy.set_eps(args.eps_train_final)
print("set eps =", policy.eps)

def test_fn(x):
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
print("Testing agent ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * 4)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)

pprint.pprint(result)
watch()


if __name__ == '__main__':
test_dqn(get_args())
Loading