Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 20, 2024
1 parent 33aa1cb commit fac26d1
Show file tree
Hide file tree
Showing 18 changed files with 32 additions and 133 deletions.
2 changes: 0 additions & 2 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -979,11 +979,9 @@ Helpers

RandomPolicy
check_env_specs
exploration_mode #deprecated
exploration_type
get_available_libraries
make_composite_from_td
set_exploration_mode #deprecated
set_exploration_type
step_mdp
terminated_or_truncated
Expand Down
6 changes: 3 additions & 3 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ Exploration wrappers and modules

To efficiently explore the environment, TorchRL proposes a series of modules
that will override the action sampled by the policy by a noisier version.
Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`:
if the exploration is set to ``"random"``, the exploration is active. In all
Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`:
if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all
other cases, the action written in the tensordict is simply the network output.

.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule`
uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch.
The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on
The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on
this module.

.. currentmodule:: torchrl.modules
Expand Down
4 changes: 2 additions & 2 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, set_exploration_mode
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
Expand Down Expand Up @@ -201,7 +201,7 @@
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
logs["lr"].append(optim.param_groups[0]["lr"])
lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
with set_exploration_mode("mean"), torch.no_grad():
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
Expand Down
8 changes: 3 additions & 5 deletions sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import set_gym_backend
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
DTActor,
OnlineDTActor,
Expand Down Expand Up @@ -374,13 +374,12 @@ def make_odt_model(cfg):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
cache_dist=False,
return_log_prob=False,
)

# init the lazy layers
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = proof_environment.rollout(max_steps=100)
td["action"] = td["next", "action"]
actor(td)
Expand Down Expand Up @@ -428,13 +427,12 @@ def make_dt_model(cfg):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
cache_dist=False,
return_log_prob=False,
)

# init the lazy layers
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = proof_environment.rollout(max_steps=100)
td["action"] = td["next", "action"]
actor(td)
Expand Down
1 change: 0 additions & 1 deletion sota-implementations/redq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ collector:
multi_step: 1
n_steps_return: 3
max_frames_per_traj: -1
exploration_mode: random

logger:
backend: wandb
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def make_collector_offpolicy(
"init_random_frames": cfg.collector.init_random_frames,
"split_trajs": True,
# trajectories must be separated if multi-step is used
"exploration_type": ExplorationType.from_str(cfg.collector.exploration_mode),
"exploration_type": cfg.collector.exploration_type,
}

collector = collector_helper(**collector_helper_kwargs)
Expand Down
25 changes: 1 addition & 24 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def __init__(
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
split_trajs: bool | None = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
exploration_mode: str | None = None,
return_same_td: bool = False,
reset_when_done: bool = True,
interruptor=None,
Expand All @@ -456,9 +455,6 @@ def __init__(
from torchrl.envs.batched_envs import BatchedEnvBase

self.closed = True
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
if create_env_kwargs is None:
create_env_kwargs = {}
if not isinstance(create_env_fn, EnvBase):
Expand Down Expand Up @@ -1421,7 +1417,7 @@ class _MultiDataCollector(DataCollectorBase):
A ``cat_results`` value of ``-1`` will always concatenate results along the
time dimension. This should be preferred over the default. Intermediate values
are also accepted.
Defaults to ``0``.
Defaults to ``"stack"``.
.. note:: From v0.5, this argument will default to ``"stack"`` for a better
interoperability with the rest of the library.
Expand Down Expand Up @@ -1462,7 +1458,6 @@ def __init__(
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
exploration_mode=None,
reset_when_done: bool = True,
update_at_each_batch: bool = False,
preemptive_threshold: float = None,
Expand All @@ -1474,9 +1469,6 @@ def __init__(
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
self.closed = True
self.num_workers = len(create_env_fn)

Expand Down Expand Up @@ -2156,19 +2148,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
cat_results = self.cat_results
if cat_results is None:
cat_results = "stack"
warnings.warn(
f"`cat_results` was not specified in the constructor of {type(self).__name__}. "
f"For MultiSyncDataCollector, `cat_results` indicates how the data should "
f"be packed: the preferred option and current default is `cat_results='stack'` "
f"which provides the best interoperability across torchrl components. "
f"Other accepted values are `cat_results=0` (previous behavior) and "
f"`cat_results=-1` (cat along time dimension). Among these two, the latter "
f"should be preferred for consistency across environment configurations. "
f"Currently, the default value is `'stack'`."
f"From v0.6 onward, this warning will be removed. "
f"To suppress this warning, set `cat_results` to the desired value.",
category=DeprecationWarning,
)

self.buffers = {}
dones = [False for _ in range(self.num_workers)]
Expand Down Expand Up @@ -2749,7 +2728,6 @@ def __init__(
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
exploration_mode=None,
reset_when_done: bool = True,
update_at_each_batch: bool = False,
preemptive_threshold: float = None,
Expand All @@ -2774,7 +2752,6 @@ def __init__(
env_device=env_device,
storing_device=storing_device,
exploration_type=exploration_type,
exploration_mode=exploration_mode,
reset_when_done=reset_when_done,
update_at_each_batch=update_at_each_batch,
preemptive_threshold=preemptive_threshold,
Expand Down
4 changes: 0 additions & 4 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def __init__(
postproc: Callable | None = None,
split_trajs: bool = False,
exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa
exploration_mode: str = None,
collector_class: Type = SyncDataCollector,
collector_kwargs: dict = None,
num_workers_per_collector: int = 1,
Expand All @@ -438,9 +437,6 @@ def __init__(
launcher: str = "submitit",
tcp_port: int = None,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)

if collector_class == "async":
collector_class = MultiaSyncDataCollector
Expand Down
4 changes: 0 additions & 4 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def __init__(
postproc: Callable | None = None,
split_trajs: bool = False,
exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa
exploration_mode: str = None,
collector_class=SyncDataCollector,
collector_kwargs=None,
num_workers_per_collector=1,
Expand All @@ -288,9 +287,6 @@ def __init__(
visible_devices=None,
tensorpipe_options=None,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)
if collector_class == "async":
collector_class = MultiaSyncDataCollector
elif collector_class == "sync":
Expand Down
4 changes: 0 additions & 4 deletions torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def __init__(
postproc: Callable | None = None,
split_trajs: bool = False,
exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa
exploration_mode: str = None,
collector_class=SyncDataCollector,
collector_kwargs=None,
num_workers_per_collector=1,
Expand All @@ -302,9 +301,6 @@ def __init__(
launcher="submitit",
tcp_port=None,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
)

if collector_class == "async":
collector_class = MultiaSyncDataCollector
Expand Down
2 changes: 0 additions & 2 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@
from .utils import (
check_env_specs,
check_marl_grouping,
exploration_mode,
exploration_type,
ExplorationType,
make_composite_from_td,
MarlGroupMapType,
set_exploration_mode,
set_exploration_type,
step_mdp,
)
23 changes: 12 additions & 11 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,15 +2136,23 @@ class UnsqueezeTransform(Transform):
"""Inserts a dimension of size one at the specified position.
Args:
unsqueeze_dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim
dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim
must be turned on).
Keyword Args:
allow_positive_dim (bool, optional): if ``True``, positive dimensions are accepted.
:obj:`UnsqueezeTransform` will map these to the n^th feature dimension
`UnsqueezeTransform`` will map these to the n^th feature dimension
(ie n^th dimension after batch size of parent env) of the input tensor,
independently from the tensordict batch size (ie positive dims may be
independently of the tensordict batch size (ie positive dims may be
dangerous in contexts where tensordict of different batch dimension
are passed).
Defaults to False, ie. non-negative dimensions are not permitted.
in_keys (list of NestedKeys): input entries (read).
out_keys (list of NestedKeys): input entries (write). Defaults to ``in_keys`` if
not provided.
in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls.
out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls.
Defaults to ``in_keys_in`` if not provided.
"""

invertible = True
Expand All @@ -2157,20 +2165,13 @@ def __new__(cls, *args, **kwargs):
def __init__(
self,
dim: int = None,
*,
allow_positive_dim: bool = False,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
**kwargs,
):
if "unsqueeze_dim" in kwargs:
warnings.warn(
"The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead."
)
dim = kwargs["unsqueeze_dim"]
elif dim is None:
raise TypeError("dim must be provided.")
if in_keys is None:
in_keys = [] # default
if out_keys is None:
Expand Down
11 changes: 1 addition & 10 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,8 @@
from tensordict.base import _is_leaf_nontensor
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.nn.probabilistic import ( # noqa
# Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated!
# Please use the `set_/interaction_type` ones above with the InteractionType enum instead.
# See more details: https://github.com/pytorch/rl/issues/1016
interaction_mode as exploration_mode,
interaction_type as exploration_type,
InteractionType as ExplorationType,
set_interaction_mode as set_exploration_mode,
set_interaction_type as set_exploration_type,
)
from tensordict.utils import is_non_tensor, NestedKey
Expand All @@ -55,9 +50,7 @@
from torchrl.data.utils import check_no_exclusive_keys

__all__ = [
"exploration_mode",
"exploration_type",
"set_exploration_mode",
"set_exploration_type",
"ExplorationType",
"check_env_specs",
Expand All @@ -79,9 +72,7 @@
)


def _convert_exploration_type(*, exploration_mode, exploration_type):
if exploration_mode is not None:
return ExplorationType.from_str(exploration_mode)
def _convert_exploration_type(*, exploration_type):
return exploration_type


Expand Down
Loading

0 comments on commit fac26d1

Please sign in to comment.