Skip to content

Commit

Permalink
Fixes issue where the indices were not created correctly. (#1660)
Browse files Browse the repository at this point in the history
# Description

This PR fixes an issue in articulation object from omni.isaac.lab. The
functions `write_root_com_pose_to_sim` and `write_root_link_to_sim` both
result in an error where a variable is called before being assigned. I
checked and there is a test for this, but I'm not sure why they don't
catch it...

The two functions have been changed to match the behaviors of
`write_root_link_pose_to_sim` and `write_root_com_velocity_to_sim`.



Fixes #1659 

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------

Co-authored-by: Kelly Guo <[email protected]>
  • Loading branch information
AntoineRichard and kellyguo11 authored Jan 14, 2025
1 parent 8b5fd06 commit a73b63c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,13 @@ def write_root_com_pose_to_sim(self, root_pose: torch.Tensor, env_ids: Sequence[
env_ids: Environment indices. If None, then all indices are used.
"""
# resolve all indices
physx_env_ids = env_ids
if env_ids is None:
local_env_ids = slice(None)
env_ids = slice(None)
physx_env_ids = self._ALL_INDICES

com_pos = self.data.com_pos_b[local_env_ids, 0, :]
com_quat = self.data.com_quat_b[local_env_ids, 0, :]
com_pos = self.data.com_pos_b[env_ids, 0, :]
com_quat = self.data.com_quat_b[env_ids, 0, :]

root_link_pos, root_link_quat = math_utils.combine_frame_transforms(
root_pose[..., :3],
Expand All @@ -399,7 +401,7 @@ def write_root_com_pose_to_sim(self, root_pose: torch.Tensor, env_ids: Sequence[
)

root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1)
self.write_root_link_pose_to_sim(root_pose=root_link_pose, env_ids=env_ids)
self.write_root_link_pose_to_sim(root_pose=root_link_pose, env_ids=physx_env_ids)

def write_root_velocity_to_sim(self, root_velocity: torch.Tensor, env_ids: Sequence[int] | None = None):
"""Set the root center of mass velocity over selected environment indices into the simulation.
Expand Down Expand Up @@ -458,18 +460,20 @@ def write_root_link_velocity_to_sim(self, root_velocity: torch.Tensor, env_ids:
env_ids: Environment indices. If None, then all indices are used.
"""
# resolve all indices
physx_env_ids = env_ids
if env_ids is None:
local_env_ids = slice(None)
env_ids = slice(None)
physx_env_ids = self._ALL_INDICES

root_com_velocity = root_velocity.clone()
quat = self.data.root_link_state_w[local_env_ids, 3:7]
com_pos_b = self.data.com_pos_b[local_env_ids, 0, :]
quat = self.data.root_link_state_w[env_ids, 3:7]
com_pos_b = self.data.com_pos_b[env_ids, 0, :]
# transform given velocity to center of mass
root_com_velocity[:, :3] += torch.linalg.cross(
root_com_velocity[:, 3:], math_utils.quat_rotate(quat, com_pos_b), dim=-1
)
# write center of mass velocity to sim
self.write_root_com_velocity_to_sim(root_velocity=root_com_velocity, env_ids=env_ids)
self.write_root_com_velocity_to_sim(root_velocity=root_com_velocity, env_ids=physx_env_ids)

def write_joint_state_to_sim(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,7 @@ def test_write_root_state(self):
# make quaternion a unit vector
rand_state[..., 3:7] = torch.nn.functional.normalize(rand_state[..., 3:7], dim=-1)

env_idx = env_idx.to(device)
for i in range(10):

# perform step
Expand All @@ -1096,9 +1097,15 @@ def test_write_root_state(self):
articulation.update(sim.cfg.dt)

if state_location == "com":
articulation.write_root_com_state_to_sim(rand_state)
if i % 2 == 0:
articulation.write_root_com_state_to_sim(rand_state)
else:
articulation.write_root_com_state_to_sim(rand_state, env_ids=env_idx)
elif state_location == "link":
articulation.write_root_link_state_to_sim(rand_state)
if i % 2 == 0:
articulation.write_root_link_state_to_sim(rand_state)
else:
articulation.write_root_link_state_to_sim(rand_state, env_ids=env_idx)

if state_location == "com":
torch.testing.assert_close(rand_state, articulation.data.root_com_state_w)
Expand Down

0 comments on commit a73b63c

Please sign in to comment.