From 2d622bd72ccab6c61b8e9ca0b5d4b652631b4f58 Mon Sep 17 00:00:00 2001 From: Louis Faury Date: Fri, 4 Oct 2024 11:02:22 +0200 Subject: [PATCH] Fix Compose input spec transform --- test/test_transforms.py | 19 +++++++++++++++++++ torchrl/envs/transforms/transforms.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fc5048569fb..d0bc785f5fe 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8675,6 +8675,25 @@ def test_compose_indexing(self): assert last_t.scale == 4 assert last_t2.scale == 4 + def test_compose_action_spec(self): + # Create a Compose transform that renames "action" to "action_1" and then to "action_2" + c = Compose( + RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action",), out_keys_inv=("action_1",)), + RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action_1",), out_keys_inv=("action_2",)), + ) + base_env = ContinuousActionVecMockEnv() + env = TransformedEnv(base_env, c) + + # Check the `full_action_spec`s + assert "action_2" in env.full_action_spec + # Ensure intermediate keys are no longer in the action spec + assert "action_1" not in env.full_action_spec + assert "action" not in env.full_action_spec + + # Final check to ensure clean sampling from the action_spec + action = env.rand_action() + assert "action_2" + @pytest.mark.parametrize("device", get_default_devices()) def test_finitetensordictcheck(self, device): ftd = FiniteTensorDictCheck() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d95f598944a..a95a14d42ad 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1094,7 +1094,7 @@ def transform_env_batch_size(self, batch_size: torch.batch_size): return batch_size def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - for t in self.transforms[::-1]: + for t in self.transforms: input_spec = t.transform_input_spec(input_spec) return input_spec