diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 62d4ba7ab12..068d37f8135 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -22,6 +22,7 @@ import torch from tensordict import ( + _zip_strict, is_tensor_collection, LazyStackedTensorDict, TensorDict, @@ -544,9 +545,7 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: raise RuntimeError( f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}." ) - for _kwargs, _new_kwargs in zip( - self.create_env_kwargs, kwargs, strict=True - ): + for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs): _kwargs.update(_new_kwargs) def _get_in_keys_to_exclude(self, tensordict):