diff --git a/baselines/common/misc_util.py b/baselines/common/misc_util.py index a8d79779ef..4e45ce7871 100644 --- a/baselines/common/misc_util.py +++ b/baselines/common/misc_util.py @@ -4,6 +4,7 @@ import pickle import random import tempfile +import time import zipfile @@ -152,6 +153,76 @@ def __float__(self): """Get the current estimate""" return self._value + +class SimpleMonitor(gym.Wrapper): + def __init__(self, env): + """Adds two qunatities to info returned by every step: + + num_steps: int + Number of steps takes so far + rewards: [float] + All the cumulative rewards for the episodes completed so far. + """ + super().__init__(env) + # current episode state + self._current_reward = None + self._num_steps = None + # temporary monitor state that we do not save + self._time_offset = None + self._total_steps = None + # monitor state + self._episode_rewards = [] + self._episode_lengths = [] + self._episode_end_times = [] + + def _reset(self): + obs = self.env.reset() + # recompute temporary state if needed + if self._time_offset is None: + self._time_offset = time.time() + if len(self._episode_end_times) > 0: + self._time_offset -= self._episode_end_times[-1] + if self._total_steps is None: + self._total_steps = sum(self._episode_lengths) + # update monitor state + if self._current_reward is not None: + self._episode_rewards.append(self._current_reward) + self._episode_lengths.append(self._num_steps) + self._episode_end_times.append(time.time() - self._time_offset) + # reset episode state + self._current_reward = 0 + self._num_steps = 0 + + return obs + + def _step(self, action): + obs, rew, done, info = self.env.step(action) + self._current_reward += rew + self._num_steps += 1 + self._total_steps += 1 + info['steps'] = self._total_steps + info['rewards'] = self._episode_rewards + return (obs, rew, done, info) + + def get_state(self): + return { + 'env_id': self.env.unwrapped.spec.id, + 'episode_data': { + 'episode_rewards': self._episode_rewards, + 'episode_lengths': self._episode_lengths, + 'episode_end_times': self._episode_end_times, + 'initial_reset_time': 0, + } + } + + def set_state(self, state): + assert state['env_id'] == self.env.unwrapped.spec.id + ed = state['episode_data'] + self._episode_rewards = ed['episode_rewards'] + self._episode_lengths = ed['episode_lengths'] + self._episode_end_times = ed['episode_end_times'] + + def boolean_flag(parser, name, default=False, help=None): """Add a boolean flag to argparse parser. diff --git a/baselines/deepq/experiments/atari/enjoy.py b/baselines/deepq/experiments/atari/enjoy.py index db2b70f751..99378e703f 100644 --- a/baselines/deepq/experiments/atari/enjoy.py +++ b/baselines/deepq/experiments/atari/enjoy.py @@ -10,6 +10,7 @@ from baselines import deepq from baselines.common.misc_util import ( boolean_flag, + SimpleMonitor, ) from baselines import bench from baselines.common.atari_wrappers_deprecated import wrap_dqn @@ -31,6 +32,7 @@ def parse_args(): def make_env(game_name): env = gym.make(game_name + "NoFrameskip-v4") env = bench.Monitor(env, None) + env = SimpleMonitor(env) env = wrap_dqn(env) return env diff --git a/baselines/deepq/experiments/atari/train.py b/baselines/deepq/experiments/atari/train.py index 30188c815d..2ada812a3c 100644 --- a/baselines/deepq/experiments/atari/train.py +++ b/baselines/deepq/experiments/atari/train.py @@ -19,6 +19,7 @@ relatively_safe_pickle_dump, set_global_seeds, RunningAvg, + SimpleMonitor, ) from baselines.common.schedules import LinearSchedule, PiecewiseSchedule from baselines import bench @@ -63,6 +64,7 @@ def parse_args(): def make_env(game_name): env = gym.make(game_name + "NoFrameskip-v4") monitored_env = bench.Monitor(env, logger.get_dir()) # puts rewards and number of steps in info, before environment is wrapped + monitored_env = SimpleMonitor(monitored_env) env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w) return env, monitored_env diff --git a/baselines/deepq/experiments/atari/wang2015_eval.py b/baselines/deepq/experiments/atari/wang2015_eval.py index 42b8ba8f87..6b01842963 100644 --- a/baselines/deepq/experiments/atari/wang2015_eval.py +++ b/baselines/deepq/experiments/atari/wang2015_eval.py @@ -6,7 +6,7 @@ import baselines.common.tf_util as U from baselines import deepq, bench -from baselines.common.misc_util import get_wrapper_by_name, boolean_flag, set_global_seeds +from baselines.common.misc_util import get_wrapper_by_name, SimpleMonitor, boolean_flag, set_global_seeds from baselines.common.atari_wrappers_deprecated import wrap_dqn from baselines.deepq.experiments.atari.model import model, dueling_model @@ -14,6 +14,7 @@ def make_env(game_name): env = gym.make(game_name + "NoFrameskip-v4") env_monitored = bench.Monitor(env, None) + env_monitored = SimpleMonitor(env_monitored) env = wrap_dqn(env_monitored) return env_monitored, env