Skip to content

Commit

Permalink
Removed hard-coded check to CartPole Space shape
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Aug 3, 2021
1 parent 2921a2c commit b274060
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions gym/vector/tests/test_batched_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from gym import spaces
from gym.vector.batched_vector_env import BatchedVectorEnv
from gym.vector.utils import batch_space

N_CPUS = cpu_count()

Expand Down Expand Up @@ -72,10 +73,11 @@ def reset(self):

@pytest.mark.parametrize("batch_size", [1, 5, 11, 24])
@pytest.mark.parametrize("n_workers", [1, 3, None])
def test_space_with_tuple_observations(batch_size: int, n_workers: Optional[int]):
@pytest.mark.parametrize("base_env", ["CartPole-v0", "Pendulum-v0"])
def test_space_with_tuple_observations(batch_size: int, n_workers: Optional[int], base_env: str):

def make_env():
env = gym.make("CartPole-v0")
env = gym.make(base_env)
env = TupleObservationsWrapper(env, spaces.Discrete(1))
return env

Expand All @@ -84,17 +86,27 @@ def make_env():
env = BatchedVectorEnv(env_fns, n_workers=n_workers)
env.seed(123)

assert env.single_observation_space[0].shape == (4,)
assert env.single_observation_space[1] == spaces.Discrete(1)

assert env.observation_space[0].shape == (batch_size, 4)
assert env.observation_space[1] == spaces.MultiDiscrete(np.ones(batch_size))
single_env = make_env()
assert env.single_observation_space == single_env.observation_space
assert env.single_action_space == single_env.action_space
assert env.observation_space == batch_space(single_env.observation_space, batch_size)
# NOTE: VectorEnvs currently have a tuple of actions as the action space,
# rather than using the result of the batch_space function.
# assert env.action_space == batch_space(single_env.action_space, batch_size)

single_env_obs = single_env.reset()
obs = env.reset()
assert obs[0].shape == env.observation_space[0].shape
assert obs[1].shape == env.observation_space[1].shape
assert isinstance(obs, tuple)
assert len(obs) == len(single_env_obs)

for i, (obs_item, single_obs_item) in enumerate(zip(obs, single_env_obs)):
assert obs_item.shape == (batch_size, *np.asanyarray(single_obs_item).shape)

assert obs in env.observation_space
assert single_env_obs in env.single_observation_space

single_action = single_env.action_space.sample()
assert single_action in env.single_action_space
actions = env.action_space.sample()
step_obs, rewards, done, info = env.step(actions)
assert step_obs in env.observation_space
Expand Down Expand Up @@ -187,14 +199,14 @@ def step(self, action):
reward, done = 0., False
return (observation, reward, done, {})


@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("n_workers", [1, 4])
def test_dict_observations(batch_size: int, n_workers: int):
def make_env(seed):
def _make_env():
return DictEnv()
return _make_env
from gym.vector.utils import batch_space

single_env = make_env(0)()

Expand Down

0 comments on commit b274060

Please sign in to comment.