diff --git a/test/test_env.py b/test/test_env.py index bbec29a0d78..9602c596f22 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -491,6 +491,26 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + def test_create_env_fn(self, maybe_fork_ParallelEnv): + def make_env(): + return GymEnv(PENDULUM_VERSIONED()) + + with pytest.raises( + RuntimeError, match="len\\(create_env_fn\\) and num_workers mismatch" + ): + maybe_fork_ParallelEnv(4, [make_env, make_env]) + + def test_create_env_kwargs(self, maybe_fork_ParallelEnv): + def make_env(): + return GymEnv(PENDULUM_VERSIONED()) + + with pytest.raises( + RuntimeError, match="len\\(create_env_kwargs\\) and num_workers mismatch" + ): + maybe_fork_ParallelEnv( + 4, make_env, create_env_kwargs=[{"seed": 0}, {"seed": 1}] + ) + @pytest.mark.skipif( not torch.cuda.device_count(), reason="No cuda device detected." ) @@ -1121,6 +1141,25 @@ def env_fn2(seed): env1.close() env2.close() + @pytest.mark.parametrize("parallel", [True, False]) + def test_parallel_env_update_kwargs(self, parallel, maybe_fork_ParallelEnv): + def make_env(seed=None): + env = DiscreteActionConvMockEnv() + if seed is not None: + env.set_seed(seed) + return env + + _class = maybe_fork_ParallelEnv if parallel else SerialEnv + env = _class( + num_workers=2, + create_env_fn=make_env, + create_env_kwargs=[{"seed": 0}, {"seed": 1}], + ) + with pytest.raises( + RuntimeError, match="len\\(kwargs\\) and num_workers mismatch" + ): + env.update_kwargs([{"seed": 42}]) + @pytest.mark.parametrize("batch_size", [(32, 5), (4,), (1,), ()]) @pytest.mark.parametrize("n_workers", [2, 1]) def test_parallel_env_reset_flag( diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index eff1808af34..02c7f5893dc 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -28,6 +28,7 @@ TensorDictBase, unravel_key, ) +from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, @@ -318,14 +319,20 @@ def __init__( create_env_fn = [create_env_fn for _ in range(num_workers)] elif len(create_env_fn) != num_workers: raise RuntimeError( - f"num_workers and len(create_env_fn) mismatch, " - f"got {len(create_env_fn)} and {num_workers}" + f"len(create_env_fn) and num_workers mismatch, " + f"got {len(create_env_fn)} and {num_workers}." ) + create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs if isinstance(create_env_kwargs, dict): create_env_kwargs = [ deepcopy(create_env_kwargs) for _ in range(num_workers) ] + elif len(create_env_kwargs) != num_workers: + raise RuntimeError( + f"len(create_env_kwargs) and num_workers mismatch, " + f"got {len(create_env_kwargs)} and {num_workers}." + ) self.policy_proof = policy_proof self.num_workers = num_workers @@ -534,7 +541,11 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: for _kwargs in self.create_env_kwargs: _kwargs.update(kwargs) else: - for _kwargs, _new_kwargs in zip(self.create_env_kwargs, kwargs): + if len(kwargs) != self.num_workers: + raise RuntimeError( + f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}." + ) + for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs): _kwargs.update(_new_kwargs) def _get_in_keys_to_exclude(self, tensordict):