Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 15, 2025
2 parents d05004b + dc25a55 commit 5ca4d08
Show file tree
Hide file tree
Showing 21 changed files with 172 additions and 15 deletions.
13 changes: 13 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
from torch import autograd, nn
from torchrl._utils import _standardize
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
from torchrl.data.postprocs.postprocs import MultiStep
from torchrl.envs.model_based.dreamer import DreamerEnv
Expand Down Expand Up @@ -16044,6 +16045,18 @@ def _composite_log_prob(self):
yield
setter.unset()

def test_standardization(self):
t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6)
std_t0 = _standardize(t, exclude_dims=(1, 3))
std_t1 = (t - t.mean((0, 2), keepdim=True)) / t.std((0, 2), keepdim=True).clamp(
1 - 6
)
torch.testing.assert_close(std_t0, std_t1)
std_t = _standardize(t, (), -1, 2)
torch.testing.assert_close(std_t, (t + 1) / 2)
std_t = _standardize(t, ())
torch.testing.assert_close(std_t, (t - t.mean()) / t.std())

@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
@pytest.mark.parametrize("T", [1, 10])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down
68 changes: 66 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
from distutils.util import strtobool
from functools import wraps
from importlib import import_module
from typing import Any, Callable, cast, Dict, TypeVar, Union
from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union

import numpy as np
import torch
from packaging.version import parse
from tensordict import unravel_key

from tensordict.utils import NestedKey
from torch import multiprocessing as mp
from torch import multiprocessing as mp, Tensor

try:
from torch.compiler import is_compiling
Expand Down Expand Up @@ -872,6 +872,70 @@ def set_mode(self, type: Any | None) -> None:
self._mode = type


def _standardize(
input: Tensor,
exclude_dims: Tuple[int] = (),
mean: Tensor | None = None,
std: Tensor | None = None,
eps: float | None = None,
):
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.
Useful when processing multi-agent data to keep the agent dimensions independent.
Args:
input (Tensor): the input tensor to be standardized.
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
"""
if eps is None:
if input.dtype.is_floating_point:
eps = torch.finfo(torch.float).resolution
else:
eps = 1e-6

len_exclude_dims = len(exclude_dims)
if not len_exclude_dims:
if mean is None:
mean = input.mean()
else:
# Assume dtypes are compatible
mean = torch.as_tensor(mean, device=input.device)
if std is None:
std = input.std()
else:
# Assume dtypes are compatible
std = torch.as_tensor(std, device=input.device)
return (input - mean) / std.clamp_min(eps)

input_shape = input.shape
exclude_dims = [
d if d >= 0 else d + len(input_shape) for d in exclude_dims
] # Make negative dims positive

if len(set(exclude_dims)) != len_exclude_dims:
raise ValueError("Exclude dims has repeating elements")
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
raise ValueError(
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
)
if len_exclude_dims == len(input_shape):
warnings.warn(
"_standardize called but all dims were excluded from the statistics, returning unprocessed input"
)
return input

included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims)
if mean is None:
mean = torch.mean(input, keepdim=True, dim=included_dims)
if std is None:
std = torch.std(input, keepdim=True, dim=included_dims)
return (input - mean) / std.clamp_min(eps)


@wraps(torch.compile)
def compile_with_warmup(*args, warmup: int = 1, **kwargs):
"""Compile a model with warm-up.
Expand Down
15 changes: 11 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
__doc__ += BatchedEnvBase.__doc__
__doc__ += """
.. note:: ParallelEnv will timeout after one of the worker is idle for a determinate amount of time.
This can be controlled via the BATCHED_PIPE_TIMEOUT environment variable, which in turn modifies
the torchrl._utils.BATCHED_PIPE_TIMEOUT integer. The default timeout value is 10000 seconds.
.. warning::
TorchRL's ParallelEnv is quite stringent when it comes to env specs, since
these are used to build shared memory buffers for inter-process communication.
Expand Down Expand Up @@ -1353,7 +1357,10 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
"""

def _start_workers(self) -> None:
import torchrl

self._timeout = 10.0
self.BATCHED_PIPE_TIMEOUT = torchrl._utils.BATCHED_PIPE_TIMEOUT

from torchrl.envs.env_creator import EnvCreator

Expand Down Expand Up @@ -1606,7 +1613,7 @@ def step_and_maybe_reset(

for i in workers_range:
event = self._events[i]
event.wait(self._timeout)
event.wait(self.BATCHED_PIPE_TIMEOUT)
event.clear()

if self._non_tensor_keys:
Expand Down Expand Up @@ -1796,7 +1803,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

for i in workers_range:
event = self._events[i]
event.wait(self._timeout)
event.wait(self.BATCHED_PIPE_TIMEOUT)
event.clear()

if self._non_tensor_keys:
Expand Down Expand Up @@ -1965,7 +1972,7 @@ def tentative_update(val, other):

for i, _ in outs:
event = self._events[i]
event.wait(self._timeout)
event.wait(self.BATCHED_PIPE_TIMEOUT)
event.clear()

workers_nontensor = []
Expand Down Expand Up @@ -2023,7 +2030,7 @@ def _shutdown_workers(self) -> None:
for channel in self.parent_channels:
channel.close()
for proc in self._workers:
proc.join(timeout=1.0)
proc.join(timeout=self._timeout)
finally:
for proc in self._workers:
if proc.is_alive():
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings
from typing import Dict, List, Optional, Type, Union
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def __post_init__(self):
self.sample_log_prob = "action_log_prob"

default_keys = _AcceptedKeys
tensor_keys: _AcceptedKeys
default_value_estimator: ValueEstimators = ValueEstimators.GAE

actor_network: TensorDictModule
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class _AcceptedKeys:

pass

tensor_keys: _AcceptedKeys
_vmap_randomness = None
default_value_estimator: ValueEstimators = None

Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0

Expand Down Expand Up @@ -1024,6 +1025,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"
pred_val: NestedKey = "pred_val"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"
log_prob: NestedKey = "_log_prob"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0

Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator: ValueEstimators = ValueEstimators.TD0
out_keys = [
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class _AcceptedKeys:
# the "action" output from the model
action_pred: NestedKey = "action"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys

actor_network: TensorDictModule
Expand Down Expand Up @@ -280,6 +281,7 @@ class _AcceptedKeys:
# the "action" output from the model
action_pred: NestedKey = "action"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys

actor_network: TensorDictModule
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = ["loss"]
Expand Down Expand Up @@ -435,6 +436,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"
steps_to_next_obs: NestedKey = "steps_to_next_obs"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0

Expand Down
12 changes: 12 additions & 0 deletions torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,13 @@ class _AcceptedKeys:
pixels: NestedKey = "pixels"
reco_pixels: NestedKey = "reco_pixels"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys

decoder: TensorDictModule
reward_model: TensorDictModule
world_mdel: TensorDictModule

def __init__(
self,
world_model: TensorDictModule,
Expand Down Expand Up @@ -238,9 +243,13 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TDLambda

value_model: TensorDictModule
actor_model: TensorDictModule

def __init__(
self,
actor_model: TensorDictModule,
Expand Down Expand Up @@ -392,8 +401,11 @@ class _AcceptedKeys:

value: NestedKey = "state_value"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys

value_model: TensorDictModule

def __init__(
self,
value_model: TensorDictModule,
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class _AcceptedKeys:
collector_observation: NestedKey = "collector_observation"
discriminator_pred: NestedKey = "d_logits"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys

discriminator_network: TensorDictModule
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
Expand Down Expand Up @@ -709,6 +710,7 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = ["loss"]
Expand Down
Loading

0 comments on commit 5ca4d08

Please sign in to comment.