From 7eb268414613595720b8afe78a61430300bc2d69 Mon Sep 17 00:00:00 2001 From: N00bcak Date: Thu, 15 Aug 2024 18:26:13 +0800 Subject: [PATCH 1/4] draft implementation of `ConsistentDropout`/`ConsistentDropoutModule` --- test/test_exploration.py | 122 +++++++++++++++++++++++++- torchrl/modules/models/exploration.py | 94 ++++++++++++++++++++ 2 files changed, 215 insertions(+), 1 deletion(-) diff --git a/test/test_exploration.py b/test/test_exploration.py index 3bb05708d83..6b156871101 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -31,7 +31,10 @@ NormalParamExtractor, TanhNormal, ) -from torchrl.modules.models.exploration import LazygSDEModule +from torchrl.modules.models.exploration import ( + LazygSDEModule, + ConsistentDropoutModule +) from torchrl.modules.tensordict_module.actors import ( Actor, ProbabilisticActor, @@ -738,6 +741,123 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s ), f"failed: mean={mean}, std={std}, sigma_init={sigma_init}, actual: {sigma.mean()}" +@pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5]) +@pytest.mark.parametrize("parallel_spec", [False, True]) +@pytest.mark.parametrize( + "device", + [torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")], +) +def test_consistent_dropout(dropout_p, parallel_spec, device): + ''' + + This preliminary test seeks to ensure two things for both + ConsistentDropout and ConsistentDropoutModule: + 1. Rollout transitions generate a dropout mask as desired. + - We can easily verify the existence of a mask + 2. The dropout mask is correctly applied. + - We will check with stochastic policies whether or not + the loc and scale are the same. + ''' + torch.manual_seed(0) + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + @torch.no_grad + def inner_verify_routine(module, env): + # Perform transitions. + collector = SyncDataCollector( + create_env_fn=env, + policy=module, + frames_per_batch=640, + total_frames=1280, + device=device, + ) + for frames in collector: + masks = [ + (key, value) + for key, value in frames.items() + if 'mask_' in key + ] + # Assert rollouts do indeed correctly generate the masks. + assert len(masks) == 1, ( + "Expected exactly ONE mask since we only put " + f"one dropout module, got {len(masks)}." + ) + + # Verify that the result for this batch is the same. + # Kind of Monte Carlo, to be honest. + sentinel_mask = masks[0][1].clone() + sentinel_outputs = frames.select("loc", "scale").clone() + + desired_dropout_mask = torch.full_like(sentinel_mask, 1 / (1 - dropout_p)) + desired_dropout_mask[sentinel_mask == 0.] = 0. + # As of 15/08/24, :meth:`~torch.nn.functional.dropout` + # is being used. Never hurts to be safe. + assert torch.allclose(sentinel_mask, desired_dropout_mask), ( + "Dropout was not scaled properly." + ) + + infer_mask = module(frames)[masks[0][0]] + infer_outputs = module(frames).select("loc", "scale") + assert (infer_mask == sentinel_mask).all(), ( + "Mask does not match" + ) + + assert all([torch.allclose( + infer_outputs[key], + sentinel_outputs[key] + ) for key in ('loc', 'scale')]), ( + "Outputs do not match:\n " + f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}" + f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}" + ) + + env = SerialEnv( + 2, + ContinuousActionVecMockEnv, + ) + env = TransformedEnv(env.to(device), InitTracker()) + env = env.to(device) + # the module must work with the action spec of a single env or a serial env + if parallel_spec: + action_spec = env.action_spec + else: + action_spec = ContinuousActionVecMockEnv(device=device).action_spec + d_act = action_spec.shape[-1] + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + module_td_seq = TensorDictSequential( + TensorDictModule( + nn.LazyLinear(2 * d_act), + in_keys = ["observation"], + out_keys = ["out"] + ), + ConsistentDropoutModule( + p = dropout_p, + in_key = "out" + ), + TensorDictModule( + NormalParamExtractor(), + in_keys=["out"], + out_keys=["loc", "scale"] + ) + ) + + policy_td_seq = ProbabilisticActor( + module=module_td_seq, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + default_interaction_type=InteractionType.RANDOM, + spec=action_spec, + ).to(device) + + # Wake up the policies + policy_td_seq(env.reset()) + + # Test. + inner_verify_routine(policy_td_seq, env) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 16c6ac5ff30..08d29c185c2 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -8,9 +8,14 @@ import torch from torch import distributions as d, nn +from torch.nn import functional as F +from torch.nn.modules.dropout import _DropoutNd from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter +from tensordict.nn import TensorDictModuleBase +from tensordict.utils import NestedKey + from torchrl._utils import prod from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS from torchrl.envs.utils import exploration_type, ExplorationType @@ -520,3 +525,92 @@ def initialize_parameters( ) self._sigma.materialize((action_dim, state_dim)) self._sigma.data.copy_(self.sigma_init.expand_as(self._sigma)) + + +class ConsistentDropout(_DropoutNd): + """ + Implements the Dropout variant proposed in `"Consistent Dropout for + Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) ` + + This implementation capitalizes on the extensibility of TensorDicts + by storing generated dropout masks in the transitions themselves. + + There is otherwise little conceptual deviance from the original + :class:`~torch.nn.Dropout` implementation. Although, there is probably a lot of + `room for improvement... ` + + NOTE: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode, + so the dropout masks ARE in fact still applied. + + See + - :class:`~torchrl.collectors.SyncDataCollector`: rollout() and iterator() + - :class:`~torchrl.collectors.MultiSyncDataCollector`: Uses + :meth:`~torchrl.collectors.collectors._main_async_collector` + (SyncDataCollector) under the hood + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, x, mask=None): + ''' + During training (rollouts & updates), this call masks a tensor full of + ones before multiplying with the input tensor. + + During evaluation, this call results in a no-op. + ''' + if self.training: + if mask is None: + mask = F.dropout(torch.ones_like(x), self.p, self.training, inplace = False) + return x * mask, mask + + return x + +class ConsistentDropoutModule(TensorDictModuleBase): + """ + Examples: + >>> from tensordict import TensorDict + >>> module = ConsistentDropoutModule(p = 0.1) + >>> td = TensorDict({"x": torch.randn(3, 4)}, [3]) + >>> module(td) + TensorDict( + fields={ + mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False), + x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + """ + def __init__(self, p: float, in_key: NestedKey=None, in_keys=None, out_keys=None): + if in_key is None: + in_key = "x" + if in_keys is None: + in_keys = [in_key, f"mask_{id(self)}"] + elif len(in_keys) != 2: + raise ValueError("in_keys and out_keys length must be 2 for consistent dropout.") + if out_keys is None: + out_keys = [in_key, f"mask_{id(self)}"] + elif len(out_keys) != 2: + raise ValueError("in_keys and out_keys length must be 2 for consistent dropout.") + self.in_keys = in_keys + self.out_keys = out_keys + super().__init__() + + if not 0 <= p < 1: + raise ValueError("p must be in [0,1)!") + + self.consistent_dropout = ConsistentDropout(p) + + def forward(self, tensordict): + x = tensordict.get(self.in_keys[0]) + mask = tensordict.get(self.in_keys[1], default=None) + if self.training: + x, mask = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + tensordict.set(self.out_keys[1], mask) + else: + x = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + + return tensordict \ No newline at end of file From 5bc41e408d768b7ee72ae85f68bb024d3dcbd541 Mon Sep 17 00:00:00 2001 From: N00bcak Date: Thu, 15 Aug 2024 23:23:13 +0800 Subject: [PATCH 2/4] draft changes to documentation --- docs/source/reference/modules.rst | 1 + torchrl/modules/__init__.py | 1 + torchrl/modules/models/__init__.py | 2 +- torchrl/modules/models/exploration.py | 57 +++++++++++++++++++-------- 4 files changed, 44 insertions(+), 17 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 62cf1dedf35..1e894392b2c 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -444,6 +444,7 @@ Regular modules SqueezeLayer Squeeze2dLayer BatchRenorm1d + ConsistentDropoutModule Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index c246b553e95..4b41368591f 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -48,6 +48,7 @@ Squeeze2dLayer, SqueezeLayer, VDNMixer, + ConsistentDropoutModule ) from .tensordict_module import ( Actor, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 9a814e35477..c710f6e887c 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -9,7 +9,7 @@ from .batchrenorm import BatchRenorm1d from .decision_transformer import DecisionTransformer -from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise +from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise, ConsistentDropoutModule from .model_based import ( DreamerActor, ObsDecoder, diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 08d29c185c2..01f678549dc 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -528,28 +528,38 @@ def initialize_parameters( class ConsistentDropout(_DropoutNd): - """ - Implements the Dropout variant proposed in `"Consistent Dropout for - Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) ` + ''' + Implements the :class:`~torch.nn.Dropout` variant proposed in `"Consistent Dropout for + Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) `_. - This implementation capitalizes on the extensibility of TensorDicts - by storing generated dropout masks in the transitions themselves. + This :class:`~torch.nn.Dropout` variant attempts to increase training stability and + reduce update variance by caching the dropout masks used during rollout + and reusing them during the update phase. + + TorchRL's implementation capitalizes on the extensibility of + ``TensorDict``s by storing generated dropout masks + in the transition ``TensorDict`` themselves. - There is otherwise little conceptual deviance from the original - :class:`~torch.nn.Dropout` implementation. Although, there is probably a lot of - `room for improvement... ` + There is otherwise little conceptual deviance from the PyTorch + :class:`~torch.nn.Dropout` implementation. NOTE: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode, so the dropout masks ARE in fact still applied. - See - - :class:`~torchrl.collectors.SyncDataCollector`: rollout() and iterator() - - :class:`~torchrl.collectors.MultiSyncDataCollector`: Uses - :meth:`~torchrl.collectors.collectors._main_async_collector` - (SyncDataCollector) under the hood - """ + See + + - :class:`~torchrl.collectors.SyncDataCollector`: :meth:`~torchrl.collectors.SyncDataCollector.rollout()` and :meth:`~torchrl.collectors.SyncDataCollector.iterator()` + + - :class:`~torchrl.collectors.MultiSyncDataCollector`: Uses :meth:`~torchrl.collectors.collectors._main_async_collector` (:class:`~torchrl.collectors.SyncDataCollector`) under the hood + + - :class:`~torchrl.collectors.MultiaSyncDataCollector`, :class:`~torchrl.collectors.aSyncDataCollector`: Ditto. + ''' def __init__(self, p=0.5): + ''' + Parameters: + p (float, optional): Dropout probability. Default: ``0.5``. + ''' super().__init__() self.p = p @@ -568,7 +578,20 @@ def forward(self, x, mask=None): return x class ConsistentDropoutModule(TensorDictModuleBase): - """ + ''' + + Parameters: + p (float, optional): Dropout probability. Default: ``0.5``. + + in_key (str, optional): The key to be read from input tensordict. + Only used if ``in_keys`` is not specified. + + in_keys (iterable of NestedKeys, Dict[NestedStr, str]): keys to be read + from input tensordict and passed to this module. Default: ``None``. + + out_keys (iterable of str): keys to be written to the input tensordict. + Default: ``None``. + Examples: >>> from tensordict import TensorDict >>> module = ConsistentDropoutModule(p = 0.1) @@ -581,7 +604,9 @@ class ConsistentDropoutModule(TensorDictModuleBase): batch_size=torch.Size([3]), device=None, is_shared=False) - """ + ''' + __doc__ = f"{ConsistentDropout.__doc__}\n{__doc__}" + def __init__(self, p: float, in_key: NestedKey=None, in_keys=None, out_keys=None): if in_key is None: in_key = "x" From c1847bee6566e3d53b07b9d08bbabf8a435a7182 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 30 Aug 2024 18:49:51 +0100 Subject: [PATCH 3/4] make primers - edit doc --- docs/source/reference/modules.rst | 17 ++- test/test_exploration.py | 75 ++++------ torchrl/envs/transforms/transforms.py | 7 +- torchrl/modules/__init__.py | 3 +- torchrl/modules/models/__init__.py | 7 +- torchrl/modules/models/exploration.py | 177 +++++++++++++++-------- torchrl/modules/tensordict_module/rnn.py | 10 +- 7 files changed, 178 insertions(+), 118 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 1e894392b2c..64c83daead3 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -57,8 +57,8 @@ projected (in a L1-manner) into the desired domain. SafeSequential TanhModule -Exploration wrappers -~~~~~~~~~~~~~~~~~~~~ +Exploration wrappers and modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To efficiently explore the environment, TorchRL proposes a series of wrappers that will override the action sampled by the policy by a noisier version. @@ -66,7 +66,7 @@ Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`: if the exploration is set to ``"random"``, the exploration is active. In all other cases, the action written in the tensordict is simply the network output. -.. currentmodule:: torchrl.modules.tensordict_module +.. currentmodule:: torchrl.modules .. autosummary:: :toctree: generated/ @@ -74,6 +74,7 @@ other cases, the action written in the tensordict is simply the network output. AdditiveGaussianModule AdditiveGaussianWrapper + ConsistentDropoutModule EGreedyModule EGreedyWrapper OrnsteinUhlenbeckProcessModule @@ -438,13 +439,13 @@ Regular modules :toctree: generated/ :template: rl_template_noinherit.rst - MLP - ConvNet + BatchRenorm1d + ConsistentDropout Conv3dNet - SqueezeLayer + ConvNet + MLP Squeeze2dLayer - BatchRenorm1d - ConsistentDropoutModule + SqueezeLayer Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_exploration.py b/test/test_exploration.py index 6b156871101..64b445f3f2c 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -31,10 +31,7 @@ NormalParamExtractor, TanhNormal, ) -from torchrl.modules.models.exploration import ( - LazygSDEModule, - ConsistentDropoutModule -) +from torchrl.modules.models.exploration import ConsistentDropoutModule, LazygSDEModule from torchrl.modules.tensordict_module.actors import ( Actor, ProbabilisticActor, @@ -748,16 +745,16 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s [torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")], ) def test_consistent_dropout(dropout_p, parallel_spec, device): - ''' - + """ + This preliminary test seeks to ensure two things for both ConsistentDropout and ConsistentDropoutModule: 1. Rollout transitions generate a dropout mask as desired. - We can easily verify the existence of a mask 2. The dropout mask is correctly applied. - - We will check with stochastic policies whether or not + - We will check with stochastic policies whether or not the loc and scale are the same. - ''' + """ torch.manual_seed(0) # NOTE: Please only put a module with one dropout layer. @@ -773,40 +770,36 @@ def inner_verify_routine(module, env): device=device, ) for frames in collector: - masks = [ - (key, value) - for key, value in frames.items() - if 'mask_' in key - ] + masks = [(key, value) for key, value in frames.items() if "mask_" in key] # Assert rollouts do indeed correctly generate the masks. assert len(masks) == 1, ( - "Expected exactly ONE mask since we only put " - f"one dropout module, got {len(masks)}." - ) - + "Expected exactly ONE mask since we only put " + f"one dropout module, got {len(masks)}." + ) + # Verify that the result for this batch is the same. # Kind of Monte Carlo, to be honest. sentinel_mask = masks[0][1].clone() sentinel_outputs = frames.select("loc", "scale").clone() desired_dropout_mask = torch.full_like(sentinel_mask, 1 / (1 - dropout_p)) - desired_dropout_mask[sentinel_mask == 0.] = 0. + desired_dropout_mask[sentinel_mask == 0.0] = 0.0 # As of 15/08/24, :meth:`~torch.nn.functional.dropout` # is being used. Never hurts to be safe. - assert torch.allclose(sentinel_mask, desired_dropout_mask), ( - "Dropout was not scaled properly." - ) + assert torch.allclose( + sentinel_mask, desired_dropout_mask + ), "Dropout was not scaled properly." infer_mask = module(frames)[masks[0][0]] infer_outputs = module(frames).select("loc", "scale") - assert (infer_mask == sentinel_mask).all(), ( - "Mask does not match" - ) - - assert all([torch.allclose( - infer_outputs[key], - sentinel_outputs[key] - ) for key in ('loc', 'scale')]), ( + assert (infer_mask == sentinel_mask).all(), "Mask does not match" + + assert all( + [ + torch.allclose(infer_outputs[key], sentinel_outputs[key]) + for key in ("loc", "scale") + ] + ), ( "Outputs do not match:\n " f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}" f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}" @@ -828,21 +821,14 @@ def inner_verify_routine(module, env): # NOTE: Please only put a module with one dropout layer. # That's how this test is constructed anyways. module_td_seq = TensorDictSequential( - TensorDictModule( - nn.LazyLinear(2 * d_act), - in_keys = ["observation"], - out_keys = ["out"] - ), - ConsistentDropoutModule( - p = dropout_p, - in_key = "out" - ), - TensorDictModule( - NormalParamExtractor(), - in_keys=["out"], - out_keys=["loc", "scale"] - ) - ) + TensorDictModule( + nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"] + ), + ConsistentDropoutModule(p=dropout_p, in_key="out"), + TensorDictModule( + NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"] + ), + ) policy_td_seq = ProbabilisticActor( module=module_td_seq, @@ -858,6 +844,7 @@ def inner_verify_routine(module, env): # Test. inner_verify_routine(policy_td_seq, env) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 2e2883c33bf..ddcd837be23 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4656,10 +4656,15 @@ def __init__( def reset_key(self): reset_key = self.__dict__.get("_reset_key", None) if reset_key is None: + if self.parent is None: + raise RuntimeError( + "Missing parent, cannot infer reset_key automatically." + ) reset_keys = self.parent.reset_keys if len(reset_keys) > 1: raise RuntimeError( - f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor." + f"Got more than one reset key in env {self.container}, cannot infer which one to use. " + f"Consider providing the reset key in the {type(self)} constructor." ) reset_key = self._reset_key = reset_keys[0] return reset_key diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4b41368591f..f65461842bb 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -21,6 +21,7 @@ ) from .models import ( BatchRenorm1d, + ConsistentDropoutModule, Conv3dNet, ConvNet, DdpgCnnActor, @@ -48,7 +49,6 @@ Squeeze2dLayer, SqueezeLayer, VDNMixer, - ConsistentDropoutModule ) from .tensordict_module import ( Actor, @@ -86,4 +86,5 @@ VmapModule, WorldModelWrapper, ) +from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index c710f6e887c..90b9fadd747 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -9,7 +9,12 @@ from .batchrenorm import BatchRenorm1d from .decision_transformer import DecisionTransformer -from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise, ConsistentDropoutModule +from .exploration import ( + ConsistentDropoutModule, + NoisyLazyLinear, + NoisyLinear, + reset_noise, +) from .model_based import ( DreamerActor, ObsDecoder, diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 01f678549dc..01e0ed015db 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -2,21 +2,22 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import math import warnings -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import torch + +from tensordict.nn import TensorDictModuleBase +from tensordict.utils import NestedKey from torch import distributions as d, nn from torch.nn import functional as F from torch.nn.modules.dropout import _DropoutNd from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter - -from tensordict.nn import TensorDictModuleBase -from tensordict.utils import NestedKey - from torchrl._utils import prod +from torchrl.data.tensor_specs import Unbounded from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.distributions.utils import _cast_transform_device @@ -528,69 +529,89 @@ def initialize_parameters( class ConsistentDropout(_DropoutNd): - ''' - Implements the :class:`~torch.nn.Dropout` variant proposed in `"Consistent Dropout for - Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) `_. - - This :class:`~torch.nn.Dropout` variant attempts to increase training stability and + """Implements a :class:`~torch.nn.Dropout` variant with consistent dropout. + + This method is proposed in `"Consistent Dropout for Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) `_. + + This :class:`~torch.nn.Dropout` variant attempts to increase training stability and reduce update variance by caching the dropout masks used during rollout and reusing them during the update phase. - TorchRL's implementation capitalizes on the extensibility of - ``TensorDict``s by storing generated dropout masks - in the transition ``TensorDict`` themselves. + TorchRL's implementation capitalizes on the extensibility of ``TensorDict``s by storing generated dropout masks + in the transition ``TensorDict`` themselves. This class can be used through :class:`~torchrl.modules.ConsistentDropoutModule` + within policies coded using the :class:`~tensordict.nn.TensorDictModuleBase` API. See this class for a detailed + explanation as well as usage examples. There is otherwise little conceptual deviance from the PyTorch - :class:`~torch.nn.Dropout` implementation. + :class:`~torch.nn.Dropout` implementation. - NOTE: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode, - so the dropout masks ARE in fact still applied. + ..note:: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode but not in `eval` mode, + so the dropout masks will be applied unless the policy passed to the collector is in eval mode. - See + Args: + p (float, optional): Dropout probability. Defaults to ``0.5``. - - :class:`~torchrl.collectors.SyncDataCollector`: :meth:`~torchrl.collectors.SyncDataCollector.rollout()` and :meth:`~torchrl.collectors.SyncDataCollector.iterator()` + .. seealso:: - - :class:`~torchrl.collectors.MultiSyncDataCollector`: Uses :meth:`~torchrl.collectors.collectors._main_async_collector` (:class:`~torchrl.collectors.SyncDataCollector`) under the hood + - :class:`~torchrl.collectors.SyncDataCollector`: + :meth:`~torchrl.collectors.SyncDataCollector.rollout()` and :meth:`~torchrl.collectors.SyncDataCollector.iterator()` + - :class:`~torchrl.collectors.MultiSyncDataCollector`: + Uses :meth:`~torchrl.collectors.collectors._main_async_collector` (:class:`~torchrl.collectors.SyncDataCollector`) + under the hood + - :class:`~torchrl.collectors.MultiaSyncDataCollector`, :class:`~torchrl.collectors.aSyncDataCollector`: Ditto. - - :class:`~torchrl.collectors.MultiaSyncDataCollector`, :class:`~torchrl.collectors.aSyncDataCollector`: Ditto. - ''' + """ - def __init__(self, p=0.5): - ''' - Parameters: - p (float, optional): Dropout probability. Default: ``0.5``. - ''' + def __init__(self, p: float = 0.5): super().__init__() self.p = p - def forward(self, x, mask=None): - ''' - During training (rollouts & updates), this call masks a tensor full of - ones before multiplying with the input tensor. + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + """During training (rollouts & updates), this call masks a tensor full of ones before multiplying with the input tensor. + + During evaluation, this call results in a no-op and only the input is returned. + + Args: + x (torch.Tensor): the input tensor. + mask (torch.Tensor, optional): the optional mask for the dropout. - During evaluation, this call results in a no-op. - ''' + Returns: a tensor and a corresponding mask in train mode, and only a tensor in eval mode. + """ if self.training: if mask is None: - mask = F.dropout(torch.ones_like(x), self.p, self.training, inplace = False) + mask = self.make_mask(input=x) return x * mask, mask - + return x - + + def make_mask(self, *, input=None, shape=None): + if input is not None: + return F.dropout( + torch.ones_like(input), self.p, self.training, inplace=False + ) + elif shape is not None: + return F.dropout(torch.ones(shape), self.p, self.training, inplace=False) + else: + raise RuntimeError("input or shape must be passed to make_mask.") + + class ConsistentDropoutModule(TensorDictModuleBase): - ''' + """A TensorDictModule wrapper for :class:`~ConsistentDropout`. - Parameters: + Args: p (float, optional): Dropout probability. Default: ``0.5``. + in_keys (NestedKey or list of NestedKeys): keys to be read + from input tensordict and passed to this module. + out_keys (NestedKey or iterable of NestedKeys): keys to be written to the input tensordict. + Defaults to ``in_keys`` values. - in_key (str, optional): The key to be read from input tensordict. - Only used if ``in_keys`` is not specified. - - in_keys (iterable of NestedKeys, Dict[NestedStr, str]): keys to be read - from input tensordict and passed to this module. Default: ``None``. - - out_keys (iterable of str): keys to be written to the input tensordict. - Default: ``None``. + Keyword Args: + input_shape (tuple, optional): the shape of the input (non-batchted), used to generate the + tensordict primers with :meth:`~.make_tensordict_primer`. + input_dtype (torch.dtype, optional): the dtype of the input for the primer. If none is pased, + ``torch.get_default_dtype`` is assumed. Examples: >>> from tensordict import TensorDict @@ -604,33 +625,41 @@ class ConsistentDropoutModule(TensorDictModuleBase): batch_size=torch.Size([3]), device=None, is_shared=False) - ''' - __doc__ = f"{ConsistentDropout.__doc__}\n{__doc__}" - - def __init__(self, p: float, in_key: NestedKey=None, in_keys=None, out_keys=None): - if in_key is None: - in_key = "x" - if in_keys is None: - in_keys = [in_key, f"mask_{id(self)}"] - elif len(in_keys) != 2: - raise ValueError("in_keys and out_keys length must be 2 for consistent dropout.") + """ + + def __init__( + self, + p: float, + in_keys: NestedKey | List[NestedKey], + out_keys: NestedKey | List[NestedKey] | None = None, + input_shape: torch.Size = None, + input_dtype: torch.dtype | None = None, + ): + if isinstance(in_keys, NestedKey): + in_keys = [in_keys, f"mask_{id(self)}"] if out_keys is None: - out_keys = [in_key, f"mask_{id(self)}"] - elif len(out_keys) != 2: - raise ValueError("in_keys and out_keys length must be 2 for consistent dropout.") + out_keys = list(in_keys) + if isinstance(out_keys, NestedKey): + out_keys = [out_keys, f"mask_{id(self)}"] + if len(in_keys) != 2 or len(out_keys) != 2: + raise ValueError( + "in_keys and out_keys length must be 2 for consistent dropout." + ) self.in_keys = in_keys self.out_keys = out_keys + self.input_shape = input_shape + self.input_dtype = input_dtype super().__init__() if not 0 <= p < 1: - raise ValueError("p must be in [0,1)!") + raise ValueError(f"p must be in [0,1), got p={p: 4.4f}.") self.consistent_dropout = ConsistentDropout(p) def forward(self, tensordict): x = tensordict.get(self.in_keys[0]) mask = tensordict.get(self.in_keys[1], default=None) - if self.training: + if self.consistent_dropout.training: x, mask = self.consistent_dropout(x, mask=mask) tensordict.set(self.out_keys[0], x) tensordict.set(self.out_keys[1], mask) @@ -638,4 +667,30 @@ def forward(self, tensordict): x = self.consistent_dropout(x, mask=mask) tensordict.set(self.out_keys[0], x) - return tensordict \ No newline at end of file + return tensordict + + def make_tensordict_primer(self): + """Makes a tensordict primer for the environment to generate random masks during reset calls. + + .. seealso:: :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + + """ + from torchrl.envs import TensorDictPrimer + + shape = self.input_shape + dtype = self.input_dtype + if dtype is None: + dtype = torch.get_default_dtype() + if shape is None: + raise RuntimeError( + "Cannot infer the shape of the input automatically. " + "Please pass the shape of the tensor to `ConstistentDropoutModule` during construction " + "with the `input_shape` kwarg." + ) + return TensorDictPrimer( + primers={self.in_keys[1]: Unbounded(dtype=dtype, shape=shape)}, + default_value=functools.partial( + self.consistent_dropout.make_mask, shape=shape + ), + ) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 48756683c11..1f19478f631 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -387,7 +387,7 @@ class LSTMModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. @@ -534,6 +534,9 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker @@ -1108,7 +1111,7 @@ class GRUModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. Examples: @@ -1280,6 +1283,9 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker From 4605595e95398ed3bacb86ca8a0ac4c774dee5ba Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Sep 2024 12:13:49 +0100 Subject: [PATCH 4/4] minor edits --- docs/source/reference/modules.rst | 7 ++++++- test/test_exploration.py | 6 +++--- torchrl/modules/models/exploration.py | 5 +++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 64c83daead3..2d6a6344970 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -60,12 +60,17 @@ projected (in a L1-manner) into the desired domain. Exploration wrappers and modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To efficiently explore the environment, TorchRL proposes a series of wrappers +To efficiently explore the environment, TorchRL proposes a series of modules that will override the action sampled by the policy by a noisier version. Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`: if the exploration is set to ``"random"``, the exploration is active. In all other cases, the action written in the tensordict is simply the network output. +.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + this module. + .. currentmodule:: torchrl.modules .. autosummary:: diff --git a/test/test_exploration.py b/test/test_exploration.py index 64b445f3f2c..bbf1189431d 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -742,7 +742,7 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s @pytest.mark.parametrize("parallel_spec", [False, True]) @pytest.mark.parametrize( "device", - [torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")], + get_default_devices() ) def test_consistent_dropout(dropout_p, parallel_spec, device): """ @@ -770,7 +770,7 @@ def inner_verify_routine(module, env): device=device, ) for frames in collector: - masks = [(key, value) for key, value in frames.items() if "mask_" in key] + masks = [(key, value) for key, value in frames.items() if key.startswith("mask_")] # Assert rollouts do indeed correctly generate the masks. assert len(masks) == 1, ( "Expected exactly ONE mask since we only put " @@ -824,7 +824,7 @@ def inner_verify_routine(module, env): TensorDictModule( nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"] ), - ConsistentDropoutModule(p=dropout_p, in_key="out"), + ConsistentDropoutModule(p=dropout_p, in_keys="out"), TensorDictModule( NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"] ), diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 01e0ed015db..66b262c6df2 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -548,6 +548,11 @@ class ConsistentDropout(_DropoutNd): ..note:: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode but not in `eval` mode, so the dropout masks will be applied unless the policy passed to the collector is in eval mode. + .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + this module. + Args: p (float, optional): Dropout probability. Defaults to ``0.5``.