diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index ede1421ffc9..9b5bcadd433 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -827,6 +827,7 @@ to be able to create this other composition: FlattenObservation FrameSkipTransform GrayScale + Hash InitTracker KLRewardTransform LineariseReward @@ -853,6 +854,7 @@ to be able to create this other composition: TimeMaxPool ToTensorImage TrajCounter + UnaryTransform UnsqueezeTransform VC1Transform VIPRewardTransform diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 6f666290376..0531bff10df 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import random +import string from typing import Dict, List, Optional import torch @@ -1066,6 +1068,34 @@ def _step( return tensordict +class CountingEnvWithString(CountingEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.observation_spec.set( + "string", + NonTensor( + shape=self.batch_size, + device=self.device, + ), + ) + + def get_random_string(self): + size = random.randint(4, 30) + return "".join(random.choice(string.ascii_lowercase) for _ in range(size)) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + res = super()._reset(tensordict, **kwargs) + random_string = self.get_random_string() + res["string"] = random_string + return res + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + res = super()._step(tensordict) + random_string = self.get_random_string() + res["string"] = random_string + return res + + class MultiAgentCountingEnv(EnvBase): """A multi-agent env that is done after a given number of steps. diff --git a/test/test_transforms.py b/test/test_transforms.py index 7a01acdaeef..ec413b2b34c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -39,6 +39,7 @@ CountingBatchedEnv, CountingEnv, CountingEnvCountPolicy, + CountingEnvWithString, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, EnvWithScalarAction, @@ -66,6 +67,7 @@ CountingBatchedEnv, CountingEnv, CountingEnvCountPolicy, + CountingEnvWithString, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, EnvWithScalarAction, @@ -77,7 +79,7 @@ MultiKeyCountingEnvPolicy, NestedCountingEnv, ) -from tensordict import TensorDict, TensorDictBase, unravel_key +from tensordict import NonTensorData, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictSequential from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor @@ -118,6 +120,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, LineariseRewards, MultiStepTransform, @@ -2180,6 +2183,246 @@ def test_transform_no_env(self, device, batch): pytest.skip("TrajCounter cannot be called without env") +class TestHash(TransformBase): + @pytest.mark.parametrize("datatype", ["tensor", "str", "NonTensorStack"]) + def test_transform_no_env(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + hash_fn = hash + elif datatype == "str": + obs = "abcdefg" + hash_fn = Hash.reproducible_hash + elif datatype == "NonTensorStack": + obs = torch.stack( + [ + NonTensorData(data="abcde"), + NonTensorData(data="fghij"), + NonTensorData(data="klmno"), + ] + ) + + def fn0(x): + return torch.stack([Hash.reproducible_hash(x_) for x_ in x]) + + hash_fn = fn0 + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + + t = Hash(in_keys=["observation"], out_keys=["hashing"], hash_fn=hash_fn) + td_hashed = t(td) + + assert td_hashed.get("observation") is td.get("observation") + + if datatype == "NonTensorStack": + assert ( + td_hashed["hashing"] == hash_fn(td.get("observation").tolist()) + ).all() + elif datatype == "str": + assert all(td_hashed["hashing"] == hash_fn(td["observation"])) + else: + assert td_hashed["hashing"] == hash_fn(td["observation"]) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_single_trans_env_check(self, datatype): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=hash, + ) + base_env = CountingEnv() + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hashing"], + ) + base_env = CountingEnvWithString() + env = TransformedEnv(base_env, t) + check_env_specs(env) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_serial_trans_env_check(self, datatype): + def make_env(): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=hash, + ) + base_env = CountingEnv() + + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hashing"], + ) + base_env = CountingEnvWithString() + + return TransformedEnv(base_env, t) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype): + def make_env(): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=hash, + ) + base_env = CountingEnv() + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hashing"], + ) + base_env = CountingEnvWithString() + return TransformedEnv(base_env, t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_trans_serial_env_check(self, datatype): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=lambda x: [hash(x[0]), hash(x[1])], + ) + base_env = CountingEnv + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hashing"], + hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]), + ) + base_env = CountingEnvWithString + + env = TransformedEnv(SerialEnv(2, base_env), t) + check_env_specs(env) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=lambda x: [hash(x[0]), hash(x[1])], + ) + base_env = CountingEnv + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hashing"], + hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]), + ) + base_env = CountingEnvWithString + + env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_transform_compose(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + elif datatype == "str": + obs = "abcdefg" + + td = TensorDict( + { + "observation": obs, + } + ) + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=hash, + ) + t = Compose(t) + td_hashed = t(td) + + assert td_hashed["observation"] is td["observation"] + assert td_hashed["hashing"] == hash(td["observation"]) + + def test_transform_model(self): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hashing"), ("hashing",)], + hash_fn=hash, + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict( + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] + ) + td_out = model(td) + assert ("next", "hashing") in td_out.keys(True) + assert ("hashing",) in td_out.keys(True) + assert td_out["next", "hashing"] == hash(td["next", "observation"]) + assert td_out["hashing"] == hash(td["observation"]) + + @pytest.mark.skipif(not _has_gym, reason="Gym not found") + def test_transform_env(self): + t = Hash( + in_keys=["observation"], + out_keys=["hashing"], + hash_fn=hash, + ) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) + assert env.observation_spec["hashing"] + assert "observation" in env.observation_spec + assert "observation" in env.base_env.observation_spec + check_env_specs(env) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hashing"), ("hashing",)], + hash_fn=lambda x: [hash(x[0]), hash(x[1])], + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": torch.randn(3, 4), + "next": TensorDict( + {"observation": torch.randn(3, 4)}, + [], + ), + }, + [], + ).expand(10) + rb.extend(td) + td = rb.sample(2) + assert "hashing" in td.keys() + assert "observation" in td.keys() + assert ("next", "observation") in td.keys(True) + + def test_transform_inverse(self): + raise pytest.skip("No inverse for Hash") + + class TestStack(TransformBase): def test_single_trans_env_check(self): t = Stack( diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index bcb50899549..3a4cde38aa2 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -67,6 +67,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, KLRewardTransform, LineariseRewards, @@ -97,6 +98,7 @@ TrajCounter, Transform, TransformedEnv, + UnaryTransform, UnsqueezeTransform, VC1Transform, VecGymEnvTransform, diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 9853e8d516d..5936c939cec 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -9,6 +9,7 @@ import warnings from typing import Dict, List, Tuple, Union +import numpy as np import packaging import torch from tensordict import TensorDictBase @@ -72,6 +73,19 @@ def _load_available_envs() -> Dict: return all_environments +def _extract_nested_with_index( + data: Union[np.ndarray, Dict[str, np.ndarray]], index: int +): + if isinstance(data, np.ndarray): + return data[index] + elif isinstance(data, dict): + return { + key: _extract_nested_with_index(value, index) for key, value in data.items() + } + else: + raise NotImplementedError(f"Invalid type of data {data}") + + class PettingZooWrapper(_EnvWrapper): """PettingZoo environment wrapper. @@ -735,7 +749,9 @@ def _step_parallel( "full_action_spec", group, "action" ].to_numpy(group_action) for index, agent in enumerate(agents): - action_dict[agent] = group_action_np[index] + # group_action_np can be a dict or an array. We need to recursively index it + action = _extract_nested_with_index(group_action_np, index) + action_dict[agent] = action return self._env.step(action_dict) @@ -750,7 +766,8 @@ def _step_aec( group_action_np = self.input_spec[ "full_action_spec", group, "action" ].to_numpy(group_action) - action = group_action_np[agent_index] + # group_action_np can be a dict or an array. We need to recursively index it + action = _extract_nested_with_index(group_action_np, agent_index) break self._env.step(action) diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 9e261eee8f2..a25c676e378 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -31,6 +31,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, LineariseRewards, NoopResetEnv, @@ -58,6 +59,7 @@ TrajCounter, Transform, TransformedEnv, + UnaryTransform, UnsqueezeTransform, VecGymEnvTransform, VecNorm, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ea4b32d7300..aa8c64bacc3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools +import hashlib import importlib.util import multiprocessing as mp import warnings @@ -35,6 +36,7 @@ is_tensor_collection, LazyStackedTensorDict, NonTensorData, + NonTensorStack, set_lazy_legacy, TensorDict, TensorDictBase, @@ -81,7 +83,12 @@ _set_missing_tolerance, check_finite, ) -from torchrl.envs.utils import _sort_keys, _update_during_reset, step_mdp +from torchrl.envs.utils import ( + _sort_keys, + _update_during_reset, + make_composite_from_td, + step_mdp, +) from torchrl.objectives.value.functional import reward2go _has_tv = importlib.util.find_spec("torchvision", None) is not None @@ -4402,6 +4409,325 @@ def __repr__(self) -> str: ) +class UnaryTransform(Transform): + r"""Applies a unary operation on the specified inputs. + + Args: + in_keys (sequence of NestedKey): the keys of inputs to the unary operation. + out_keys (sequence of NestedKey): the keys of the outputs of the unary operation. + fn (Callable): the function to use as the unary operation. If it accepts + a non-tensor input, it must also accept ``None``. + + Keyword Args: + use_raw_nontensor (bool, optional): if ``False``, data is extracted from + :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called + on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` + inputs are given directly to ``fn``, which must support those + inputs. Default is ``False``. + + Example: + >>> from torchrl.envs import GymEnv, UnaryTransform + >>> env = GymEnv("Pendulum-v1") + >>> env = env.append_transform( + ... UnaryTransform( + ... in_keys=["observation"], + ... out_keys=["observation_trsf"], + ... fn=lambda tensor: str(tensor.numpy().tobytes()))) + >>> env.observation_spec + Composite( + observation: BoundedContinuous( + shape=torch.Size([3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + observation_trsf: NonTensor( + shape=torch.Size([]), + space=None, + device=cpu, + dtype=None, + domain=None), + device=None, + shape=torch.Size([])) + >>> env.rollout(3) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_trsf: NonTensorStack( + ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x..., + batch_size=torch.Size([3]), + device=None), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_trsf: NonTensorStack( + ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\..., + batch_size=torch.Size([3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + >>> env.check_env_specs() + [torchrl][INFO] check_env_specs succeeded! + + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + fn: Callable, + *, + use_raw_nontensor: bool = False, + ): + super().__init__(in_keys=in_keys, out_keys=out_keys) + self._fn = fn + self._use_raw_nontensor = use_raw_nontensor + + def _apply_transform(self, value): + if not self._use_raw_nontensor: + if isinstance(value, NonTensorData): + if value.dim() == 0: + value = value.get("data") + else: + value = value.tolist() + elif isinstance(value, NonTensorStack): + value = value.tolist() + return self._fn(value) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + def transform_output_spec(self, output_spec: Composite) -> Composite: + output_spec = output_spec.clone() + + # Make a generic input from the spec, call the transform with that + # input, and then generate the output spec from the output. + zero_input_ = output_spec.zero() + test_input = ( + zero_input_["full_observation_spec"] + .update(zero_input_["full_reward_spec"]) + .update(zero_input_["full_done_spec"]) + ) + test_output = self.forward(test_input) + test_output_spec = make_composite_from_td( + test_output, unsqueeze_null_shapes=False + ) + + output_spec["full_observation_spec"] = self.transform_observation_spec( + output_spec["full_observation_spec"], + test_output_spec, + ) + if "full_reward_spec" in output_spec.keys(): + output_spec["full_reward_spec"] = self.transform_reward_spec( + output_spec["full_reward_spec"], + test_output_spec, + ) + if "full_done_spec" in output_spec.keys(): + output_spec["full_done_spec"] = self.transform_done_spec( + output_spec["full_done_spec"], + test_output_spec, + ) + return output_spec + + def _transform_spec( + self, spec: TensorSpec, test_output_spec: TensorSpec + ) -> TensorSpec: + if not isinstance(spec, Composite): + raise TypeError(f"{self}: Only specs of type Composite can be transformed") + + spec_keys = set(spec.keys(include_nested=True)) + + for in_key, out_key in zip(self.in_keys, self.out_keys): + if in_key in spec_keys: + spec.set(out_key, test_output_spec[out_key]) + return spec + + def transform_observation_spec( + self, observation_spec: TensorSpec, test_output_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(observation_spec, test_output_spec) + + def transform_reward_spec( + self, reward_spec: TensorSpec, test_output_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(reward_spec, test_output_spec) + + def transform_done_spec( + self, done_spec: TensorSpec, test_output_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(done_spec, test_output_spec) + + +class Hash(UnaryTransform): + r"""Adds a hash value to a tensordict. + + Args: + in_keys (sequence of NestedKey): the keys of the values to hash. + out_keys (sequence of NestedKey): the keys of the resulting hashes. + hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, + the hash function must accept it as its second argument. Default is + ``Hash.reproducible_hash``. + seed (optional): seed to use for the hash function, if it requires one. + + Keyword Args: + use_raw_nontensor (bool, optional): if ``False``, data is extracted from + :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called + on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` + inputs are given directly to ``fn``, which must support those + inputs. Default is ``False``. + + >>> from torchrl.envs import GymEnv, UnaryTransform, Hash + >>> env = GymEnv("Pendulum-v1") + >>> # Add a string output + >>> env = env.append_transform( + ... UnaryTransform( + ... in_keys=["observation"], + ... out_keys=["observation_str"], + ... fn=lambda tensor: str(tensor.numpy().tobytes()))) + >>> # process the string output + >>> env = env.append_transform( + ... Hash( + ... in_keys=["observation_str"], + ... out_keys=["observation_hash"],) + ... ) + >>> env.observation_spec + Composite( + observation: BoundedContinuous( + shape=torch.Size([3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + observation_str: NonTensor( + shape=torch.Size([]), + space=None, + device=cpu, + dtype=None, + domain=None), + observation_hash: UnboundedDiscrete( + shape=torch.Size([32]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True), + high=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True)), + device=cpu, + dtype=torch.uint8, + domain=discrete), + device=None, + shape=torch.Size([])) + >>> env.rollout(3) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), + observation_str: NonTensorStack( + ["b'g\\x08\\x8b\\xbexav\\xbf\\x00\\xee(>'", "b'\\x..., + batch_size=torch.Size([3]), + device=None), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), + observation_str: NonTensorStack( + ["b'\\xb5\\x17\\x8f\\xbe\\x88\\xccu\\xbf\\xc0Vr?'"..., + batch_size=torch.Size([3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + >>> env.check_env_specs() + [torchrl][INFO] check_env_specs succeeded! + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + hash_fn: Callable = None, + seed: Any | None = None, + *, + use_raw_nontensor: bool = False, + ): + if hash_fn is None: + hash_fn = Hash.reproducible_hash + + self._seed = seed + self._hash_fn = hash_fn + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + fn=self.call_hash_fn, + use_raw_nontensor=use_raw_nontensor, + ) + + def call_hash_fn(self, value): + if self._seed is None: + return self._hash_fn(value) + else: + return self._hash_fn(value, self._seed) + + @classmethod + def reproducible_hash(cls, string, seed=None): + """Creates a reproducible 256-bit hash from a string using a seed. + + Args: + string (str or None): The input string. If ``None``, null string ``""`` is used. + seed (str, optional): The seed value. Default is ``None``. + + Returns: + Tensor: Shape ``(32,)`` with dtype ``torch.uint8``. + """ + if string is None: + string = "" + + # Prepend the seed to the string + if seed is not None: + seeded_string = seed + string + else: + seeded_string = string + + # Create a new SHA-256 hash object + hash_object = hashlib.sha256() + + # Update the hash object with the seeded string + hash_object.update(seeded_string.encode("utf-8")) + + # Get the hash value as bytes + hash_bytes = bytearray(hash_object.digest()) + + return torch.frombuffer(hash_bytes, dtype=torch.uint8) + + class Stack(Transform): """Stacks tensors and tensordicts. diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index f7403e6a69e..878df8dad07 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -42,6 +42,7 @@ from torchrl.data.tensor_specs import ( Composite, NO_DEFAULT_RL as NO_DEFAULT, + NonTensor, TensorSpec, Unbounded, ) @@ -907,6 +908,8 @@ def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) + else NonTensor(shape=data.shape, device=tensor.device) + if is_non_tensor(tensor) else Unbounded( dtype=tensor.dtype, device=tensor.device,