Skip to content

Commit

Permalink
Fixes to RenameTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasbbrunner committed Sep 19, 2024
1 parent e294c68 commit b0afcc2
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6713,7 +6713,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance)
for in_key, out_key in zip(self.in_keys, self.out_keys):
try:
tensordict.rename_key_(in_key, out_key)
out.rename_key_(in_key, out_key)
except KeyError:
if not self._missing_tolerance:
raise
Expand Down Expand Up @@ -6743,7 +6743,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
)
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
try:
out.rename_key_(out_key, in_key)
out.rename_key_(in_key, out_key)
except KeyError:
if not self._missing_tolerance:
raise
Expand All @@ -6752,7 +6752,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
try:
tensordict.rename_key_(out_key, in_key)
tensordict.rename_key_(in_key, out_key)
except KeyError:
if not self._missing_tolerance:
raise
Expand Down Expand Up @@ -6802,29 +6802,29 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:

def transform_input_spec(self, input_spec: Composite) -> Composite:
for action_key in self.parent.action_keys:
if action_key in self.in_keys:
for i, out_key in enumerate(self.out_keys): # noqa: B007
if self.in_keys[i] == action_key:
if action_key in self.out_keys_inv:
for i, in_key_inv in enumerate(self.in_keys_inv): # noqa: B007
if self.out_keys_inv[i] == action_key:
break
else:
# unreachable
raise RuntimeError
input_spec["full_action_spec"][out_key] = input_spec[
input_spec["full_action_spec"][in_key_inv] = input_spec[
"full_action_spec"
][action_key].clone()
if not self.create_copy:
del input_spec["full_action_spec"][action_key]
for state_key in self.parent.full_state_spec.keys(True):
if state_key in self.in_keys:
for i, out_key in enumerate(self.out_keys): # noqa: B007
if self.in_keys[i] == state_key:
if state_key in self.out_keys_inv:
for i, in_key_inv in enumerate(self.in_keys_inv): # noqa: B007
if self.out_keys_inv[i] == state_key:
break
else:
# unreachable
raise RuntimeError
input_spec["full_state_spec"][out_key] = input_spec["full_state_spec"][
state_key
].clone()
input_spec["full_state_spec"][in_key_inv] = input_spec[
"full_state_spec"
][state_key].clone()
if not self.create_copy:
del input_spec["full_state_spec"][state_key]
return input_spec
Expand Down

0 comments on commit b0afcc2

Please sign in to comment.