Skip to content

Commit

Permalink
[Feature,Doc] get_stateful_net and document loss initialization
Browse files Browse the repository at this point in the history
ghstack-source-id: 98830fd00d790dc622f1fec7b7220f2f9e747d81
Pull Request resolved: #2310
  • Loading branch information
vmoens committed Jul 23, 2024
1 parent 1e654c3 commit fa6efcd
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 3 deletions.
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ The main characteristics of TorchRL losses are:

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

.. note::
Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net`
which will return a stateful version of the network that can be initialized like any other module.
If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter
set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss
will also modify the actor in the collector.
If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be
used to reset the parameters in the loss to the new value.

Training value functions
------------------------

Expand Down
56 changes: 56 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -14549,6 +14549,62 @@ def __init__(self, compare_against, expand_dim):
for key in ["module.1.bias", "module.1.weight"]:
loss_module.module_b_params.flatten_keys()[key].requires_grad

def test_init_params(self):
class MyLoss(LossModule):
module_a: TensorDictModule
module_b: TensorDictModule
module_a_params: TensorDict
module_b_params: TensorDict
target_module_a_params: TensorDict
target_module_b_params: TensorDict

def __init__(self, expand_dim=2):
super().__init__()
module1 = nn.Linear(3, 4)
module2 = nn.Linear(3, 4)
module3 = nn.Linear(3, 4)
module_a = TensorDictModule(
nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"]
)
module_b = TensorDictModule(
nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"]
)
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
module_b,
"module_b",
compare_against=module_a.parameters(),
expand_dim=expand_dim,
)

loss = MyLoss()

module_a = loss.get_stateful_net("module_a", copy=False)
assert module_a is loss.module_a

module_a = loss.get_stateful_net("module_a")
assert module_a is not loss.module_a

def init(mod):
if hasattr(mod, "weight"):
mod.weight.data.zero_()
if hasattr(mod, "bias"):
mod.bias.data.zero_()

module_a.apply(init)
assert (loss.module_a_params == 0).all()

def init(mod):
if hasattr(mod, "weight"):
mod.weight = torch.nn.Parameter(mod.weight.data + 1)
if hasattr(mod, "bias"):
mod.bias = torch.nn.Parameter(mod.bias.data + 1)

module_a.apply(init)
assert (loss.module_a_params == 0).all()
loss.from_stateful_net("module_a", module_a)
assert (loss.module_a_params == 1).all()

def test_tensordict_keys(self):
"""Test configurable tensordict key behavior with derived classes."""

Expand Down
5 changes: 2 additions & 3 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def get_stateful_net(self, copy: bool = True):
This can be used to initialize parameters.
Such networks will generally not be callable out-of-the-box and will require some `vmap`
execution. to work
Such networks will often not be callable out-of-the-box and will require a `vmap` call
to be executable.
Args:
copy (bool, optional): if ``True``, a deepcopy of the network is made.
Expand Down Expand Up @@ -203,7 +203,6 @@ def get_stateful_net(self, copy: bool = True):
self.params.to_module(net)
return net

@abc.abstractmethod
def from_stateful_net(self, stateful_net: nn.Module):
"""Populates the parameters given a stateful version of the network.
Expand Down
62 changes: 62 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import abc
import functools
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple

Expand Down Expand Up @@ -138,6 +139,67 @@ def __init__(self):
self._tensor_keys = self._AcceptedKeys()
self.register_forward_pre_hook(_updater_check_forward_prehook)

@property
def functional(self):
"""Whether the module is functional.
Unless it has been specifically designed not to be functional, all losses are functional.
"""
return True

def get_stateful_net(self, network_name: str, copy: bool | None = None):
"""Returns a stateful version of the network.
This can be used to initialize parameters.
Such networks will often not be callable out-of-the-box and will require a `vmap` call
to be executable.
Args:
network_name (str): the network name to gather.
copy (bool, optional): if ``True``, a deepcopy of the network is made.
Defaults to ``True``.
.. note:: if the module is not functional, no copy is made.
"""
net = getattr(self, network_name)
if not self.functional:
if copy is not None and copy:
raise RuntimeError("Cannot copy module in non-functional mode.")
return net
copy = True if copy is None else copy
if copy:
net = deepcopy(net)
params = getattr(self, network_name + "_params")
params.to_module(net)
return net

def from_stateful_net(self, network_name: str, stateful_net: nn.Module):
"""Populates the parameters of a model given a stateful version of the network.
See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network.
Args:
network_name (str): the network name to reset.
stateful_net (nn.Module): the stateful network from which the params should be
gathered.
"""
if not self.functional:
getattr(self, network_name).load_state_dict(stateful_net.state_dict())
return
params = TensorDict.from_module(stateful_net, as_module=True)
keyset0 = set(params.keys(True, True))
self_params = getattr(self, network_name + "_params")
keyset1 = set(self_params.keys(True, True))
if keyset0 != keyset1:
raise RuntimeError(
f"The keys of params and provided module differ: "
f"{keyset1-keyset0} are in self.params and not in the module, "
f"{keyset0-keyset1} are in the module but not in self.params."
)
self_params.data.update_(params.data)

def _set_deprecated_ctor_keys(self, **kwargs) -> None:
for key, value in kwargs.items():
if value is not None:
Expand Down

0 comments on commit fa6efcd

Please sign in to comment.