Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check Action Type #712

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added Tests, changed changlog
MentalGear committed Feb 29, 2020
commit 5ed28e5496e0f67d4549610aff83316a20336561
2 changes: 1 addition & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ Others:
- Removed redundant return value from ``a2c.utils::total_episode_reward_logger``. (@shwang)
- Cleanup and refactoring in ``common/identity_env.py`` (@shwang)
- Added a Makefile to simplify common development tasks (build the doc, type check, run the tests)
- Action Type Check: Assertion added to ``VecEnv` to check if action is of type list or np.ndarray. Otherwise a developer friendly message is displayed on how to fix the issue.
- Action Type Check: Assertion added to ``VecEnv` to check if action is of type list or np.ndarray. Otherwise a developer friendly message is displayed on how to fix the issue. (@mentalgear)
Documentation:
^^^^^^^^^^^^^^
3 changes: 2 additions & 1 deletion stable_baselines/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
@@ -146,7 +146,8 @@ def step(self, actions):
:param actions: ([int] or [float]) the action
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
"""
assert isinstance( actions,( list, np.ndarray ) ), "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue."
if not isinstance( actions,( list, np.ndarray ) ):
raise TypeError( "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." )
self.step_async(actions)
return self.step_wait()

17 changes: 17 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -127,6 +127,20 @@ def wrong_step(_action):
check_env(env)


def test_action_format ( env, action ):
"""
Helper to check that the error is caught.
:param env: (gym.Env)
:param new_step_return: (tuple)
"""

with pytest.raises(TypeError):
step( 0 )

with pytest.raises(not TypeError):
step( [ 0 ] )


def test_common_failures_step():
"""
Test that common failure cases of the `step` method are caught
@@ -147,3 +161,6 @@ def test_common_failures_step():
# Done is not a boolean
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {}))
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {}))

# action format must be [] or np.ndarray
test_action_format ( env, 0 )