From 1a6c9e2d03c3a7577b320dfd37a23a30621b9a28 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Wed, 15 Jan 2025 16:02:44 +0100 Subject: [PATCH] [BugFix] PettingZoo dict action spaces (#2692) --- torchrl/envs/libs/pettingzoo.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 9853e8d516d..5936c939cec 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -9,6 +9,7 @@ import warnings from typing import Dict, List, Tuple, Union +import numpy as np import packaging import torch from tensordict import TensorDictBase @@ -72,6 +73,19 @@ def _load_available_envs() -> Dict: return all_environments +def _extract_nested_with_index( + data: Union[np.ndarray, Dict[str, np.ndarray]], index: int +): + if isinstance(data, np.ndarray): + return data[index] + elif isinstance(data, dict): + return { + key: _extract_nested_with_index(value, index) for key, value in data.items() + } + else: + raise NotImplementedError(f"Invalid type of data {data}") + + class PettingZooWrapper(_EnvWrapper): """PettingZoo environment wrapper. @@ -735,7 +749,9 @@ def _step_parallel( "full_action_spec", group, "action" ].to_numpy(group_action) for index, agent in enumerate(agents): - action_dict[agent] = group_action_np[index] + # group_action_np can be a dict or an array. We need to recursively index it + action = _extract_nested_with_index(group_action_np, index) + action_dict[agent] = action return self._env.step(action_dict) @@ -750,7 +766,8 @@ def _step_aec( group_action_np = self.input_spec[ "full_action_spec", group, "action" ].to_numpy(group_action) - action = group_action_np[agent_index] + # group_action_np can be a dict or an array. We need to recursively index it + action = _extract_nested_with_index(group_action_np, agent_index) break self._env.step(action)