diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index c42a25caf4f..3a1d938d824 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -3,11 +3,19 @@ except ImportError: Iterable = (tuple, list) +from gym.vector.batched_vector_env import BatchedVectorEnv from gym.vector.async_vector_env import AsyncVectorEnv from gym.vector.sync_vector_env import SyncVectorEnv from gym.vector.vector_env import VectorEnv, VectorEnvWrapper -__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] +__all__ = [ + "BatchedVectorEnv", + "AsyncVectorEnv", + "SyncVectorEnv", + "VectorEnv", + "VectorEnvWrapper", + "make", +] def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs): diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index b84fc2087d9..9446d8be088 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -403,6 +403,16 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): assert shared_memory is None env = env_fn() parent_pipe.close() + + def step_fn(actions): + observation, reward, done, info = env.step(actions) + # Do nothing if the env is a VectorEnv, since it will automatically + # reset the envs that are done if needed in the 'step' method and return + # the initial observation instead of the final observation. + if not isinstance(env, VectorEnv) and done: + observation = env.reset() + return observation, reward, done, info + try: while True: command, data = pipe.recv() @@ -440,6 +450,16 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error env = env_fn() observation_space = env.observation_space parent_pipe.close() + + def step_fn(actions): + observation, reward, done, info = env.step(actions) + # Do nothing if the env is a VectorEnv, since it will automatically + # reset the envs that are done if needed in the 'step' method and return + # the initial observation instead of the final observation. + if not isinstance(env, VectorEnv) and done: + observation = env.reset() + return observation, reward, done, info + try: while True: command, data = pipe.recv() @@ -451,7 +471,8 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error pipe.send((None, True)) elif command == "step": observation, reward, done, info = env.step(data) - if done: + # BUG: See PR #2104: Currently unable to nest VectorEnvs because of this + if (done if isinstance(done, bool) else all(done)): observation = env.reset() write_to_shared_memory( index, observation, shared_memory, observation_space diff --git a/gym/vector/batched_vector_env.py b/gym/vector/batched_vector_env.py new file mode 100644 index 00000000000..967f9d6f2d1 --- /dev/null +++ b/gym/vector/batched_vector_env.py @@ -0,0 +1,386 @@ +""" Mix of AsyncVectorEnv and SyncVectorEnv, with support for 'chunking' and for +where we have a series of environments on each worker. +""" +import itertools +import math +import multiprocessing as mp +from collections import OrderedDict +from functools import partial +from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence, + Tuple, TypeVar, Union) +import numpy as np + +from gym import spaces, Env, Space +from gym.spaces.utils import flatten, unflatten +from gym.vector.async_vector_env import AsyncVectorEnv +from gym.vector.sync_vector_env import SyncVectorEnv +from gym.vector.utils import batch_space, concatenate, create_empty_array +from gym.vector.vector_env import VectorEnv + +T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") +M = TypeVar("M") + + +class BatchedVectorEnv(VectorEnv): + """ Batched vectorized environment. + + Adds the following features, compared to using the vectorized Async and Sync + VectorEnvs: + + - Chunking: Running more than one environment per worker. This is done by + passing `SyncVectorEnv`s as the env_fns to the `AsyncVectorEnv`. + + - Flexible batch size: Supports any number of environments, irrespective + of the number of workers or of CPUs. The number of environments will be + spread out as equally as possible between the workers. + + For example, if you want to have a batch_size of 17, and n_workers is 6, + then the number of environments per worker will be: [3, 3, 3, 3, 3, 2]. + + Internally, this works by creating up to two AsyncVectorEnvs, env_a and + env_b. If the number of envs (batch_size) isn't a multiple of the number + of workers, then we create the second AsyncVectorEnv (env_b). + + In the first environment (env_a), each env will contain + ceil(n_envs / n_workers) each. If env_b is needed, then each of its envs + will contain floor(n_envs / n_workers) environments. + + The observations/actions/rewards are reshaped to be (n_envs, *shape), i.e. + they don't have an extra 'chunk' dimension. + + - When some environments have `done=True` while stepping, those + environments are reset, as was done previously. Additionally, the final + observation for those environments is placed in the info dict at key + FINAL_STATE_KEY (currently 'final_state'). + """ + def __init__(self, + env_fns, + n_workers: int = None, + **kwargs): + assert env_fns, "need at least one env_fn." + self.batch_size: int = len(env_fns) + + # Use one of the env_fns to get the observation/action space. + with env_fns[0]() as temp_env: + single_observation_space = temp_env.observation_space + single_action_space = temp_env.action_space + self.reward_range = temp_env.reward_range + del temp_env + + super().__init__( + num_envs=self.batch_size, + observation_space=single_observation_space, + action_space=single_action_space, + ) + + if n_workers is None: + n_workers = mp.cpu_count() + self.n_workers: int = n_workers + + if self.n_workers > self.batch_size: + self.n_workers = self.batch_size + + # Divide the env_fns as evenly as possible between the workers. + groups = distribute(env_fns, self.n_workers) + + # Find the first index where the group has a different length. + self.chunk_length_a = len(groups[0]) + self.chunk_length_b = 0 + + # First, assume there is no need for another environment (all the + # groups have the same length). + self.start_index_b = self.n_workers + for i, group in enumerate(groups): + if len(group) != self.chunk_length_a: + self.start_index_b = i + self.chunk_length_b = len(group) + break + + # Total number of envs in each environment. + self.n_a = sum(map(len, groups[:self.start_index_b])) + self.n_b = sum(map(len, groups[self.start_index_b:])) + + # Create a SyncVectorEnv per group. + chunk_env_fns: List[Callable[[], Env]] = [ + partial(SyncVectorEnv, env_fns_group) for env_fns_group in groups + ] + env_a_fns = chunk_env_fns[:self.start_index_b] + env_b_fns = chunk_env_fns[self.start_index_b:] + # Create the AsyncVectorEnvs. + self.env_a = AsyncVectorEnv(env_fns=env_a_fns, **kwargs) + self.env_b: Optional[AsyncVectorEnv] = None + if env_b_fns: + self.env_b = AsyncVectorEnv(env_fns=env_b_fns, **kwargs) + + # Unbatch & join the observations/actions spaces. + + def reset_async(self): + self.env_a.reset_async() + if self.env_b: + self.env_b.reset_async() + + def reset_wait(self, timeout=None, **kwargs): + obs_a = self.env_a.reset_wait(timeout=timeout) + obs_a = unroll(self.single_observation_space, chunks=obs_a) + obs = (obs_a,) + if self.env_b: + obs_b = self.env_b.reset_wait(timeout=timeout) + obs_b = unroll(self.single_observation_space, chunks=obs_b) + obs = (obs_a, obs_b) + observations = fuse_and_batch( + self.single_observation_space, + *obs, + n_items=self.n_a + self.n_b, + ) + return observations + + def step_async(self, action: Sequence) -> None: + if self.env_b: + flat_actions_a, flat_actions_b = action[:self.n_a], action[self.n_a:] + actions_a = chunk(flat_actions_a, self.chunk_length_a) + actions_b = chunk(flat_actions_b, self.chunk_length_b) + self.env_a.step_async(actions_a) + self.env_b.step_async(actions_b) + else: + action = chunk(action, self.chunk_length_a) + self.env_a.step_async(action) + + def step_wait(self, timeout: Union[int, float] = None, **kwargs): + obs_a, rew_a, done_a, info_a = self.env_a.step_wait(timeout) + obs_a = unroll(self.single_observation_space, chunks=obs_a) + rew_a = unroll(None, chunks=rew_a) + done_a = unroll(None, chunks=done_a) + info_a = unroll(None, chunks=info_a) + obs = [obs_a] + + rew_b = [] + done_b = [] + info_b = [] + if self.env_b: + obs_b, rew_b, done_b, info_b = self.env_b.step_wait(timeout) + obs_b = unroll(self.single_observation_space, chunks=obs_b) + rew_b = unroll(None, chunks=rew_b) + done_b = unroll(None, chunks=done_b) + info_b = unroll(None, chunks=info_b) + obs = [obs_a, obs_b] + + observations = fuse_and_batch( + self.single_observation_space, + *obs, + n_items=self.n_a + self.n_b, + ) + + rewards = np.array(rew_a + rew_b) + done = np.array(done_a + done_b) + # NOTE: The 'info' dict isn't batched, it is a list of dicts for each + # environment. + info = info_a + info_b + return observations, rewards, done, info + + def seed(self, seeds: Union[int, Sequence[Optional[int]]] = None): + if seeds is None: + seeds = [None for _ in range(self.batch_size)] + if isinstance(seeds, int): + seeds = [seeds + i for i in range(self.batch_size)] + assert len(seeds) == self.batch_size + + seeds_a = chunk(seeds[:self.n_a], self.chunk_length_a) + seeds_b = chunk(seeds[self.n_a:], self.chunk_length_b) + self.env_a.seed(seeds_a) + if self.env_b: + self.env_b.seed(seeds_b) + + def close_extras(self, **kwargs): + self.env_a.close_extras(**kwargs) + if self.env_b: + self.env_b.close_extras(**kwargs) + + def render(self, mode: str = "rgb_array"): + chunked_images_a = self.env_a.render(mode="rgb_array") + images_a: List[np.ndarray] = unroll(None, chunks=chunked_images_a) + images = images_a + + if self.env_b: + chunked_images_b = self.env_b.render(mode="rgb_array") + images_b = unroll(None, chunks=chunked_images_b) + images.extend(images_b) + + image_batch = np.stack(images) + + if mode == "rgb_array": + return image_batch + + raise NotImplementedError(f"Render mode {mode} isn't implemented yet.") + + # NOTE: This is only here for illustration purposes. + # if mode == "human": + # See PR #1624 for the tile_images function. + # tiled_version = tile_images(image_batch) + # if self.viewer is None: + # from gym.envs.classic_control import rendering + # self.viewer = rendering.SimpleImageViewer() + # self.viewer.imshow(tiled_version) + # return self.viewer.isopen + + +def distribute(values: Sequence[T], n_groups: int) -> List[Sequence[T]]: + """ Distribute the values 'values' as evenly as possible into n_groups. + + >>> distribute(list(range(14)), 5) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13]] + >>> distribute(list(range(9)), 4) + [[0, 1, 2], [3, 4], [5, 6], [7, 8]] + >>> import numpy as np + >>> distribute(np.arange(9), 4) + [array([0, 1, 2]), array([3, 4]), array([5, 6]), array([7, 8])] + """ + n_values = len(values) + # Determine the final lengths of each group. + min_values_per_group = math.floor(n_values / n_groups) + max_values_per_group = math.ceil(n_values / n_groups) + remainder = n_values % n_groups + group_lengths = [ + max_values_per_group if i < remainder else min_values_per_group + for i in range(n_groups) + ] + # Equivalent, but maybe a tiny bit slower: + # group_lengths: List[int] = [0 for _ in range(n_groups)] + # for i in range(len(values)): + # group_lengths[i % n_groups] += 1 + groups: List[Sequence[T]] = [[] for _ in range(n_groups)] + + start_index = 0 + for i, group_length in enumerate(group_lengths): + end_index = start_index + group_length + groups[i] = values[start_index:end_index] + start_index += group_length + return groups + + +def chunk(values: Sequence[T], chunk_length: int) -> Sequence[Sequence[T]]: + """ Add the 'chunk'/second batch dimension to the list of items. + + NOTE: I don't think this would work with tuples as inputs, but it hasn't + been a problem yet because the action/reward spaces haven't been tuples yet. + """ + groups = list(n_consecutive(values, chunk_length)) + if isinstance(values, np.ndarray): + groups = np.array(groups) + return groups + + +from functools import singledispatch +from gym.vector.utils.spaces import _BaseGymSpaces + + +@singledispatch +def unroll(item_space: Optional[Space], chunks: Sequence[Sequence[T]]) -> Sequence[T]: + """ Unroll the given chunks, returning a list of individual items. + + This is the inverse operation of the `chunk` function. + """ + if isinstance(chunks, np.ndarray): + # Remove the 'chunk' dimension. + return list(chunks.reshape([-1, *chunks.shape[2:]])) + + values: List[T] = [] + for chunk in chunks: + values.extend(chunk) + return values + + +@unroll.register(spaces.Dict) # type: ignore +def unroll_dict(item_space: spaces.Dict, chunks: Sequence[Dict[K, V]]) -> Dict[K, Sequence[V]]: + assert isinstance(chunks, dict), chunks + return OrderedDict( + (key, unroll(item_space[key], chunks=values)) + for key, values in chunks.items() + ) + + +@unroll.register(spaces.Tuple) # type: ignore +def unroll_tuple(item_space: spaces.Tuple, chunks: Sequence[Tuple]) -> Tuple[Sequence]: + # 'flatten out' the chunks for each index. The returned value will be a + # tuple of lists of samples. + chunked_items = list(zip(chunks)) + return tuple([ + unroll(item_space.spaces[i], chunks=chunk_items_i) + for i, chunk_items_i in enumerate(chunked_items) + ]) + + +@singledispatch +def fuse_and_batch(item_space: spaces.Space, *sequences, n_items: int): + """Concatenate sequences of items, and then fuse them into a single batch. + """ + out = create_empty_array(item_space, n=n_items) + # Concatenate the batches into a single batch of samples. + items_batch = np.concatenate([ + np.asarray(v).reshape([-1, *item_space.shape]) + for v in itertools.chain(*filter(len, sequences)) + ]) + # # Split this batch of samples into a list of items from each space. + items = [ + v.reshape(item_space.shape) for v in np.split(items_batch, n_items) + ] + # TODO: Need to add more tests to make sure this works with custom spaces and Dict spaces. + return concatenate(items, out, item_space) + + +@fuse_and_batch.register(spaces.Dict) +def fuse_and_batch_dicts(item_space: spaces.Dict, *sequences: Dict[K, List[V]], n_items: int) -> Dict[K, Sequence[V]]: + fused_values: Dict[K, Sequence[V]] = OrderedDict() + assert all(isinstance(sequence, dict) for sequence in sequences) + for key, zipped_values in zip_dicts(*sequences, missing=None): + fused_values[key] = fuse_and_batch(item_space[key], *zipped_values, n_items=n_items) + return fused_values + + +@fuse_and_batch.register(spaces.Tuple) +def fuse_and_batch_tuples(item_space: spaces.Tuple, *sequences: Sequence[Tuple[T, ...]], n_items: int) -> Tuple[Sequence[T], ...]: + # Join the list of items for each subspace of the item_space Tuple. + joined_sequences: Sequence[List[T]] = [ + sum(items, []) for items in itertools.zip_longest(*sequences, fillvalue=[]) + ] + return tuple( + fuse_and_batch(space, sequence, n_items=n_items) + for space, sequence in zip(item_space.spaces, joined_sequences) + # np.concatenate(sequence) for sequence in joined_sequences + ) + + +def n_consecutive(items: Iterable[T], n: int=2, yield_last_batch=True) -> Iterable[Tuple[T, ...]]: + """Collect data into chunks of up to `n` elements. + + When `yield_last_batch` is True, the final chunk (which might have fewer + than `n` items) will also be yielded. + + >>> list(n_consecutive("ABCDEFG", 3)) + [["A", "B", "C"], ["D", "E", "F"], ["G"]] + """ + values: List[T] = [] + for item in items: + values.append(item) + if len(values) == n: + yield tuple(values) + values.clear() + if values and yield_last_batch: + yield tuple(values) + + +def zip_dicts(*dicts: Dict[K, V], missing: M) -> Iterable[Tuple[K, Tuple[Union[M, V], ...]]]: + """Iterator over the union of all keys, giving the value from each dict if + present, else `missing`. + """ + # If any attributes are common to both the Experiment and the State, + # copy them over to the Experiment. + keys = set(itertools.chain(*dicts)) + for key in keys: + yield (key, tuple(d.get(key, missing) for d in dicts)) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 49369073a72..c53927863fb 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -67,27 +67,27 @@ def reset_wait(self): for env in self.envs: observation = env.reset() observations.append(observation) - self.observations = concatenate( - observations, self.observations, self.single_observation_space - ) - + self.observations = concatenate(observations, self.observations, + self.single_observation_space) return deepcopy(self.observations) if self.copy else self.observations def step_async(self, actions): self._actions = actions - + def step_wait(self): observations, infos = [], [] for i, (env, action) in enumerate(zip(self.envs, self._actions)): observation, self._rewards[i], self._dones[i], info = env.step(action) - if self._dones[i]: + # Do nothing if the env is a VectorEnv, since it will automatically + # reset the envs that are done if needed in the 'step' method and + # return the initial observation instead of the final observation. + if not isinstance(env, VectorEnv) and self._dones[i]: observation = env.reset() observations.append(observation) infos.append(info) self.observations = concatenate( observations, self.observations, self.single_observation_space ) - return ( deepcopy(self.observations) if self.copy else self.observations, np.copy(self._rewards), @@ -95,6 +95,7 @@ def step_wait(self): infos, ) + def close_extras(self, **kwargs): [env.close() for env in self.envs] diff --git a/gym/vector/tests/test_batched_vector_env.py b/gym/vector/tests/test_batched_vector_env.py new file mode 100644 index 00000000000..ebbb0de2283 --- /dev/null +++ b/gym/vector/tests/test_batched_vector_env.py @@ -0,0 +1,219 @@ +from functools import partial +from multiprocessing import cpu_count +from typing import Optional + +import gym +import numpy as np +import pytest +from gym import spaces +from gym.vector.batched_vector_env import BatchedVectorEnv +from gym.vector.utils import batch_space + +N_CPUS = cpu_count() + +class DummyEnvironment(gym.Env): + """ Dummy environment for testing. + + The reward is how close to the target value the state (a counter) is. The + actions are: + 0: keep the counter the same. + 1: Increment the counter. + 2: Decrement the counter. + """ + def __init__(self, start: int = 0, max_value: int = 10, target: int = 5): + self.max_value = max_value + self.i = start + self.start = start + self.reward_range = (0, max_value) + self.action_space = gym.spaces.Discrete(n=3) # type: ignore + self.observation_space = gym.spaces.Discrete(n=max_value) # type: ignore + + self.target = target + self.reward_range = (0, max(target, max_value - target)) + + self.done: bool = False + self._reset: bool = False + + def step(self, action: int): + # The action modifies the state, producing a new state, and you get the + # reward associated with that transition. + if not self._reset: + raise RuntimeError("Need to reset before you can step.") + if action == 1: + self.i += 1 + elif action == 2: + self.i -= 1 + self.i %= self.max_value + done = (self.i == self.target) + reward = abs(self.i - self.target) + print(self.i, reward, done, action) + return self.i, reward, done, {} + + def reset(self): + self._reset = True + self.i = self.start + return self.i + + +class TupleObservationsWrapper(gym.Wrapper): + def __init__(self, env: gym.Env, second_space: gym.Space): + super().__init__(env) + self.observation_space: gym.Space = spaces.Tuple([ + env.observation_space, + second_space, + ]) + def step(self, action): + observation, reward, done, info = self.env.step(action) + return (observation, self.observation_space[1].sample()), reward, done, info + + def reset(self): + observation = self.env.reset() + return (observation, self.observation_space[1].sample()) + + +@pytest.mark.parametrize("batch_size", [1, 5, 11, 24]) +@pytest.mark.parametrize("n_workers", [1, 3, None]) +@pytest.mark.parametrize("base_env", ["CartPole-v0", "Pendulum-v0"]) +def test_space_with_tuple_observations(batch_size: int, n_workers: Optional[int], base_env: str): + + def make_env(): + env = gym.make(base_env) + env = TupleObservationsWrapper(env, spaces.Discrete(1)) + return env + + env_fn = make_env + env_fns = [env_fn for _ in range(batch_size)] + env = BatchedVectorEnv(env_fns, n_workers=n_workers) + env.seed(123) + + single_env = make_env() + assert env.single_observation_space == single_env.observation_space + assert env.single_action_space == single_env.action_space + assert env.observation_space == batch_space(single_env.observation_space, batch_size) + # NOTE: VectorEnvs currently have a tuple of actions as the action space, + # rather than using the result of the batch_space function. + # assert env.action_space == batch_space(single_env.action_space, batch_size) + + single_env_obs = single_env.reset() + obs = env.reset() + assert isinstance(obs, tuple) + assert len(obs) == len(single_env_obs) + + for i, (obs_item, single_obs_item) in enumerate(zip(obs, single_env_obs)): + assert obs_item.shape == (batch_size, *np.asanyarray(single_obs_item).shape) + + assert obs in env.observation_space + assert single_env_obs in env.single_observation_space + + single_action = single_env.action_space.sample() + assert single_action in env.single_action_space + actions = env.action_space.sample() + step_obs, rewards, done, info = env.step(actions) + assert step_obs in env.observation_space + + assert len(rewards) == batch_size + assert len(done) == batch_size + assert all([isinstance(v, bool) for v in done.tolist()]), [type(v) for v in done] + assert len(info) == batch_size + + +@pytest.mark.parametrize("batch_size", [1, 5, N_CPUS, 11, 24]) +@pytest.mark.parametrize("n_workers", [1, 3, N_CPUS]) +def test_right_shapes(batch_size: int, n_workers: Optional[int]): + env_fn = partial(gym.make, "CartPole-v0") + env_fns = [env_fn for _ in range(batch_size)] + env = BatchedVectorEnv(env_fns, n_workers=n_workers) + env.seed(123) + + assert env.observation_space.shape == (batch_size, 4) + assert isinstance(env.action_space, spaces.Tuple) + assert len(env.action_space) == batch_size + + obs = env.reset() + assert obs.shape == (batch_size, 4) + + for i in range(3): + actions = env.action_space.sample() + assert actions in env.action_space + obs, rewards, done, info = env.step(actions) + assert obs.shape == (batch_size, 4) + assert len(rewards) == batch_size + assert len(done) == batch_size + assert len(info) == batch_size + + env.close() + + +@pytest.mark.parametrize("batch_size", [1, 2, 5, N_CPUS, 10, 24]) +def test_ordering_of_env_fns_preserved(batch_size): + """ Test that the order of the env_fns is also reproduced in the order of + the observations, and that the actions are sent to the right environments. + """ + target = 50 + env_fns = [ + partial(DummyEnvironment, start=i, target=target, max_value=100) + for i in range(batch_size) + ] + env = BatchedVectorEnv(env_fns, n_workers=4) + env.seed(123) + obs = env.reset() + assert obs.tolist() == list(range(batch_size)) + + obs, reward, done, info = env.step(np.zeros(batch_size)) + assert obs.tolist() == list(range(batch_size)) + # Increment only the 'counters' at even indices. + actions = [ + int(i % 2 == 0) for i in range(batch_size) + ] + obs, reward, done, info = env.step(actions) + even = np.arange(batch_size) % 2 == 0 + odd = np.arange(batch_size) % 2 == 1 + assert obs[even].tolist() == (np.arange(batch_size) + 1)[even].tolist() + assert obs[odd].tolist() == np.arange(batch_size)[odd].tolist(), (obs, obs[odd], actions) + assert reward.tolist() == (np.ones(batch_size) * target - obs).tolist() + + env.close() + + +import gym +import numpy as np + +from collections import OrderedDict +from gym.spaces import Dict, Box, Discrete +from gym.vector import BatchedVectorEnv + +class DictEnv(gym.Env): + def __init__(self): + super(DictEnv, self).__init__() + self.observation_space = Dict(OrderedDict([ + ('position', Box(-2., 2., shape=(2,), dtype=np.float32)), + ('velocity', Box(0., 1., shape=(2,), dtype=np.float32)) + ])) + self.action_space = Discrete(2) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + observation = self.observation_space.sample() + reward, done = 0., False + return (observation, reward, done, {}) + + +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("n_workers", [1, 4]) +def test_dict_observations(batch_size: int, n_workers: int): + def make_env(seed): + def _make_env(): + return DictEnv() + return _make_env + + single_env = make_env(0)() + + env = BatchedVectorEnv([make_env(i) for i in range(batch_size)], n_workers=4) + + assert env.observation_space == batch_space(single_env.observation_space, batch_size) + observations = env.reset() + + assert observations["position"].shape == (batch_size, 2) + assert observations["velocity"].shape == (batch_size, 2)