Skip to content

Commit

Permalink
Bugfix: Allow nesting of Sync/Async VectorEnvs
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Aug 3, 2021
1 parent 3344918 commit 8b24de6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 19 deletions.
37 changes: 26 additions & 11 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,24 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
assert shared_memory is None
env = env_fn()
parent_pipe.close()

def step_fn(actions):
observation, reward, done, info = env.step(actions)
# Do nothing if the env is a VectorEnv, since it will automatically
# reset the envs that are done if needed in the 'step' method and return
# the initial observation instead of the final observation.
if not isinstance(env, VectorEnv) and done:
observation = env.reset()
return observation, reward, done, info

try:
while True:
command, data = pipe.recv()
if command == "reset":
observation = env.reset()
pipe.send((observation, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:
observation = env.reset()
elif command == 'step':
observation, reward, done, info = step_fn(data)
pipe.send(((observation, reward, done, info), True))
elif command == "seed":
env.seed(data)
Expand Down Expand Up @@ -440,6 +448,16 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
env = env_fn()
observation_space = env.observation_space
parent_pipe.close()

def step_fn(actions):
observation, reward, done, info = env.step(actions)
# Do nothing if the env is a VectorEnv, since it will automatically
# reset the envs that are done if needed in the 'step' method and return
# the initial observation instead of the final observation.
if not isinstance(env, VectorEnv) and done:
observation = env.reset()
return observation, reward, done, info

try:
while True:
command, data = pipe.recv()
Expand All @@ -449,13 +467,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
index, observation, shared_memory, observation_space
)
pipe.send((None, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:
observation = env.reset()
write_to_shared_memory(
index, observation, shared_memory, observation_space
)
elif command == 'step':
observation, reward, done, info = step_fn(data)
write_to_shared_memory(index, observation, shared_memory,
observation_space)
pipe.send(((None, reward, done, info), True))
elif command == "seed":
env.seed(data)
Expand Down
22 changes: 15 additions & 7 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True
)

self._check_observation_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self.observations = create_empty_array(self.single_observation_space,
n=self.num_envs, fn=np.zeros)

shape = (self.num_envs,)
if isinstance(self.envs[0].unwrapped, VectorEnv):
shape += (self.envs[0].num_envs,)
self._rewards = np.zeros(shape, dtype=np.float64)
self._dones = np.zeros(shape, dtype=np.bool_)
self._actions = None

def seed(self, seeds=None):
Expand All @@ -67,6 +70,7 @@ def reset_wait(self):
for env in self.envs:
observation = env.reset()
observations.append(observation)

self.observations = concatenate(
observations, self.observations, self.single_observation_space
)
Expand All @@ -75,12 +79,15 @@ def reset_wait(self):

def step_async(self, actions):
self._actions = actions

def step_wait(self):
observations, infos = [], []
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]:
# Do nothing if the env is a VectorEnv, since it will automatically
# reset the envs that are done if needed in the 'step' method and
# return the initial observation instead of the final observation.
if not isinstance(env, VectorEnv) and self._dones[i]:
observation = env.reset()
observations.append(observation)
infos.append(info)
Expand All @@ -95,6 +102,7 @@ def step_wait(self):
infos,
)


def close_extras(self, **kwargs):
[env.close() for env in self.envs]

Expand Down
82 changes: 81 additions & 1 deletion gym/vector/tests/test_vector_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import partial
import pytest
import numpy as np

from gym.spaces import Tuple
from gym import spaces
from gym.spaces import Tuple, Box
from gym.vector.tests.utils import CustomSpace, make_env

from gym.vector.async_vector_env import AsyncVectorEnv
Expand Down Expand Up @@ -54,3 +56,81 @@ def test_custom_space_vector_env():

assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)


@pytest.mark.parametrize('base_env', ["CubeCrash-v0", "CartPole-v0"])
@pytest.mark.parametrize('async_inner', [False, True])
@pytest.mark.parametrize('async_outer', [False, True])
@pytest.mark.parametrize('inner_envs', [1, 4, 7])
@pytest.mark.parametrize('outer_envs', [1, 4, 7])
def test_nesting_vector_envs(base_env: str,
async_inner: bool,
async_outer: bool,
inner_envs: int,
outer_envs: int):
inner_vector_wrapper = AsyncVectorEnv if async_inner else SyncVectorEnv
# When nesting AsyncVectorEnvs, only the "innermost" envs can have
# `daemon=True`, otherwise the "daemonic processes are not allowed to have
# children" AssertionError is raised in `multiprocessing.process`.
outer_vector_wrapper = (
partial(AsyncVectorEnv, daemon=False) if async_outer
else SyncVectorEnv
)

env = outer_vector_wrapper([ # type: ignore
partial(inner_vector_wrapper, [
make_env(base_env, inner_envs * i + j) for j in range(inner_envs)
]) for i in range(outer_envs)
])

# Create a single test environment.
with make_env(base_env, 0)() as temp_single_env:
single_observation_space = temp_single_env.observation_space
single_action_space = temp_single_env.action_space

assert isinstance(single_observation_space, Box)
assert isinstance(env.observation_space, Box)
assert env.observation_space.shape == (outer_envs, inner_envs, *single_observation_space.shape)
assert env.observation_space.dtype == single_observation_space.dtype

assert isinstance(env.action_space, spaces.Tuple)
assert len(env.action_space.spaces) == outer_envs
assert all(
isinstance(outer_action_space, spaces.Tuple) and
len(outer_action_space.spaces) == inner_envs
for outer_action_space in env.action_space.spaces
)
assert all([
len(inner_action_space.spaces) == inner_envs
for inner_action_space in env.action_space.spaces
])
assert all([
inner_action_space.spaces[i] == single_action_space
for inner_action_space in env.action_space.spaces
for i in range(inner_envs)
])

with env:
observations = env.reset()
assert observations in env.observation_space

actions = env.action_space.sample()
assert actions in env.action_space

observations, rewards, dones, _ = env.step(actions)
assert observations in env.observation_space

assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (outer_envs, inner_envs) + single_observation_space.shape

assert isinstance(rewards, np.ndarray)
assert isinstance(rewards[0], np.ndarray)
assert rewards.ndim == 2
assert rewards.shape == (outer_envs, inner_envs)

assert isinstance(dones, np.ndarray)
assert dones.dtype == np.bool_
assert dones.ndim == 2
assert dones.shape == (outer_envs, inner_envs)

0 comments on commit 8b24de6

Please sign in to comment.