Skip to content

Commit

Permalink
[Feature] Check number of kwargs matches num_workers (#2465)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <[email protected]>
  • Loading branch information
antoinebrl and vmoens authored Oct 10, 2024
1 parent e127d9a commit f411f93
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
39 changes: 39 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,26 @@ def test_mb_env_batch_lock(self, device, seed=0):


class TestParallel:
def test_create_env_fn(self, maybe_fork_ParallelEnv):
def make_env():
return GymEnv(PENDULUM_VERSIONED())

with pytest.raises(
RuntimeError, match="len\\(create_env_fn\\) and num_workers mismatch"
):
maybe_fork_ParallelEnv(4, [make_env, make_env])

def test_create_env_kwargs(self, maybe_fork_ParallelEnv):
def make_env():
return GymEnv(PENDULUM_VERSIONED())

with pytest.raises(
RuntimeError, match="len\\(create_env_kwargs\\) and num_workers mismatch"
):
maybe_fork_ParallelEnv(
4, make_env, create_env_kwargs=[{"seed": 0}, {"seed": 1}]
)

@pytest.mark.skipif(
not torch.cuda.device_count(), reason="No cuda device detected."
)
Expand Down Expand Up @@ -1121,6 +1141,25 @@ def env_fn2(seed):
env1.close()
env2.close()

@pytest.mark.parametrize("parallel", [True, False])
def test_parallel_env_update_kwargs(self, parallel, maybe_fork_ParallelEnv):
def make_env(seed=None):
env = DiscreteActionConvMockEnv()
if seed is not None:
env.set_seed(seed)
return env

_class = maybe_fork_ParallelEnv if parallel else SerialEnv
env = _class(
num_workers=2,
create_env_fn=make_env,
create_env_kwargs=[{"seed": 0}, {"seed": 1}],
)
with pytest.raises(
RuntimeError, match="len\\(kwargs\\) and num_workers mismatch"
):
env.update_kwargs([{"seed": 42}])

@pytest.mark.parametrize("batch_size", [(32, 5), (4,), (1,), ()])
@pytest.mark.parametrize("n_workers", [2, 1])
def test_parallel_env_reset_flag(
Expand Down
17 changes: 14 additions & 3 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TensorDictBase,
unravel_key,
)
from tensordict.utils import _zip_strict
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
Expand Down Expand Up @@ -318,14 +319,20 @@ def __init__(
create_env_fn = [create_env_fn for _ in range(num_workers)]
elif len(create_env_fn) != num_workers:
raise RuntimeError(
f"num_workers and len(create_env_fn) mismatch, "
f"got {len(create_env_fn)} and {num_workers}"
f"len(create_env_fn) and num_workers mismatch, "
f"got {len(create_env_fn)} and {num_workers}."
)

create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs
if isinstance(create_env_kwargs, dict):
create_env_kwargs = [
deepcopy(create_env_kwargs) for _ in range(num_workers)
]
elif len(create_env_kwargs) != num_workers:
raise RuntimeError(
f"len(create_env_kwargs) and num_workers mismatch, "
f"got {len(create_env_kwargs)} and {num_workers}."
)

self.policy_proof = policy_proof
self.num_workers = num_workers
Expand Down Expand Up @@ -534,7 +541,11 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None:
for _kwargs in self.create_env_kwargs:
_kwargs.update(kwargs)
else:
for _kwargs, _new_kwargs in zip(self.create_env_kwargs, kwargs):
if len(kwargs) != self.num_workers:
raise RuntimeError(
f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}."
)
for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs):
_kwargs.update(_new_kwargs)

def _get_in_keys_to_exclude(self, tensordict):
Expand Down

0 comments on commit f411f93

Please sign in to comment.