diff --git a/ctoybox/baselines/baselines/deepq/utils.py b/ctoybox/baselines/baselines/deepq/utils.py index 4dae7a6b..0fb15695 100644 --- a/ctoybox/baselines/baselines/deepq/utils.py +++ b/ctoybox/baselines/baselines/deepq/utils.py @@ -1,8 +1,6 @@ from baselines.common.input import observation_input from baselines.common.tf_util import adjust_shape -import tensorflow as tf - # ================================================================ # Placeholders # ================================================================ @@ -40,29 +38,6 @@ def make_feed_dict(self, data): return {self._placeholder: adjust_shape(self._placeholder, data)} -class Uint8Input(PlaceholderTfInput): - def __init__(self, shape, name=None): - """Takes input in uint8 format which is cast to float32 and divided by 255 - before passing it to the model. - - On GPU this ensures lower data transfer times. - - Parameters - ---------- - shape: [int] - shape of the tensor. - name: str - name of the underlying placeholder - """ - - super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name)) - self._shape = shape - self._output = tf.cast(super().get(), tf.float32) / 255.0 - - def get(self): - return self._output - - class ObservationInput(PlaceholderTfInput): def __init__(self, observation_space, name=None): """Creates an input placeholder tailored to a specific observation space diff --git a/ctoybox/baselines/baselines/run.py b/ctoybox/baselines/baselines/run.py index 0e7be629..881b8b35 100644 --- a/ctoybox/baselines/baselines/run.py +++ b/ctoybox/baselines/baselines/run.py @@ -111,7 +111,7 @@ def build_env(args, extra_args): env = atari_wrappers.make_atari(env_id, None) env.seed(seed) env = bench.Monitor(env, logger.get_dir()) - env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True) + env = atari_wrappers.wrap_deepmind(env, frame_stack=True) elif alg == 'trpo_mpi': env = atari_wrappers.make_atari(env_id, None) env.seed(seed)