From 8b24de60275061d6998e0714c7b4d15ad559b598 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 20 Nov 2020 14:28:38 -0500 Subject: [PATCH] Bugfix: Allow nesting of Sync/Async VectorEnvs Signed-off-by: Fabrice Normandin --- gym/vector/async_vector_env.py | 37 +++++++++---- gym/vector/sync_vector_env.py | 22 +++++--- gym/vector/tests/test_vector_env.py | 82 ++++++++++++++++++++++++++++- 3 files changed, 122 insertions(+), 19 deletions(-) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index b84fc2087d9..82b07b78a92 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -403,16 +403,24 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): assert shared_memory is None env = env_fn() parent_pipe.close() + + def step_fn(actions): + observation, reward, done, info = env.step(actions) + # Do nothing if the env is a VectorEnv, since it will automatically + # reset the envs that are done if needed in the 'step' method and return + # the initial observation instead of the final observation. + if not isinstance(env, VectorEnv) and done: + observation = env.reset() + return observation, reward, done, info + try: while True: command, data = pipe.recv() if command == "reset": observation = env.reset() pipe.send((observation, True)) - elif command == "step": - observation, reward, done, info = env.step(data) - if done: - observation = env.reset() + elif command == 'step': + observation, reward, done, info = step_fn(data) pipe.send(((observation, reward, done, info), True)) elif command == "seed": env.seed(data) @@ -440,6 +448,16 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error env = env_fn() observation_space = env.observation_space parent_pipe.close() + + def step_fn(actions): + observation, reward, done, info = env.step(actions) + # Do nothing if the env is a VectorEnv, since it will automatically + # reset the envs that are done if needed in the 'step' method and return + # the initial observation instead of the final observation. + if not isinstance(env, VectorEnv) and done: + observation = env.reset() + return observation, reward, done, info + try: while True: command, data = pipe.recv() @@ -449,13 +467,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error index, observation, shared_memory, observation_space ) pipe.send((None, True)) - elif command == "step": - observation, reward, done, info = env.step(data) - if done: - observation = env.reset() - write_to_shared_memory( - index, observation, shared_memory, observation_space - ) + elif command == 'step': + observation, reward, done, info = step_fn(data) + write_to_shared_memory(index, observation, shared_memory, + observation_space) pipe.send(((None, reward, done, info), True)) elif command == "seed": env.seed(data) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 49369073a72..699c87deb0d 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -44,11 +44,14 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True ) self._check_observation_spaces() - self.observations = create_empty_array( - self.single_observation_space, n=self.num_envs, fn=np.zeros - ) - self._rewards = np.zeros((self.num_envs,), dtype=np.float64) - self._dones = np.zeros((self.num_envs,), dtype=np.bool_) + self.observations = create_empty_array(self.single_observation_space, + n=self.num_envs, fn=np.zeros) + + shape = (self.num_envs,) + if isinstance(self.envs[0].unwrapped, VectorEnv): + shape += (self.envs[0].num_envs,) + self._rewards = np.zeros(shape, dtype=np.float64) + self._dones = np.zeros(shape, dtype=np.bool_) self._actions = None def seed(self, seeds=None): @@ -67,6 +70,7 @@ def reset_wait(self): for env in self.envs: observation = env.reset() observations.append(observation) + self.observations = concatenate( observations, self.observations, self.single_observation_space ) @@ -75,12 +79,15 @@ def reset_wait(self): def step_async(self, actions): self._actions = actions - + def step_wait(self): observations, infos = [], [] for i, (env, action) in enumerate(zip(self.envs, self._actions)): observation, self._rewards[i], self._dones[i], info = env.step(action) - if self._dones[i]: + # Do nothing if the env is a VectorEnv, since it will automatically + # reset the envs that are done if needed in the 'step' method and + # return the initial observation instead of the final observation. + if not isinstance(env, VectorEnv) and self._dones[i]: observation = env.reset() observations.append(observation) infos.append(info) @@ -95,6 +102,7 @@ def step_wait(self): infos, ) + def close_extras(self, **kwargs): [env.close() for env in self.envs] diff --git a/gym/vector/tests/test_vector_env.py b/gym/vector/tests/test_vector_env.py index bcbfd156fa2..8485a1a8e83 100644 --- a/gym/vector/tests/test_vector_env.py +++ b/gym/vector/tests/test_vector_env.py @@ -1,7 +1,9 @@ +from functools import partial import pytest import numpy as np -from gym.spaces import Tuple +from gym import spaces +from gym.spaces import Tuple, Box from gym.vector.tests.utils import CustomSpace, make_env from gym.vector.async_vector_env import AsyncVectorEnv @@ -54,3 +56,81 @@ def test_custom_space_vector_env(): assert isinstance(env.single_action_space, CustomSpace) assert isinstance(env.action_space, Tuple) + + +@pytest.mark.parametrize('base_env', ["CubeCrash-v0", "CartPole-v0"]) +@pytest.mark.parametrize('async_inner', [False, True]) +@pytest.mark.parametrize('async_outer', [False, True]) +@pytest.mark.parametrize('inner_envs', [1, 4, 7]) +@pytest.mark.parametrize('outer_envs', [1, 4, 7]) +def test_nesting_vector_envs(base_env: str, + async_inner: bool, + async_outer: bool, + inner_envs: int, + outer_envs: int): + inner_vector_wrapper = AsyncVectorEnv if async_inner else SyncVectorEnv + # When nesting AsyncVectorEnvs, only the "innermost" envs can have + # `daemon=True`, otherwise the "daemonic processes are not allowed to have + # children" AssertionError is raised in `multiprocessing.process`. + outer_vector_wrapper = ( + partial(AsyncVectorEnv, daemon=False) if async_outer + else SyncVectorEnv + ) + + env = outer_vector_wrapper([ # type: ignore + partial(inner_vector_wrapper, [ + make_env(base_env, inner_envs * i + j) for j in range(inner_envs) + ]) for i in range(outer_envs) + ]) + + # Create a single test environment. + with make_env(base_env, 0)() as temp_single_env: + single_observation_space = temp_single_env.observation_space + single_action_space = temp_single_env.action_space + + assert isinstance(single_observation_space, Box) + assert isinstance(env.observation_space, Box) + assert env.observation_space.shape == (outer_envs, inner_envs, *single_observation_space.shape) + assert env.observation_space.dtype == single_observation_space.dtype + + assert isinstance(env.action_space, spaces.Tuple) + assert len(env.action_space.spaces) == outer_envs + assert all( + isinstance(outer_action_space, spaces.Tuple) and + len(outer_action_space.spaces) == inner_envs + for outer_action_space in env.action_space.spaces + ) + assert all([ + len(inner_action_space.spaces) == inner_envs + for inner_action_space in env.action_space.spaces + ]) + assert all([ + inner_action_space.spaces[i] == single_action_space + for inner_action_space in env.action_space.spaces + for i in range(inner_envs) + ]) + + with env: + observations = env.reset() + assert observations in env.observation_space + + actions = env.action_space.sample() + assert actions in env.action_space + + observations, rewards, dones, _ = env.step(actions) + assert observations in env.observation_space + + assert isinstance(env.observation_space, Box) + assert isinstance(observations, np.ndarray) + assert observations.dtype == env.observation_space.dtype + assert observations.shape == (outer_envs, inner_envs) + single_observation_space.shape + + assert isinstance(rewards, np.ndarray) + assert isinstance(rewards[0], np.ndarray) + assert rewards.ndim == 2 + assert rewards.shape == (outer_envs, inner_envs) + + assert isinstance(dones, np.ndarray) + assert dones.dtype == np.bool_ + assert dones.ndim == 2 + assert dones.shape == (outer_envs, inner_envs)