Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jan 10, 2025
1 parent 3d1e978 commit 9cfe213
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
35 changes: 24 additions & 11 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from distutils.util import strtobool
from functools import wraps
from importlib import import_module
from typing import Any, Callable, cast, Dict, Sequence, TypeVar, Union
from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -872,7 +872,9 @@ def set_mode(self, type: Any | None) -> None:
self._mode = type


def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None):
def _standardize(
input, exclude_dims: Tuple[int] = (), mean=None, std=None, eps: float = None
):
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.
Useful when processing multi-agent data to keep the agent dimensions independent.
Expand All @@ -882,8 +884,12 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None):
exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
"""
if eps is None:
eps = torch.finfo(torch.float.dtype).resolution

input_shape = input.shape
exclude_dims = [
d if d >= 0 else d + len(input_shape) for d in exclude_dims
Expand All @@ -893,19 +899,22 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None):
raise ValueError("Exclude dims has repeating elements")
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
raise ValueError(
f"exclude_dims provided outside bounds for input of shape={input_shape}"
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
)
if len(exclude_dims) == len(input_shape):
warnings.warn(
"standardize called but all dims were excluded from the statistics, returning unprocessed input"
)
return input

# Put all excluded dims in the beginning
permutation = list(range(len(input_shape)))
for dim in exclude_dims:
permutation.insert(0, permutation.pop(permutation.index(dim)))
permuted_input = input.permute(*permutation)
if len(exclude_dims):
# Put all excluded dims in the beginning
permutation = list(range(len(input_shape)))
for dim in exclude_dims:
permutation.insert(0, permutation.pop(permutation.index(dim)))
permuted_input = input.permute(*permutation)
else:
permuted_input = input
normalized_shape_len = len(input_shape) - len(exclude_dims)

if mean is None:
Expand All @@ -916,11 +925,15 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None):
std = torch.std(
permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0))
)
output = (permuted_input - mean) / std.clamp_min(1e-6)
output = (permuted_input - mean) / std.clamp_min(eps)

# Reverse permutation
inv_permutation = torch.argsort(torch.LongTensor(permutation)).tolist()
output = torch.permute(output, inv_permutation)
if len(exclude_dims):
inv_permutation = torch.argsort(
torch.tensor(permutation, dtype=torch.long, device=input.device)
).tolist()
output = torch.permute(output, inv_permutation)

return output


Expand Down
23 changes: 13 additions & 10 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from copy import deepcopy
from dataclasses import dataclass
from typing import Sequence, Tuple
from typing import Tuple

import torch
from tensordict import (
Expand Down Expand Up @@ -88,8 +88,9 @@ class PPOLoss(LossModule):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage
standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: ().
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -314,7 +315,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Sequence[int] = (),
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
advantage_key: str = None,
Expand Down Expand Up @@ -715,8 +716,9 @@ class ClipPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage
standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: ().
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -808,7 +810,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Sequence[int] = (),
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand Down Expand Up @@ -961,8 +963,9 @@ class KLPENPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage
standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: ().
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -1056,7 +1059,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Sequence[int] = (),
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand Down

0 comments on commit 9cfe213

Please sign in to comment.