Skip to content

Commit

Permalink
[BugFix] PettingZoo dict action spaces (#2692)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Jan 15, 2025
1 parent 61e05b3 commit 1a6c9e2
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from typing import Dict, List, Tuple, Union

import numpy as np
import packaging
import torch
from tensordict import TensorDictBase
Expand Down Expand Up @@ -72,6 +73,19 @@ def _load_available_envs() -> Dict:
return all_environments


def _extract_nested_with_index(
data: Union[np.ndarray, Dict[str, np.ndarray]], index: int
):
if isinstance(data, np.ndarray):
return data[index]
elif isinstance(data, dict):
return {
key: _extract_nested_with_index(value, index) for key, value in data.items()
}
else:
raise NotImplementedError(f"Invalid type of data {data}")


class PettingZooWrapper(_EnvWrapper):
"""PettingZoo environment wrapper.
Expand Down Expand Up @@ -735,7 +749,9 @@ def _step_parallel(
"full_action_spec", group, "action"
].to_numpy(group_action)
for index, agent in enumerate(agents):
action_dict[agent] = group_action_np[index]
# group_action_np can be a dict or an array. We need to recursively index it
action = _extract_nested_with_index(group_action_np, index)
action_dict[agent] = action

return self._env.step(action_dict)

Expand All @@ -750,7 +766,8 @@ def _step_aec(
group_action_np = self.input_spec[
"full_action_spec", group, "action"
].to_numpy(group_action)
action = group_action_np[agent_index]
# group_action_np can be a dict or an array. We need to recursively index it
action = _extract_nested_with_index(group_action_np, agent_index)
break

self._env.step(action)
Expand Down

0 comments on commit 1a6c9e2

Please sign in to comment.