Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DQN Atari examples #187

Merged
merged 19 commits into from
Aug 29, 2020
11 changes: 11 additions & 0 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Use DQN to play Atari Games:

|task |best reward| reward curve | parameters | time cost|
| ---- | ---- | ---- | ---- |----|
| PongNoFrameskip-v4 | 20 | ![](results/dqn/pong_rew.png) | python3 atari_dqn.py | ~ 30 min(15 epoch)|
| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100 |3~4h(100 epoch)|
| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100 |3~4h(100 epoch)|
| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100 |3~4h(100 epoch)|
| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100 |3~4h(100 epoch)|
| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100 |3~4h(100 epoch)|
| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100 | 3~4h(100 epoch)|
145 changes: 145 additions & 0 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import DQN
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps_test', type=float, default=0.005)
parser.add_argument('--eps_train', type=float, default=1.)
parser.add_argument('--eps_train_final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n_step', type=int, default=3)
parser.add_argument('--target_update_freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step_per_epoch', type=int, default=10000)
parser.add_argument('--collect_per_step', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--training_num', type=int, default=16)
parser.add_argument('--test_num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_dqn(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape: ", args.state_shape)
print("Actions shape: ", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = DQN(*args.state_shape,
args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DQNPolicy(net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))
print("Loaded agent from: ", args.resume_path)
# collector
train_collector = Collector(
policy, train_envs,
ReplayBuffer(args.buffer_size, ignore_obs_next=True)) # save memory
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(x):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
elif 'Pong' in args.task:
return x >= 20

def train_fn(x):
# nature DQN setting, linear decay in the first 1M steps
now = x * args.collect_per_step * args.step_per_epoch
if now <= 1e6:
eps = args.eps_train - now / 1e6 * \
(args.eps_train - args.eps_train_final)
policy.set_eps(eps)
else:
policy.set_eps(args.eps_train_final)
print("set eps =", policy.eps)

def test_fn(x):
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
print("Testing agent ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * 4)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)

pprint.pprint(result)
watch()


if __name__ == '__main__':
test_dqn(get_args())
228 changes: 228 additions & 0 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import cv2
import gym
import numpy as np
from collections import deque


class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
:param gym.Env env: the environment to wrap.
:param int noop_max: the maximum value of no-ops to run.
"""

def __init__(self, env, noop_max=30):
super().__init__(env)
self.noop_max = noop_max
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

def reset(self):
self.env.reset()
noops = self.unwrapped.np_random.randint(self.noop_max) + 1
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset()
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
return obs


class MaxAndSkipEnv(gym.Wrapper):
"""Return only every `skip`-th frame (frameskipping) using most recent raw
observations (for max pooling across time steps)
:param gym.Env env: the environment to wrap.
:param int skip: number of `skip`-th frame.
"""

def __init__(self, env, skip=4):
super().__init__(env)
self._skip = skip

def step(self, action):
"""Step the environment with the given action. Repeat action, sum
reward, and max over last observations.
"""
obs_list, total_reward, done = [], 0., None
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
obs_list.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(obs_list[-2:], axis=0)
return max_frame, total_reward, done, info


class EpisodicLifeEnv(gym.Wrapper):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
:param gym.Env env: the environment to wrap.
"""

def __init__(self, env):
super().__init__(env)
self.lives = 0
self.was_real_done = True

def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal, then update lives to
# handle bonus lives
lives = self.env.unwrapped.ale.lives()
if 0 < lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few
# frames, so its important to keep lives > 0, so that we only reset
# once the environment advertises done.
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
done = True
self.lives = lives
return obs, reward, done, info

def reset(self):
"""Calls the Gym environment reset, only when lives are exhausted. This
way all states are still reachable even though lives are episodic, and
the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs


class FireResetEnv(gym.Wrapper):
"""Take action on reset for environments that are fixed until firing.
:param gym.Env env: the environment to wrap.
"""

def __init__(self, env):
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3

def reset(self):
self.env.reset()
for act in [1, 2]:
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
obs, _, done, _ = self.env.step(act)
if done:
self.env.reset()
return obs


class WarpFrame(gym.ObservationWrapper):
"""Warp frames to 84x84 as done in the Nature paper and later work.
:param gym.Env env: the environment to wrap.
"""

def __init__(self, env):
super().__init__(env)
self.size = 84
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(self.size, self.size),
dtype=env.observation_space.dtype)

def observation(self, frame):
"""returns the current observation from a frame"""
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
return cv2.resize(frame, (self.size, self.size),
interpolation=cv2.INTER_AREA)


class ScaledFloatFrame(gym.ObservationWrapper):
"""Normalize observations to 0~1.
:param gym.Env env: the environment to wrap.
"""

def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Box(
low=0, high=1., shape=env.observation_space.shape,
dtype=np.float32)

def observation(self, observation):
return np.array(observation, dtype=np.float32) / 255.
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved


class ClipRewardEnv(gym.RewardWrapper):
"""clips the reward to {+1, 0, -1} by its sign.
:param gym.Env env: the environment to wrap.
"""

def __init__(self, env):
super().__init__(env)
self.reward_range = (-1, 1)
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved

def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved


class FrameStack(gym.Wrapper):
"""Stack n_frames last frames.
Returns lazy array, which is much more memory efficient.
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
:param gym.Env env: the environment to wrap.
:param int n_frames: the number of frames to stack.
"""

def __init__(self, env, n_frames):
super().__init__(env)
self.n_frames = n_frames
self.frames = deque([], maxlen=n_frames)
shape = (n_frames,) + env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=shape, dtype=env.observation_space.dtype)
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved

def reset(self):
obs = self.env.reset()
for _ in range(self.n_frames):
self.frames.append(obs)
return self._get_ob()

def step(self, action):
obs, reward, done, info = self.env.step(action)
self.frames.append(obs)
return self._get_ob(), reward, done, info

def _get_ob(self):
# the original wrapper use `LazyFrames` but since we use np buffer,
# it has no effect
return np.stack(self.frames, axis=0)


def make_atari(env_id):
"""Create a wrapped atari Environment.
:param str env_id: the environment ID.
:return: the wrapped atari environment.
"""
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
return env


def wrap_deepmind(env_id, episode_life=True, clip_rewards=True,
Trinkle23897 marked this conversation as resolved.
Show resolved Hide resolved
frame_stack=4, scale=False):
"""Configure environment for DeepMind-style Atari.
:param str env_id: the atari environment id.
:param bool episode_life: wrap the episode life wrapper.
:param bool clip_rewards: wrap the reward clipping wrapper.
:param int frame_stack: wrap the frame stacking wrapper.
:param bool scale: wrap the scaling observation wrapper.
:return: the wrapped atari environment.
"""
env = make_atari(env_id)
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, frame_stack)
return env
Binary file added examples/atari/results/dqn/Breakout_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/dqn/Enduro_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/dqn/MsPacman_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/dqn/Qbert_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/dqn/Seaquest_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/dqn/SpaceInvader_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/dqn/pong_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.