Skip to content

Commit

Permalink
attempt to fix DQN bug; copied commit 4121d9c from baselines repo
Browse files Browse the repository at this point in the history
  • Loading branch information
kclary committed Jan 13, 2019
1 parent e00c798 commit 825b40f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 26 deletions.
25 changes: 0 additions & 25 deletions ctoybox/baselines/baselines/deepq/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from baselines.common.input import observation_input
from baselines.common.tf_util import adjust_shape

import tensorflow as tf

# ================================================================
# Placeholders
# ================================================================
Expand Down Expand Up @@ -40,29 +38,6 @@ def make_feed_dict(self, data):
return {self._placeholder: adjust_shape(self._placeholder, data)}


class Uint8Input(PlaceholderTfInput):
def __init__(self, shape, name=None):
"""Takes input in uint8 format which is cast to float32 and divided by 255
before passing it to the model.
On GPU this ensures lower data transfer times.
Parameters
----------
shape: [int]
shape of the tensor.
name: str
name of the underlying placeholder
"""

super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name))
self._shape = shape
self._output = tf.cast(super().get(), tf.float32) / 255.0

def get(self):
return self._output


class ObservationInput(PlaceholderTfInput):
def __init__(self, observation_space, name=None):
"""Creates an input placeholder tailored to a specific observation space
Expand Down
2 changes: 1 addition & 1 deletion ctoybox/baselines/baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def build_env(args, extra_args):
env = atari_wrappers.make_atari(env_id, None)
env.seed(seed)
env = bench.Monitor(env, logger.get_dir())
env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True)
env = atari_wrappers.wrap_deepmind(env, frame_stack=True)
elif alg == 'trpo_mpi':
env = atari_wrappers.make_atari(env_id, None)
env.seed(seed)
Expand Down

1 comment on commit 825b40f

@kclary
Copy link
Collaborator Author

@kclary kclary commented on 825b40f Jan 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github thread with discussion of error and fix: openai/baselines#431

Please sign in to comment.