From 7a66edfd630ed142ce629b25d01b2ba8d26d2ca6 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Fri, 13 Dec 2024 15:49:06 -0800 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/test_transforms.py | 153 ++++++++++++++++++++++++++ torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 46 ++++++++ 4 files changed, 201 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index cc3ca40b059..7fae20ad7e1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -116,6 +116,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, MultiStepTransform, NoopResetEnv, @@ -2177,6 +2178,158 @@ def test_transform_no_env(self, device, batch): pytest.skip("TrajCounter cannot be called without env") +# TODO: Add tests that hash NonTensorStacks of strings +class TestHash(TransformBase): + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_transform_no_env(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + elif datatype == "str": + obs = "abcdefg" + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + t = Hash(in_keys=["observation"], out_keys=["hash"]) + td_hashed = t(td) + + assert td_hashed["observation"] is td["observation"] + assert td_hashed["hash"] == hash(td["observation"]) + + def test_single_trans_env_check(self): + t = Hash(in_keys=["observation"], out_keys=["hash"]) + env = TransformedEnv(CountingEnv(), t) + check_env_specs(env) + + def test_serial_trans_env_check(self): + def make_env(): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + return TransformedEnv(CountingEnv(), t) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + def make_env(): + t = Hash(in_keys=["observation"], out_keys=["hash"]) + return TransformedEnv(CountingEnv(), t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + def test_trans_serial_env_check(self): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + + env = TransformedEnv(SerialEnv(2, CountingEnv), t) + check_env_specs(env) + + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + + env = TransformedEnv(maybe_fork_ParallelEnv(2, CountingEnv), t) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["tensor", "str"]) + def test_transform_compose(self, datatype): + if datatype == "tensor": + obs = torch.tensor(10) + elif datatype == "str": + obs = "abcdefg" + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + t = Hash(in_keys=["observation"], out_keys=["hash"]) + t = Compose(t) + td_hashed = t(td) + + assert td_hashed["observation"] is td["observation"] + assert td_hashed["hash"] == hash(td["observation"]) + + def test_transform_model(self): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hash"), ("hash",)], + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict( + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] + ) + td_out = model(td) + assert ("next", "hash") in td_out.keys(True) + assert ("hash",) in td_out.keys(True) + assert td_out["next", "hash"] == hash(td["next", "observation"]) + assert td_out["hash"] == hash(td["observation"]) + + @pytest.mark.skipif(not _has_gym, reason="Gym not found") + def test_transform_env(self): + t = Hash( + in_keys=["observation"], + out_keys=["hash"], + ) + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) + assert env.observation_spec["hash"] + assert "observation" in env.observation_spec + assert "observation" in env.base_env.observation_spec + check_env_specs(env) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Hash( + in_keys=[("next", "observation"), ("observation",)], + out_keys=[("next", "hash"), ("hash",)], + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": torch.randn(3, 4), + "next": TensorDict( + {"observation": torch.randn(3, 4)}, + [], + ), + }, + [], + ).expand(10) + rb.extend(td) + td = rb.sample(2) + assert "observation_out" in td.keys() + assert "observation" not in td.keys() + assert ("next", "observation") not in td.keys(True) + + def test_transform_inverse(self): + raise pytest.skip("No inverse for Hash") + + class TestStack(TransformBase): def test_single_trans_env_check(self): t = Stack( diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index b863ad0801c..85d8b993335 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -67,6 +67,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, KLRewardTransform, MultiStepTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 77f6ecc03bf..8e7ecbf2c65 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -31,6 +31,7 @@ FrameSkipTransform, GrayScale, gSDENoise, + Hash, InitTracker, NoopResetEnv, ObservationNorm, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3329d085df..e3dca6ca069 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4400,6 +4400,52 @@ def __repr__(self) -> str: ) +class Hash(Transform): + """Adds a hash value to a tensordict. + + Args: + in_keys (sequence of NestedKey): the key of the data to create the hash from. + out_key (sequence of NestedKey): the key of the resulting hash. + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + ): + super().__init__(in_keys=in_keys, out_keys=out_keys) + + # TODO: If this transform is run on a tensordict like + # `TensorDict({"obs": # tensor.rand(2)}, batch_size=[2])`, then + # `_apply_transform` will create only one hash value for the tensor of size + # 2. Then, when `forward` tries to add the hash to the tensordict, an error + # is raised since the hash doesn't have a leading dimension of size 2. + # TODO: Add support for NonTensorStack inputs. + def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + if isinstance(observation, NonTensorData): + obs = observation.get("data") + else: + obs = observation + return hash(obs) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if not isinstance(observation_spec, Composite): + raise TypeError(f"{self}: Only specs of type Composite can be transformed") + for out_key in self.out_keys: + observation_spec.set( + out_key, + Unbounded(shape=(), dtype=torch.int64), + ) + return observation_spec + + class Stack(Transform): """Stacks tensors and tensordicts. From 6b28487a9fe4c86336e01e0f6ab619a19c73b52c Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Fri, 10 Jan 2025 12:05:23 -0800 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- docs/source/reference/envs.rst | 1 + test/test_transforms.py | 14 ++++++-------- torchrl/envs/transforms/transforms.py | 28 +++++++-------------------- 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 065d6a2e3d4..75accf68651 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -827,6 +827,7 @@ to be able to create this other composition: FlattenObservation FrameSkipTransform GrayScale + Hash InitTracker KLRewardTransform NoopResetEnv diff --git a/test/test_transforms.py b/test/test_transforms.py index 0400ca1003d..7b19940a461 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2188,7 +2188,7 @@ def test_transform_no_env(self, datatype): hash_fn = hash elif datatype == "str": obs = "abcdefg" - hash_fn = Hash.reproducible_hash_parts + hash_fn = Hash.reproducible_hash elif datatype == "NonTensorStack": obs = torch.stack( [ @@ -2199,9 +2199,7 @@ def test_transform_no_env(self, datatype): ) def fn0(x): - return torch.stack( - [Hash.reproducible_hash_parts(x_.get("data")) for x_ in x] - ) + return torch.stack([Hash.reproducible_hash(x_.get("data")) for x_ in x]) hash_fn = fn0 else: @@ -2311,9 +2309,9 @@ def test_trans_serial_env_check(self, datatype): in_keys=["string"], out_keys=["hash"], hash_fn=lambda x: torch.stack( - [Hash.reproducible_hash_parts(x_.get("data")) for x_ in x] + [Hash.reproducible_hash(x_.get("data")) for x_ in x] ), - output_spec=Unbounded(shape=(2, 4), dtype=torch.int64), + output_spec=Unbounded(shape=(2, 32), dtype=torch.uint8), ) base_env = CountingEnvWithString @@ -2335,9 +2333,9 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): in_keys=["string"], out_keys=["hash"], hash_fn=lambda x: torch.stack( - [Hash.reproducible_hash_parts(x_.get("data")) for x_ in x] + [Hash.reproducible_hash(x_.get("data")) for x_ in x] ), - output_spec=Unbounded(shape=(2, 4), dtype=torch.int64), + output_spec=Unbounded(shape=(2, 32), dtype=torch.uint8), ) base_env = CountingEnvWithString diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 06858320c64..93d9a5f9db8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4463,9 +4463,9 @@ class Hash(UnaryTransform): out_keys (sequence of NestedKey): the keys of the resulting hashes. hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, the hash function must accept it as its second argument. Default is - Python's builtin ``hash`` function. + ``Hash.reproducible_hash``. output_spec (TensorSpec, optional): the spec of the hash output. Default - is ``Unbounded(shape=(), dtype=torch.int64)``. + is ``Unbounded(shape=(32,), dtype=torch.uint8)``. seed (optional): seed to use for the hash function, if it requires one. """ @@ -4478,10 +4478,10 @@ def __init__( seed: Any | None = None, ): if hash_fn is None: - hash_fn = Hash.reproducible_hash_parts + hash_fn = Hash.reproducible_hash if output_spec is None: - output_spec = Unbounded(shape=(4,), dtype=torch.int64) + output_spec = Unbounded(shape=(32,), dtype=torch.uint8) self._seed = seed self._hash_fn = hash_fn @@ -4499,7 +4499,7 @@ def call_hash_fn(self, value): return self._hash_fn(value, self._seed) @classmethod - def reproducible_hash_parts(cls, string, seed=None): + def reproducible_hash(cls, string, seed=None): """Creates a reproducible 256-bit hash from a string using a seed. Args: @@ -4507,7 +4507,7 @@ def reproducible_hash_parts(cls, string, seed=None): seed (str, optional): The seed value. Default is ``None``. Returns: - tuple: Four 64-bit integers representing the parts of the 256-bit hash value. + Tensor: Shape ``(32,)`` with dtype ``torch.int8``. """ # Prepend the seed to the string if seed is not None: @@ -4524,21 +4524,7 @@ def reproducible_hash_parts(cls, string, seed=None): # Get the hash value as bytes hash_bytes = hash_object.digest() - # Split the hash bytes into four parts - part1 = hash_bytes[:8] - part2 = hash_bytes[8:16] - part3 = hash_bytes[16:24] - part4 = hash_bytes[24:] - - # Convert each part to a 64-bit integer - part1_value = int.from_bytes(part1, "big", signed=True) - part2_value = int.from_bytes(part2, "big", signed=True) - part3_value = int.from_bytes(part3, "big", signed=True) - part4_value = int.from_bytes(part4, "big", signed=True) - - return torch.tensor( - [part1_value, part2_value, part3_value, part4_value], dtype=torch.int64 - ) + return torch.frombuffer(hash_bytes, dtype=torch.uint8) class Stack(Transform): From ced9b93ec4e43fc13e922e62ea36fadc089d9c5c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 Jan 2025 15:26:41 +0000 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 145 ++++++++++++++++++++++++-- torchrl/envs/utils.py | 3 + 4 files changed, 144 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 410cf6dc7e4..3a4cde38aa2 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -98,6 +98,7 @@ TrajCounter, Transform, TransformedEnv, + UnaryTransform, UnsqueezeTransform, VC1Transform, VecGymEnvTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 29c2522755a..a25c676e378 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -59,6 +59,7 @@ TrajCounter, Transform, TransformedEnv, + UnaryTransform, UnsqueezeTransform, VecGymEnvTransform, VecNorm, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5e94359926d..aa8c64bacc3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4410,7 +4410,7 @@ def __repr__(self) -> str: class UnaryTransform(Transform): - """Applies a unary operation on the specified inputs. + r"""Applies a unary operation on the specified inputs. Args: in_keys (sequence of NestedKey): the keys of inputs to the unary operation. @@ -4420,10 +4420,69 @@ class UnaryTransform(Transform): Keyword Args: use_raw_nontensor (bool, optional): if ``False``, data is extracted from - ``NonTensorData``/``NonTensorStack`` inputs before ``fn`` is called - on them. If ``True``, the raw ``NonTensorData``/``NonTensorStack`` + :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called + on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs are given directly to ``fn``, which must support those inputs. Default is ``False``. + + Example: + >>> from torchrl.envs import GymEnv, UnaryTransform + >>> env = GymEnv("Pendulum-v1") + >>> env = env.append_transform( + ... UnaryTransform( + ... in_keys=["observation"], + ... out_keys=["observation_trsf"], + ... fn=lambda tensor: str(tensor.numpy().tobytes()))) + >>> env.observation_spec + Composite( + observation: BoundedContinuous( + shape=torch.Size([3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + observation_trsf: NonTensor( + shape=torch.Size([]), + space=None, + device=cpu, + dtype=None, + domain=None), + device=None, + shape=torch.Size([])) + >>> env.rollout(3) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_trsf: NonTensorStack( + ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x..., + batch_size=torch.Size([3]), + device=None), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_trsf: NonTensorStack( + ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\..., + batch_size=torch.Size([3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + >>> env.check_env_specs() + [torchrl][INFO] check_env_specs succeeded! + """ def __init__( @@ -4518,7 +4577,7 @@ def transform_done_spec( class Hash(UnaryTransform): - """Adds a hash value to a tensordict. + r"""Adds a hash value to a tensordict. Args: in_keys (sequence of NestedKey): the keys of the values to hash. @@ -4530,10 +4589,84 @@ class Hash(UnaryTransform): Keyword Args: use_raw_nontensor (bool, optional): if ``False``, data is extracted from - ``NonTensorData``/``NonTensorStack`` inputs before ``fn`` is called - on them. If ``True``, the raw ``NonTensorData``/``NonTensorStack`` + :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called + on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs are given directly to ``fn``, which must support those inputs. Default is ``False``. + + >>> from torchrl.envs import GymEnv, UnaryTransform, Hash + >>> env = GymEnv("Pendulum-v1") + >>> # Add a string output + >>> env = env.append_transform( + ... UnaryTransform( + ... in_keys=["observation"], + ... out_keys=["observation_str"], + ... fn=lambda tensor: str(tensor.numpy().tobytes()))) + >>> # process the string output + >>> env = env.append_transform( + ... Hash( + ... in_keys=["observation_str"], + ... out_keys=["observation_hash"],) + ... ) + >>> env.observation_spec + Composite( + observation: BoundedContinuous( + shape=torch.Size([3]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + observation_str: NonTensor( + shape=torch.Size([]), + space=None, + device=cpu, + dtype=None, + domain=None), + observation_hash: UnboundedDiscrete( + shape=torch.Size([32]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True), + high=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True)), + device=cpu, + dtype=torch.uint8, + domain=discrete), + device=None, + shape=torch.Size([])) + >>> env.rollout(3) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), + observation_str: NonTensorStack( + ["b'g\\x08\\x8b\\xbexav\\xbf\\x00\\xee(>'", "b'\\x..., + batch_size=torch.Size([3]), + device=None), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), + observation_str: NonTensorStack( + ["b'\\xb5\\x17\\x8f\\xbe\\x88\\xccu\\xbf\\xc0Vr?'"..., + batch_size=torch.Size([3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + >>> env.check_env_specs() + [torchrl][INFO] check_env_specs succeeded! """ def __init__( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index f7403e6a69e..878df8dad07 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -42,6 +42,7 @@ from torchrl.data.tensor_specs import ( Composite, NO_DEFAULT_RL as NO_DEFAULT, + NonTensor, TensorSpec, Unbounded, ) @@ -907,6 +908,8 @@ def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) + else NonTensor(shape=data.shape, device=tensor.device) + if is_non_tensor(tensor) else Unbounded( dtype=tensor.dtype, device=tensor.device,