Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Atari module by adding SimpleMonitor back in #255

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions baselines/common/misc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
import random
import tempfile
import time
import zipfile


Expand Down Expand Up @@ -152,6 +153,76 @@ def __float__(self):
"""Get the current estimate"""
return self._value


class SimpleMonitor(gym.Wrapper):
def __init__(self, env):
"""Adds two qunatities to info returned by every step:

num_steps: int
Number of steps takes so far
rewards: [float]
All the cumulative rewards for the episodes completed so far.
"""
super().__init__(env)
# current episode state
self._current_reward = None
self._num_steps = None
# temporary monitor state that we do not save
self._time_offset = None
self._total_steps = None
# monitor state
self._episode_rewards = []
self._episode_lengths = []
self._episode_end_times = []

def _reset(self):
obs = self.env.reset()
# recompute temporary state if needed
if self._time_offset is None:
self._time_offset = time.time()
if len(self._episode_end_times) > 0:
self._time_offset -= self._episode_end_times[-1]
if self._total_steps is None:
self._total_steps = sum(self._episode_lengths)
# update monitor state
if self._current_reward is not None:
self._episode_rewards.append(self._current_reward)
self._episode_lengths.append(self._num_steps)
self._episode_end_times.append(time.time() - self._time_offset)
# reset episode state
self._current_reward = 0
self._num_steps = 0

return obs

def _step(self, action):
obs, rew, done, info = self.env.step(action)
self._current_reward += rew
self._num_steps += 1
self._total_steps += 1
info['steps'] = self._total_steps
info['rewards'] = self._episode_rewards
return (obs, rew, done, info)

def get_state(self):
return {
'env_id': self.env.unwrapped.spec.id,
'episode_data': {
'episode_rewards': self._episode_rewards,
'episode_lengths': self._episode_lengths,
'episode_end_times': self._episode_end_times,
'initial_reset_time': 0,
}
}

def set_state(self, state):
assert state['env_id'] == self.env.unwrapped.spec.id
ed = state['episode_data']
self._episode_rewards = ed['episode_rewards']
self._episode_lengths = ed['episode_lengths']
self._episode_end_times = ed['episode_end_times']


def boolean_flag(parser, name, default=False, help=None):
"""Add a boolean flag to argparse parser.

Expand Down
2 changes: 2 additions & 0 deletions baselines/deepq/experiments/atari/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from baselines import deepq
from baselines.common.misc_util import (
boolean_flag,
SimpleMonitor,
)
from baselines import bench
from baselines.common.atari_wrappers_deprecated import wrap_dqn
Expand All @@ -31,6 +32,7 @@ def parse_args():
def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v4")
env = bench.Monitor(env, None)
env = SimpleMonitor(env)
env = wrap_dqn(env)
return env

Expand Down
2 changes: 2 additions & 0 deletions baselines/deepq/experiments/atari/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
relatively_safe_pickle_dump,
set_global_seeds,
RunningAvg,
SimpleMonitor,
)
from baselines.common.schedules import LinearSchedule, PiecewiseSchedule
from baselines import bench
Expand Down Expand Up @@ -63,6 +64,7 @@ def parse_args():
def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v4")
monitored_env = bench.Monitor(env, logger.get_dir()) # puts rewards and number of steps in info, before environment is wrapped
monitored_env = SimpleMonitor(monitored_env)
env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w)
return env, monitored_env

Expand Down
3 changes: 2 additions & 1 deletion baselines/deepq/experiments/atari/wang2015_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import baselines.common.tf_util as U

from baselines import deepq, bench
from baselines.common.misc_util import get_wrapper_by_name, boolean_flag, set_global_seeds
from baselines.common.misc_util import get_wrapper_by_name, SimpleMonitor, boolean_flag, set_global_seeds
from baselines.common.atari_wrappers_deprecated import wrap_dqn
from baselines.deepq.experiments.atari.model import model, dueling_model


def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v4")
env_monitored = bench.Monitor(env, None)
env_monitored = SimpleMonitor(env_monitored)
env = wrap_dqn(env_monitored)
return env_monitored, env

Expand Down