Skip to content

Commit

Permalink
Fix Compose input spec transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Faury committed Oct 4, 2024
1 parent 97ccbb7 commit 2d622bd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 19 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8675,6 +8675,25 @@ def test_compose_indexing(self):
assert last_t.scale == 4
assert last_t2.scale == 4

def test_compose_action_spec(self):
# Create a Compose transform that renames "action" to "action_1" and then to "action_2"
c = Compose(
RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action",), out_keys_inv=("action_1",)),
RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action_1",), out_keys_inv=("action_2",)),
)
base_env = ContinuousActionVecMockEnv()
env = TransformedEnv(base_env, c)

# Check the `full_action_spec`s
assert "action_2" in env.full_action_spec
# Ensure intermediate keys are no longer in the action spec
assert "action_1" not in env.full_action_spec
assert "action" not in env.full_action_spec

# Final check to ensure clean sampling from the action_spec
action = env.rand_action()
assert "action_2"

@pytest.mark.parametrize("device", get_default_devices())
def test_finitetensordictcheck(self, device):
ftd = FiniteTensorDictCheck()
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def transform_env_batch_size(self, batch_size: torch.batch_size):
return batch_size

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
for t in self.transforms[::-1]:
for t in self.transforms:
input_spec = t.transform_input_spec(input_spec)
return input_spec

Expand Down

0 comments on commit 2d622bd

Please sign in to comment.