From b7a0d11e52a7b9adeb8baa7817b3a03bebec0d01 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Wed, 15 Jan 2025 17:04:22 +0100 Subject: [PATCH 1/3] [Feature] multiagent data standardization: PPO advantages (#2677) Co-authored-by: Vincent Moens --- test/test_cost.py | 14 +++++++- torchrl/_utils.py | 68 +++++++++++++++++++++++++++++++++++++-- torchrl/objectives/ppo.py | 58 +++++++++++++++++++++++++++------ 3 files changed, 128 insertions(+), 12 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 1f191e41db6..a0283e0e276 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -19,7 +19,6 @@ import torch from packaging import version, version as pack_version - from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict._C import unravel_keys from tensordict.nn import ( @@ -38,6 +37,7 @@ from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn +from torchrl._utils import _standardize from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv @@ -15848,6 +15848,18 @@ class _AcceptedKeys: class TestUtils: + def test_standardization(self): + t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6) + std_t0 = _standardize(t, exclude_dims=(1, 3)) + std_t1 = (t - t.mean((0, 2), keepdim=True)) / t.std((0, 2), keepdim=True).clamp( + 1 - 6 + ) + torch.testing.assert_close(std_t0, std_t1) + std_t = _standardize(t, (), -1, 2) + torch.testing.assert_close(std_t, (t + 1) / 2) + std_t = _standardize(t, ()) + torch.testing.assert_close(std_t, (t - t.mean()) / t.std()) + @pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip @pytest.mark.parametrize("T", [1, 10]) @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 6a2f80aeffb..f999fa96c1d 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -24,7 +24,7 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, Dict, TypeVar, Union +from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union import numpy as np import torch @@ -32,7 +32,7 @@ from tensordict import unravel_key from tensordict.utils import NestedKey -from torch import multiprocessing as mp +from torch import multiprocessing as mp, Tensor try: from torch.compiler import is_compiling @@ -872,6 +872,70 @@ def set_mode(self, type: Any | None) -> None: self._mode = type +def _standardize( + input: Tensor, + exclude_dims: Tuple[int] = (), + mean: Tensor | None = None, + std: Tensor | None = None, + eps: float | None = None, +): + """Standardizes the input tensor with the possibility of excluding specific dims from the statistics. + + Useful when processing multi-agent data to keep the agent dimensions independent. + + Args: + input (Tensor): the input tensor to be standardized. + exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: (). + mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. + std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. + eps (float): epsilon to be used for numerical stability. Default: float32 resolution. + + """ + if eps is None: + if input.dtype.is_floating_point: + eps = torch.finfo(torch.float).resolution + else: + eps = 1e-6 + + len_exclude_dims = len(exclude_dims) + if not len_exclude_dims: + if mean is None: + mean = input.mean() + else: + # Assume dtypes are compatible + mean = torch.as_tensor(mean, device=input.device) + if std is None: + std = input.std() + else: + # Assume dtypes are compatible + std = torch.as_tensor(std, device=input.device) + return (input - mean) / std.clamp_min(eps) + + input_shape = input.shape + exclude_dims = [ + d if d >= 0 else d + len(input_shape) for d in exclude_dims + ] # Make negative dims positive + + if len(set(exclude_dims)) != len_exclude_dims: + raise ValueError("Exclude dims has repeating elements") + if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims): + raise ValueError( + f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}" + ) + if len_exclude_dims == len(input_shape): + warnings.warn( + "_standardize called but all dims were excluded from the statistics, returning unprocessed input" + ) + return input + + included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims) + if mean is None: + mean = torch.mean(input, keepdim=True, dim=included_dims) + if std is None: + std = torch.std(input, keepdim=True, dim=included_dims) + return (input - mean) / std.clamp_min(eps) + + @wraps(torch.compile) def compile_with_warmup(*args, warmup: int = 1, **kwargs): """Compile a model with warm-up. diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index eb9a916dfc1..3d1b3bd5088 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -27,6 +27,7 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl._utils import _standardize from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -46,6 +47,7 @@ TDLambdaEstimator, VTrace, ) +from yaml import warnings class PPOLoss(LossModule): @@ -87,6 +89,9 @@ class PPOLoss(LossModule): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -311,6 +316,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, advantage_key: str = None, @@ -381,6 +387,8 @@ def __init__( self.critic_coef = None self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage + self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims + if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._set_deprecated_ctor_keys( @@ -606,9 +614,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) if is_tensor_collection(log_weight): @@ -711,6 +726,9 @@ class ClipPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -802,6 +820,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -821,6 +840,7 @@ def __init__( critic_coef=critic_coef, loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, gamma=gamma, separate_losses=separate_losses, reduction=reduction, @@ -871,9 +891,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) # ESS for logging @@ -955,6 +982,9 @@ class KLPENPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -1048,6 +1078,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -1063,6 +1094,7 @@ def __init__( critic_coef=critic_coef, loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, gamma=gamma, separate_losses=separate_losses, reduction=reduction, @@ -1151,9 +1183,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: ) advantage = tensordict_copy.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) + log_weight, dist, kl_approx = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage From 4b3279a3f9a28486549e2916a2e0cb730c389714 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 Jan 2025 16:14:52 +0000 Subject: [PATCH 2/3] [BE] Add type annotation for tensor_keys to facilitate auto-complete ghstack-source-id: b4a8fe38e7c6b028759eef082f65f26036bc0250 Pull Request resolved: https://github.com/pytorch/rl/pull/2696 --- torchrl/objectives/a2c.py | 1 + torchrl/objectives/common.py | 1 + torchrl/objectives/cql.py | 2 ++ torchrl/objectives/crossq.py | 1 + torchrl/objectives/ddpg.py | 1 + torchrl/objectives/decision_transformer.py | 2 ++ torchrl/objectives/deprecated.py | 1 + torchrl/objectives/dqn.py | 2 ++ torchrl/objectives/dreamer.py | 12 ++++++++++++ torchrl/objectives/gail.py | 1 + torchrl/objectives/iql.py | 2 ++ torchrl/objectives/multiagent/qmixer.py | 1 + torchrl/objectives/ppo.py | 1 + torchrl/objectives/redq.py | 1 + torchrl/objectives/reinforce.py | 1 + torchrl/objectives/sac.py | 2 ++ torchrl/objectives/td3.py | 1 + torchrl/objectives/td3_bc.py | 1 + torchrl/objectives/value/advantages.py | 1 + 19 files changed, 35 insertions(+) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index f324e491298..e90a188331c 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -242,6 +242,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" sample_log_prob: NestedKey = "sample_log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index c2627770de9..0cda513e419 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -128,6 +128,7 @@ class _AcceptedKeys: pass + tensor_keys: _AcceptedKeys _vmap_randomness = None default_value_estimator: ValueEstimators = None diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 6e056589a8c..4c320dec46e 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -260,6 +260,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -1024,6 +1025,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" pred_val: NestedKey = "pred_val" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 22e84673641..9af95581344 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -242,6 +242,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" log_prob: NestedKey = "_log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 7dc6b23212a..26f7d128601 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -173,6 +173,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a0d193acbfc..5e24edf548e 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -70,6 +70,7 @@ class _AcceptedKeys: # the "action" output from the model action_pred: NestedKey = "action" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() actor_network: TensorDictModule @@ -280,6 +281,7 @@ class _AcceptedKeys: # the "action" output from the model action_pred: NestedKey = "action" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() actor_network: TensorDictModule diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 2a4124c80de..7f795706640 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -127,6 +127,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 72638893aa4..d025018e9c7 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -164,6 +164,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 out_keys = ["loss"] @@ -435,6 +436,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 418ec95b677..a92d5bfeedd 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -89,8 +89,13 @@ class _AcceptedKeys: pixels: NestedKey = "pixels" reco_pixels: NestedKey = "reco_pixels" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() + decoder: TensorDictModule + reward_model: TensorDictModule + world_mdel: TensorDictModule + def __init__( self, world_model: TensorDictModule, @@ -238,9 +243,13 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TDLambda + value_model: TensorDictModule + actor_model: TensorDictModule + def __init__( self, actor_model: TensorDictModule, @@ -392,8 +401,11 @@ class _AcceptedKeys: value: NestedKey = "state_value" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() + value_model: TensorDictModule + def __init__( self, value_model: TensorDictModule, diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index bbac2581199..ff95b0036ee 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -59,6 +59,7 @@ class _AcceptedKeys: collector_observation: NestedKey = "collector_observation" discriminator_pred: NestedKey = "d_logits" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() discriminator_network: TensorDictModule diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 039d5fc1c34..300105c1ba7 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -233,6 +233,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 out_keys = [ @@ -709,6 +710,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 39777c59e26..793a335f1b9 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -179,6 +179,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 out_keys = ["loss"] diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 3d1b3bd5088..cd22e03323c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -295,6 +295,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index e234df1a512..9ed3a7f8f3e 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -231,6 +231,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 2f207339b2f..4334016503f 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -211,6 +211,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE out_keys = ["loss_actor", "loss_value"] diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index eae6b7feb34..9d9790d52b5 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -290,6 +290,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 @@ -1029,6 +1030,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" log_prob: NestedKey = "log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 delay_actor: bool = False diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 42e04ec2212..20dcf19dce3 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -204,6 +204,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 998fbde6ea4..45a76e80a53 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -217,6 +217,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 3b08780e24c..dd3f9cc4589 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -143,6 +143,7 @@ class _AcceptedKeys: steps_to_next_obs: NestedKey = "steps_to_next_obs" sample_log_prob: NestedKey = "sample_log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys() value_network: Union[TensorDictModule, Callable] _vmap_randomness = None From dc25a55a7f5fecdbfde7c38822f6ac6dc4e590f6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 14 Jan 2025 08:26:25 +0000 Subject: [PATCH 3/3] [BugFix,Doc] Fix BATCHED_PIPE_TIMEOUT refs and doc ghstack-source-id: 6e43c4ff1c319545cf0952abf6f35f3e7ed473e0 Pull Request resolved: https://github.com/pytorch/rl/pull/2695 --- torchrl/envs/batched_envs.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 5b6763f6910..2a70f70a3e2 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1217,6 +1217,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): __doc__ += BatchedEnvBase.__doc__ __doc__ += """ + .. note:: ParallelEnv will timeout after one of the worker is idle for a determinate amount of time. + This can be controlled via the BATCHED_PIPE_TIMEOUT environment variable, which in turn modifies + the torchrl._utils.BATCHED_PIPE_TIMEOUT integer. The default timeout value is 10000 seconds. + .. warning:: TorchRL's ParallelEnv is quite stringent when it comes to env specs, since these are used to build shared memory buffers for inter-process communication. @@ -1353,7 +1357,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): """ def _start_workers(self) -> None: + import torchrl + self._timeout = 10.0 + self.BATCHED_PIPE_TIMEOUT = torchrl._utils.BATCHED_PIPE_TIMEOUT from torchrl.envs.env_creator import EnvCreator @@ -1606,7 +1613,7 @@ def step_and_maybe_reset( for i in workers_range: event = self._events[i] - event.wait(self._timeout) + event.wait(self.BATCHED_PIPE_TIMEOUT) event.clear() if self._non_tensor_keys: @@ -1796,7 +1803,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in workers_range: event = self._events[i] - event.wait(self._timeout) + event.wait(self.BATCHED_PIPE_TIMEOUT) event.clear() if self._non_tensor_keys: @@ -1965,7 +1972,7 @@ def tentative_update(val, other): for i, _ in outs: event = self._events[i] - event.wait(self._timeout) + event.wait(self.BATCHED_PIPE_TIMEOUT) event.clear() workers_nontensor = [] @@ -2023,7 +2030,7 @@ def _shutdown_workers(self) -> None: for channel in self.parent_channels: channel.close() for proc in self._workers: - proc.join(timeout=1.0) + proc.join(timeout=self._timeout) finally: for proc in self._workers: if proc.is_alive():