From 0000349ab1080a05c38f2073b574ccbdafaae600 Mon Sep 17 00:00:00 2001 From: Antoine Broyelle Date: Fri, 4 Oct 2024 22:37:20 +0200 Subject: [PATCH 1/3] Check number of kwargs marches num_workers --- test/test_env.py | 39 ++++++++++++++++++++++++++++++++++++ torchrl/envs/batched_envs.py | 18 ++++++++++++++--- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index bbec29a0d78..9602c596f22 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -491,6 +491,26 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + def test_create_env_fn(self, maybe_fork_ParallelEnv): + def make_env(): + return GymEnv(PENDULUM_VERSIONED()) + + with pytest.raises( + RuntimeError, match="len\\(create_env_fn\\) and num_workers mismatch" + ): + maybe_fork_ParallelEnv(4, [make_env, make_env]) + + def test_create_env_kwargs(self, maybe_fork_ParallelEnv): + def make_env(): + return GymEnv(PENDULUM_VERSIONED()) + + with pytest.raises( + RuntimeError, match="len\\(create_env_kwargs\\) and num_workers mismatch" + ): + maybe_fork_ParallelEnv( + 4, make_env, create_env_kwargs=[{"seed": 0}, {"seed": 1}] + ) + @pytest.mark.skipif( not torch.cuda.device_count(), reason="No cuda device detected." ) @@ -1121,6 +1141,25 @@ def env_fn2(seed): env1.close() env2.close() + @pytest.mark.parametrize("parallel", [True, False]) + def test_parallel_env_update_kwargs(self, parallel, maybe_fork_ParallelEnv): + def make_env(seed=None): + env = DiscreteActionConvMockEnv() + if seed is not None: + env.set_seed(seed) + return env + + _class = maybe_fork_ParallelEnv if parallel else SerialEnv + env = _class( + num_workers=2, + create_env_fn=make_env, + create_env_kwargs=[{"seed": 0}, {"seed": 1}], + ) + with pytest.raises( + RuntimeError, match="len\\(kwargs\\) and num_workers mismatch" + ): + env.update_kwargs([{"seed": 42}]) + @pytest.mark.parametrize("batch_size", [(32, 5), (4,), (1,), ()]) @pytest.mark.parametrize("n_workers", [2, 1]) def test_parallel_env_reset_flag( diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index eff1808af34..62d4ba7ab12 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -318,14 +318,20 @@ def __init__( create_env_fn = [create_env_fn for _ in range(num_workers)] elif len(create_env_fn) != num_workers: raise RuntimeError( - f"num_workers and len(create_env_fn) mismatch, " - f"got {len(create_env_fn)} and {num_workers}" + f"len(create_env_fn) and num_workers mismatch, " + f"got {len(create_env_fn)} and {num_workers}." ) + create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs if isinstance(create_env_kwargs, dict): create_env_kwargs = [ deepcopy(create_env_kwargs) for _ in range(num_workers) ] + elif len(create_env_kwargs) != num_workers: + raise RuntimeError( + f"len(create_env_kwargs) and num_workers mismatch, " + f"got {len(create_env_kwargs)} and {num_workers}." + ) self.policy_proof = policy_proof self.num_workers = num_workers @@ -534,7 +540,13 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: for _kwargs in self.create_env_kwargs: _kwargs.update(kwargs) else: - for _kwargs, _new_kwargs in zip(self.create_env_kwargs, kwargs): + if len(kwargs) != self.num_workers: + raise RuntimeError( + f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}." + ) + for _kwargs, _new_kwargs in zip( + self.create_env_kwargs, kwargs, strict=True + ): _kwargs.update(_new_kwargs) def _get_in_keys_to_exclude(self, tensordict): From 36936edf5c9fe7af2318d8dd408865ac7fa11cd1 Mon Sep 17 00:00:00 2001 From: Antoine Broyelle Date: Tue, 8 Oct 2024 21:29:42 +0200 Subject: [PATCH 2/3] _zip_strict --- torchrl/envs/batched_envs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 62d4ba7ab12..068d37f8135 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -22,6 +22,7 @@ import torch from tensordict import ( + _zip_strict, is_tensor_collection, LazyStackedTensorDict, TensorDict, @@ -544,9 +545,7 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: raise RuntimeError( f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}." ) - for _kwargs, _new_kwargs in zip( - self.create_env_kwargs, kwargs, strict=True - ): + for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs): _kwargs.update(_new_kwargs) def _get_in_keys_to_exclude(self, tensordict): From 684e15a9d815c6b1ef87f230bcd8be64dc26b2d5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 9 Oct 2024 13:16:16 +0100 Subject: [PATCH 3/3] fix zip_strict import --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 068d37f8135..02c7f5893dc 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -22,13 +22,13 @@ import torch from tensordict import ( - _zip_strict, is_tensor_collection, LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key, ) +from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process,