Skip to content

Commit

Permalink
[Feature] multiagent data standardization: PPO advantages (#2677)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <[email protected]>
  • Loading branch information
matteobettini and vmoens authored Jan 15, 2025
1 parent 50011dc commit b7a0d11
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 12 deletions.
14 changes: 13 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch

from packaging import version, version as pack_version

from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict._C import unravel_keys
from tensordict.nn import (
Expand All @@ -38,6 +37,7 @@
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
from torch import autograd, nn
from torchrl._utils import _standardize
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
from torchrl.data.postprocs.postprocs import MultiStep
from torchrl.envs.model_based.dreamer import DreamerEnv
Expand Down Expand Up @@ -15848,6 +15848,18 @@ class _AcceptedKeys:


class TestUtils:
def test_standardization(self):
t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6)
std_t0 = _standardize(t, exclude_dims=(1, 3))
std_t1 = (t - t.mean((0, 2), keepdim=True)) / t.std((0, 2), keepdim=True).clamp(
1 - 6
)
torch.testing.assert_close(std_t0, std_t1)
std_t = _standardize(t, (), -1, 2)
torch.testing.assert_close(std_t, (t + 1) / 2)
std_t = _standardize(t, ())
torch.testing.assert_close(std_t, (t - t.mean()) / t.std())

@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
@pytest.mark.parametrize("T", [1, 10])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down
68 changes: 66 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
from distutils.util import strtobool
from functools import wraps
from importlib import import_module
from typing import Any, Callable, cast, Dict, TypeVar, Union
from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union

import numpy as np
import torch
from packaging.version import parse
from tensordict import unravel_key

from tensordict.utils import NestedKey
from torch import multiprocessing as mp
from torch import multiprocessing as mp, Tensor

try:
from torch.compiler import is_compiling
Expand Down Expand Up @@ -872,6 +872,70 @@ def set_mode(self, type: Any | None) -> None:
self._mode = type


def _standardize(
input: Tensor,
exclude_dims: Tuple[int] = (),
mean: Tensor | None = None,
std: Tensor | None = None,
eps: float | None = None,
):
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.
Useful when processing multi-agent data to keep the agent dimensions independent.
Args:
input (Tensor): the input tensor to be standardized.
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
"""
if eps is None:
if input.dtype.is_floating_point:
eps = torch.finfo(torch.float).resolution
else:
eps = 1e-6

len_exclude_dims = len(exclude_dims)
if not len_exclude_dims:
if mean is None:
mean = input.mean()
else:
# Assume dtypes are compatible
mean = torch.as_tensor(mean, device=input.device)
if std is None:
std = input.std()
else:
# Assume dtypes are compatible
std = torch.as_tensor(std, device=input.device)
return (input - mean) / std.clamp_min(eps)

input_shape = input.shape
exclude_dims = [
d if d >= 0 else d + len(input_shape) for d in exclude_dims
] # Make negative dims positive

if len(set(exclude_dims)) != len_exclude_dims:
raise ValueError("Exclude dims has repeating elements")
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
raise ValueError(
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
)
if len_exclude_dims == len(input_shape):
warnings.warn(
"_standardize called but all dims were excluded from the statistics, returning unprocessed input"
)
return input

included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims)
if mean is None:
mean = torch.mean(input, keepdim=True, dim=included_dims)
if std is None:
std = torch.std(input, keepdim=True, dim=included_dims)
return (input - mean) / std.clamp_min(eps)


@wraps(torch.compile)
def compile_with_warmup(*args, warmup: int = 1, **kwargs):
"""Compile a model with warm-up.
Expand Down
58 changes: 49 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl._utils import _standardize
from torchrl.objectives.common import LossModule

from torchrl.objectives.utils import (
Expand All @@ -46,6 +47,7 @@
TDLambdaEstimator,
VTrace,
)
from yaml import warnings


class PPOLoss(LossModule):
Expand Down Expand Up @@ -87,6 +89,9 @@ class PPOLoss(LossModule):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -311,6 +316,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
advantage_key: str = None,
Expand Down Expand Up @@ -381,6 +387,8 @@ def __init__(
self.critic_coef = None
self.loss_critic_type = loss_critic_type
self.normalize_advantage = normalize_advantage
self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._set_deprecated_ctor_keys(
Expand Down Expand Up @@ -606,9 +614,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
if is_tensor_collection(log_weight):
Expand Down Expand Up @@ -711,6 +726,9 @@ class ClipPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -802,6 +820,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand All @@ -821,6 +840,7 @@ def __init__(
critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
Expand Down Expand Up @@ -871,9 +891,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
# ESS for logging
Expand Down Expand Up @@ -955,6 +982,9 @@ class KLPENPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -1048,6 +1078,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand All @@ -1063,6 +1094,7 @@ def __init__(
critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
Expand Down Expand Up @@ -1151,9 +1183,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
)
advantage = tensordict_copy.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict_copy)
neg_loss = log_weight.exp() * advantage

Expand Down

0 comments on commit b7a0d11

Please sign in to comment.