diff --git a/torchrl/_utils.py b/torchrl/_utils.py index ddd92d58341..50be67768e9 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -24,7 +24,7 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, Dict, Sequence, TypeVar, Union +from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union import numpy as np import torch @@ -872,7 +872,9 @@ def set_mode(self, type: Any | None) -> None: self._mode = type -def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): +def _standardize( + input, exclude_dims: Tuple[int] = (), mean=None, std=None, eps: float = None +): """Standardizes the input tensor with the possibility of excluding specific dims from the statistics. Useful when processing multi-agent data to keep the agent dimensions independent. @@ -882,8 +884,12 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: (). mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. + eps (float): epsilon to be used for numerical stability. Default: float32 resolution. """ + if eps is None: + eps = torch.finfo(torch.float.dtype).resolution + input_shape = input.shape exclude_dims = [ d if d >= 0 else d + len(input_shape) for d in exclude_dims @@ -893,7 +899,7 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): raise ValueError("Exclude dims has repeating elements") if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims): raise ValueError( - f"exclude_dims provided outside bounds for input of shape={input_shape}" + f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}" ) if len(exclude_dims) == len(input_shape): warnings.warn( @@ -901,11 +907,14 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): ) return input - # Put all excluded dims in the beginning - permutation = list(range(len(input_shape))) - for dim in exclude_dims: - permutation.insert(0, permutation.pop(permutation.index(dim))) - permuted_input = input.permute(*permutation) + if len(exclude_dims): + # Put all excluded dims in the beginning + permutation = list(range(len(input_shape))) + for dim in exclude_dims: + permutation.insert(0, permutation.pop(permutation.index(dim))) + permuted_input = input.permute(*permutation) + else: + permuted_input = input normalized_shape_len = len(input_shape) - len(exclude_dims) if mean is None: @@ -916,11 +925,15 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): std = torch.std( permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0)) ) - output = (permuted_input - mean) / std.clamp_min(1e-6) + output = (permuted_input - mean) / std.clamp_min(eps) # Reverse permutation - inv_permutation = torch.argsort(torch.LongTensor(permutation)).tolist() - output = torch.permute(output, inv_permutation) + if len(exclude_dims): + inv_permutation = torch.argsort( + torch.tensor(permutation, dtype=torch.long, device=input.device) + ).tolist() + output = torch.permute(output, inv_permutation) + return output diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 36068be2aaa..5e5d4ea006e 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -8,7 +8,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Sequence, Tuple +from typing import Tuple import torch from tensordict import ( @@ -88,8 +88,9 @@ class PPOLoss(LossModule): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. - normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage - standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -314,7 +315,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Sequence[int] = (), + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, advantage_key: str = None, @@ -715,8 +716,9 @@ class ClipPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. - normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage - standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -808,7 +810,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Sequence[int] = (), + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -961,8 +963,9 @@ class KLPENPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. - normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage - standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -1056,7 +1059,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Sequence[int] = (), + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None,