diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 065d6a2e3d4..75accf68651 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -827,6 +827,7 @@ to be able to create this other composition: FlattenObservation FrameSkipTransform GrayScale + Hash InitTracker KLRewardTransform NoopResetEnv diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 6f666290376..0531bff10df 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import random +import string from typing import Dict, List, Optional import torch @@ -1066,6 +1068,34 @@ def _step( return tensordict +class CountingEnvWithString(CountingEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.observation_spec.set( + "string", + NonTensor( + shape=self.batch_size, + device=self.device, + ), + ) + + def get_random_string(self): + size = random.randint(4, 30) + return "".join(random.choice(string.ascii_lowercase) for _ in range(size)) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + res = super()._reset(tensordict, **kwargs) + random_string = self.get_random_string() + res["string"] = random_string + return res + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + res = super()._step(tensordict) + random_string = self.get_random_string() + res["string"] = random_string + return res + + class MultiAgentCountingEnv(EnvBase): """A multi-agent env that is done after a given number of steps. diff --git a/test/test_transforms.py b/test/test_transforms.py index 44ebce72c5c..7b19940a461 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -39,6 +39,7 @@ CountingBatchedEnv, CountingEnv, CountingEnvCountPolicy, + CountingEnvWithString, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, EnvWithScalarAction, @@ -66,6 +67,7 @@ CountingBatchedEnv, CountingEnv, CountingEnvCountPolicy, + CountingEnvWithString, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, EnvWithScalarAction, @@ -77,7 +79,7 @@ MultiKeyCountingEnvPolicy, NestedCountingEnv, ) -from tensordict import TensorDict, TensorDictBase, unravel_key +from tensordict import NonTensorData, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictSequential from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor @@ -116,6 +118,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, MultiStepTransform, NoopResetEnv, @@ -2177,6 +2180,259 @@ def test_transform_no_env(self, device, batch): pytest.skip("TrajCounter cannot be called without env") +class TestHash(TransformBase): + @pytest.mark.parametrize("datatype", ["tensor", "str", "NonTensorStack"]) + def test_transform_no_env(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + hash_fn = hash + elif datatype == "str": + obs = "abcdefg" + hash_fn = Hash.reproducible_hash + elif datatype == "NonTensorStack": + obs = torch.stack( + [ + NonTensorData(data="abcde"), + NonTensorData(data="fghij"), + NonTensorData(data="klmno"), + ] + ) + + def fn0(x): + return torch.stack([Hash.reproducible_hash(x_.get("data")) for x_ in x]) + + hash_fn = fn0 + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + + t = Hash(in_keys=["observation"], out_keys=["hash"], hash_fn=hash_fn) + td_hashed = t(td) + + assert td_hashed.get("observation") is td.get("observation") + + if datatype == "NonTensorStack": + assert (td_hashed["hash"] == hash_fn(td.get("observation"))).all() + elif datatype == "str": + assert all(td_hashed["hash"] == hash_fn(td["observation"])) + else: + assert td_hashed["hash"] == hash_fn(td["observation"]) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_single_trans_env_check(self, datatype): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + hash_fn=hash, + output_spec=Unbounded(shape=(), dtype=torch.int64), + ) + base_env = CountingEnv() + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hash"], + ) + base_env = CountingEnvWithString() + env = TransformedEnv(base_env, t) + check_env_specs(env) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_serial_trans_env_check(self, datatype): + def make_env(): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + hash_fn=hash, + output_spec=Unbounded(shape=(), dtype=torch.int64), + ) + base_env = CountingEnv() + + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hash"], + ) + base_env = CountingEnvWithString() + + return TransformedEnv(base_env, t) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype): + def make_env(): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + hash_fn=hash, + output_spec=Unbounded(shape=(), dtype=torch.int64), + ) + base_env = CountingEnv() + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hash"], + ) + base_env = CountingEnvWithString() + return TransformedEnv(base_env, t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_trans_serial_env_check(self, datatype): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + hash_fn=lambda x: [hash(x[0]), hash(x[1])], + output_spec=Unbounded(shape=(2,), dtype=torch.int64), + ) + base_env = CountingEnv + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hash"], + hash_fn=lambda x: torch.stack( + [Hash.reproducible_hash(x_.get("data")) for x_ in x] + ), + output_spec=Unbounded(shape=(2, 32), dtype=torch.uint8), + ) + base_env = CountingEnvWithString + + env = TransformedEnv(SerialEnv(2, base_env), t) + check_env_specs(env) + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): + if datatype == "tensor": + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + hash_fn=lambda x: [hash(x[0]), hash(x[1])], + output_spec=Unbounded(shape=(2,), dtype=torch.int64), + ) + base_env = CountingEnv + elif datatype == "str": + t = Hash( + in_keys=["string"], + out_keys=["hash"], + hash_fn=lambda x: torch.stack( + [Hash.reproducible_hash(x_.get("data")) for x_ in x] + ), + output_spec=Unbounded(shape=(2, 32), dtype=torch.uint8), + ) + base_env = CountingEnvWithString + + env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_transform_compose(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + elif datatype == "str": + obs = "abcdefg" + + td = TensorDict( + { + "observation": obs, + } + ) + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + hash_fn=hash, + output_spec=Unbounded(shape=(), dtype=torch.int64), + ) + t = Compose(t) + td_hashed = t(td) + + assert td_hashed["observation"] is td["observation"] + assert td_hashed["hash"] == hash(td["observation"]) + + def test_transform_model(self): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hash"), ("hash",)], + hash_fn=hash, + output_spec=Unbounded(shape=(), dtype=torch.int64), + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict( + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] + ) + td_out = model(td) + assert ("next", "hash") in td_out.keys(True) + assert ("hash",) in td_out.keys(True) + assert td_out["next", "hash"] == hash(td["next", "observation"]) + assert td_out["hash"] == hash(td["observation"]) + + @pytest.mark.skipif(not _has_gym, reason="Gym not found") + def test_transform_env(self): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + hash_fn=hash, + output_spec=Unbounded(shape=(), dtype=torch.int64), + ) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) + assert env.observation_spec["hash"] + assert "observation" in env.observation_spec + assert "observation" in env.base_env.observation_spec + check_env_specs(env) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hash"), ("hash",)], + hash_fn=lambda x: [hash(x[0]), hash(x[1])], + output_spec=Unbounded(shape=(2,), dtype=torch.int64), + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": torch.randn(3, 4), + "next": TensorDict( + {"observation": torch.randn(3, 4)}, + [], + ), + }, + [], + ).expand(10) + rb.extend(td) + td = rb.sample(2) + assert "hash" in td.keys() + assert "observation" in td.keys() + assert ("next", "observation") in td.keys(True) + + def test_transform_inverse(self): + raise pytest.skip("No inverse for Hash") + + class TestStack(TransformBase): def test_single_trans_env_check(self): t = Stack( diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index b863ad0801c..85d8b993335 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -67,6 +67,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, KLRewardTransform, MultiStepTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 77f6ecc03bf..8e7ecbf2c65 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -31,6 +31,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, NoopResetEnv, ObservationNorm, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8e074fa8679..93d9a5f9db8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools +import hashlib import importlib.util import multiprocessing as mp import warnings @@ -4400,6 +4401,132 @@ def __repr__(self) -> str: ) +class UnaryTransform(Transform): + """Applies a unary operation on the specified inputs. + + Args: + in_keys (sequence of NestedKey): the keys of inputs to the unary operation. + out_keys (sequence of NestedKey): the keys of the outputs of the unary operation. + fn (Callable, optional): the function to use as the unary operation. + output_spec (TensorSpec, optional): the spec of the output of the operation. + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + fn: Callable, + output_spec: TensorSpec, + ): + super().__init__(in_keys=in_keys, out_keys=out_keys) + self._fn = fn + self._output_spec = output_spec + + def _apply_transform(self, value): + if isinstance(value, NonTensorData): + value = value.get("data") + return self._fn(value) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + def _transform_spec(self, spec: TensorSpec) -> TensorSpec: + if not isinstance(spec, Composite): + raise TypeError(f"{self}: Only specs of type Composite can be transformed") + + spec_keys = set(spec.keys(include_nested=True)) + + for in_key, out_key in zip(self.in_keys, self.out_keys): + if in_key in spec_keys: + spec.set(out_key, self._output_spec) + return spec + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(observation_spec) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(reward_spec) + + def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(done_spec) + + +class Hash(UnaryTransform): + """Adds a hash value to a tensordict. + + Args: + in_keys (sequence of NestedKey): the keys of the values to hash. + out_keys (sequence of NestedKey): the keys of the resulting hashes. + hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, + the hash function must accept it as its second argument. Default is + ``Hash.reproducible_hash``. + output_spec (TensorSpec, optional): the spec of the hash output. Default + is ``Unbounded(shape=(32,), dtype=torch.uint8)``. + seed (optional): seed to use for the hash function, if it requires one. + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + hash_fn: Callable = None, + output_spec: TensorSpec | None = None, + seed: Any | None = None, + ): + if hash_fn is None: + hash_fn = Hash.reproducible_hash + + if output_spec is None: + output_spec = Unbounded(shape=(32,), dtype=torch.uint8) + + self._seed = seed + self._hash_fn = hash_fn + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + fn=self.call_hash_fn, + output_spec=output_spec, + ) + + def call_hash_fn(self, value): + if self._seed is None: + return self._hash_fn(value) + else: + return self._hash_fn(value, self._seed) + + @classmethod + def reproducible_hash(cls, string, seed=None): + """Creates a reproducible 256-bit hash from a string using a seed. + + Args: + string (str): The input string. + seed (str, optional): The seed value. Default is ``None``. + + Returns: + Tensor: Shape ``(32,)`` with dtype ``torch.int8``. + """ + # Prepend the seed to the string + if seed is not None: + seeded_string = seed + string + else: + seeded_string = string + + # Create a new SHA-256 hash object + hash_object = hashlib.sha256() + + # Update the hash object with the seeded string + hash_object.update(seeded_string.encode("utf-8")) + + # Get the hash value as bytes + hash_bytes = hash_object.digest() + + return torch.frombuffer(hash_bytes, dtype=torch.uint8) + + class Stack(Transform): """Stacks tensors and tensordicts.