Skip to content

Commit

Permalink
Use benchmark_environments test code (#16)
Browse files Browse the repository at this point in the history
* Use benchmark_environments test code

* Force CI to regenerate cache

* Fix for Gym breaking change: openai/gym#1802
  • Loading branch information
AdamGleave authored Feb 11, 2020
1 parent 59c2f25 commit faa59d0
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 10 deletions.
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
gym[mujoco]
benchmark-environments @ git+https://github.com/HumanCompatibleAI/benchmark-environments.git
imitation @ git+https://github.com/HumanCompatibleAI/imitation.git@e99844
stable-baselines @ git+https://github.com/hill-a/stable-baselines.git
gym[mujoco]
matplotlib
numpy
pandas
pymdptoolbox
seaborn
setuptools
scipy
stable-baselines @ git+https://github.com/hill-a/stable-baselines.git
tensorflow>=1.13,<1.14
xarray
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_requirements(fname):
package_data={"evaluating_rewards": ["py.typed"]},
install_requires=load_requirements("requirements.txt"),
extras_require={"test": load_requirements("requirements-dev.txt")},
url="https://github.com/AdamGleave/evaluating_rewards",
url="https://github.com/HumanCompatibleAI/evaluating_rewards",
license="Apache License, Version 2.0",
classifiers=[
# Trove classifiers
Expand Down
2 changes: 1 addition & 1 deletion src/evaluating_rewards/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def register_similar(existing_name: str, new_name: str, **kwargs_delta):
**kwargs_delta: Arguments to override the specification from existing_id.
"""
existing_spec = gym.spec(existing_name)
fields = ["entry_point", "reward_threshold", "nondeterministic", "tags", "max_episode_steps"]
fields = ["entry_point", "reward_threshold", "nondeterministic", "max_episode_steps"]
kwargs = {k: getattr(existing_spec, k) for k in fields}
kwargs["kwargs"] = existing_spec._kwargs # pylint:disable=protected-access
kwargs.update(**kwargs_delta)
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import copy
from typing import Dict, Iterator, Tuple, TypeVar

from imitation.testing import envs as test_envs
from benchmark_environments.testing import envs as test_envs
import pytest
from stable_baselines.common import vec_env

Expand Down
17 changes: 12 additions & 5 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
"evaluating_rewards/".
"""

from benchmark_environments.testing import envs as bench_test
import gym
from imitation.testing import envs as test_envs
from imitation.testing import envs as imitation_test
import pytest

from evaluating_rewards import envs # noqa: F401 pylint:disable=unused-import
Expand All @@ -38,16 +39,22 @@ class TestEnvs:
"""Simple smoke tests for custom environments."""

def test_seed(self, env, env_name):
test_envs.test_seed(env, env_name, DETERMINISTIC_ENVS)
bench_test.test_seed(env, env_name, DETERMINISTIC_ENVS)

def test_rollout(self, env):
test_envs.test_rollout(env)
def test_premature_step(self, env):
"""Test that you must call reset() before calling step()."""
bench_test.test_premature_step(
env, skip_fn=pytest.skip, raises_fn=pytest.raises,
)

def test_rollout_schema(self, env):
bench_test.test_rollout_schema(env)

def test_model_based(self, env):
"""Smoke test for each of the ModelBasedEnv methods with type checks."""
if not hasattr(env, "state_space"): # pragma: no cover
pytest.skip("This test is only for subclasses of ModelBasedEnv.")
test_envs.test_model_based(env)
imitation_test.test_model_based(env)


# pylint:enable=no-self-use

0 comments on commit faa59d0

Please sign in to comment.