diff --git a/requirements.txt b/requirements.txt index 59b67d7..122d0b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -gym[mujoco] +benchmark-environments @ git+https://github.com/HumanCompatibleAI/benchmark-environments.git imitation @ git+https://github.com/HumanCompatibleAI/imitation.git@e99844 +stable-baselines @ git+https://github.com/hill-a/stable-baselines.git +gym[mujoco] matplotlib numpy pandas @@ -7,6 +9,5 @@ pymdptoolbox seaborn setuptools scipy -stable-baselines @ git+https://github.com/hill-a/stable-baselines.git tensorflow>=1.13,<1.14 xarray diff --git a/setup.py b/setup.py index 8e97310..df11f61 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def load_requirements(fname): package_data={"evaluating_rewards": ["py.typed"]}, install_requires=load_requirements("requirements.txt"), extras_require={"test": load_requirements("requirements-dev.txt")}, - url="https://github.com/AdamGleave/evaluating_rewards", + url="https://github.com/HumanCompatibleAI/evaluating_rewards", license="Apache License, Version 2.0", classifiers=[ # Trove classifiers diff --git a/src/evaluating_rewards/envs/__init__.py b/src/evaluating_rewards/envs/__init__.py index 18188fc..3134c8a 100644 --- a/src/evaluating_rewards/envs/__init__.py +++ b/src/evaluating_rewards/envs/__init__.py @@ -48,7 +48,7 @@ def register_similar(existing_name: str, new_name: str, **kwargs_delta): **kwargs_delta: Arguments to override the specification from existing_id. """ existing_spec = gym.spec(existing_name) - fields = ["entry_point", "reward_threshold", "nondeterministic", "tags", "max_episode_steps"] + fields = ["entry_point", "reward_threshold", "nondeterministic", "max_episode_steps"] kwargs = {k: getattr(existing_spec, k) for k in fields} kwargs["kwargs"] = existing_spec._kwargs # pylint:disable=protected-access kwargs.update(**kwargs_delta) diff --git a/tests/common.py b/tests/common.py index 8e53f7d..fdd20da 100644 --- a/tests/common.py +++ b/tests/common.py @@ -17,7 +17,7 @@ import copy from typing import Dict, Iterator, Tuple, TypeVar -from imitation.testing import envs as test_envs +from benchmark_environments.testing import envs as test_envs import pytest from stable_baselines.common import vec_env diff --git a/tests/test_envs.py b/tests/test_envs.py index 768004e..afd312d 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -18,8 +18,9 @@ "evaluating_rewards/". """ +from benchmark_environments.testing import envs as bench_test import gym -from imitation.testing import envs as test_envs +from imitation.testing import envs as imitation_test import pytest from evaluating_rewards import envs # noqa: F401 pylint:disable=unused-import @@ -38,16 +39,22 @@ class TestEnvs: """Simple smoke tests for custom environments.""" def test_seed(self, env, env_name): - test_envs.test_seed(env, env_name, DETERMINISTIC_ENVS) + bench_test.test_seed(env, env_name, DETERMINISTIC_ENVS) - def test_rollout(self, env): - test_envs.test_rollout(env) + def test_premature_step(self, env): + """Test that you must call reset() before calling step().""" + bench_test.test_premature_step( + env, skip_fn=pytest.skip, raises_fn=pytest.raises, + ) + + def test_rollout_schema(self, env): + bench_test.test_rollout_schema(env) def test_model_based(self, env): """Smoke test for each of the ModelBasedEnv methods with type checks.""" if not hasattr(env, "state_space"): # pragma: no cover pytest.skip("This test is only for subclasses of ModelBasedEnv.") - test_envs.test_model_based(env) + imitation_test.test_model_based(env) # pylint:enable=no-self-use