Skip to content

Commit

Permalink
[Feature] Real2Sim Eval Digital Twins (#536)
Browse files Browse the repository at this point in the history
* work

* work

* Update base_env.py

* greenscreen trick added

* code refactors

* work

* Update widowx.py

* fixes

* fixes

* updates

* bug fixes

* align sim configs

* fixes

* Update demo_octo_eval.py

* debugged

* work

* bug fixes

* attempt to support IK

* work

* cleanup

* work

* work

* cleaned up code

* evals

* fixes

* spoon task

* Update demo_octo_eval.py

* work

* update widowx model download link and cleanup code

* fixes

* work

* bug fixes

* rt1 inference example

* bug fixes

* less eggplant rolling

* code cleanup

* GPU IK no delta controller implemented

* gpu fixes

* bug fixes

* work

* fixes

* work

* w

* cleanup

* code cleanup, assets added

* docs

* Delete demo_real2sim_eval.py

* f

* Update base_env.py

* Update base_env.py

* Delete README.md

* Update index.md
  • Loading branch information
StoneT2000 authored Sep 13, 2024
1 parent 532bb97 commit 3543565
Show file tree
Hide file tree
Showing 23 changed files with 1,176 additions and 58 deletions.
48 changes: 47 additions & 1 deletion docs/source/tasks/digital_twins/index.md
Original file line number Diff line number Diff line change
@@ -1 +1,47 @@
# Digital Twins (WIP)
# Digital Twins

ManiSkill supports both training and evaluation types of digital twins and provides a simple framework for building them. Training digital twins are tasks designed to train a robot in simulation to then be deployed in the real world (sim2real). Evaluation digital twins are tasks designed to evaluate the performance of a robot trained on real world data (real2sim) and not for training.


## Training Digital Twins (WIP)

Coming soon.

## BridgeData v2 (Evaluation)

We currently support evaluation digital twins of some tasks in the [BridgeData v2](https://rail-berkeley.github.io/bridgedata/) environments in simulation based on [SimplerEnv](https://simpler-env.github.io/) by Xuanlin Li, Kyle Hsu, Jiayuan Gu et al. These digital twins are also GPU parallelized enabling large-scale, fast, evaluation of real-world generalist robotics policies. GPU simulation + rendering enables evaluating up to 60x faster than the real-world and 10x faster than CPU simulation, all without human supervision. ManiSkill only provides the environments, to run policy inference of models like Octo and RT see https://github.com/simpler-env/SimplerEnv/tree/maniskill3

If you use the BridgeData v2 digital twins please cite the following in addition to ManiSkill 3:

```
@article{li24simpler,
title={Evaluating Real-World Robot Manipulation Policies in Simulation},
author={Xuanlin Li and Kyle Hsu and Jiayuan Gu and Karl Pertsch and Oier Mees and Homer Rich Walke and Chuyuan Fu and Ishikaa Lunawat and Isabel Sieh and Sean Kirmani and Sergey Levine and Jiajun Wu and Chelsea Finn and Hao Su and Quan Vuong and Ted Xiao},
journal = {arXiv preprint arXiv:2405.05941},
year={2024},
}
```

### PutCarrotOnPlateInScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/PutCarrotOnPlateInScene-v1.mp4" type="video/mp4">
</video>

### PutSpoonOnTableClothInScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/PutSpoonOnTableClothInScene-v1.mp4" type="video/mp4">
</video>

### StackGreenCubeOnYellowCubeBakedTexInScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/StackGreenCubeOnYellowCubeBakedTexInScene-v1.mp4" type="video/mp4">
</video>

### PutEggplantInBasketScene-v1

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/digital_twins/bridge_data_v2/PutEggplantInBasketScene-v1.mp4" type="video/mp4">
</video>
4 changes: 4 additions & 0 deletions docs/source/user_guide/demos/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ Example below shows what it looks like with the GUI:

For more details check out the [motion planning page](../data_collection/motionplanning.md)

## Real2Sim Evaluation

ManiSkill3 supports extremely fast real2sim evaluation via GPU simulation + rendering of policies like RT-1 and Octo. See [this page](../../tasks/digital_twins/index.md) for more details on which environments are supported. To run inference of RT-1 and Octo, see the `maniskill3` branch of the [SimplerEnv Project](https://github.com/simpler-env/SimplerEnv/tree/maniskill3).

## Visualize Pointcloud Data

You can run the following to visualize the pointcloud observations (require's a display to work)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
13 changes: 7 additions & 6 deletions mani_skill/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _load_articulation(self):

if not os.path.exists(asset_path):
print(f"Robot {self.uid} definition file not found at {asset_path}")
if len(assets.DATA_GROUPS[self.uid]) > 0:
if self.uid in assets.DATA_GROUPS or len(assets.DATA_GROUPS[self.uid]) > 0:
response = download_asset.prompt_yes_no(
f"Robot {self.uid} has assets available for download. Would you like to download them now?"
)
Expand All @@ -181,13 +181,15 @@ def _load_articulation(self):
print(f"Exiting as assets for robot {self.uid} are not downloaded")
exit()
else:
print(f"Exiting as assets for robot {self.uid} are not found")
print(
f"Exiting as assets for robot {self.uid} are not found. Check that this agent is properly registered with the appropriate download asset ids"
)
exit()
self.robot: Articulation = loader.load(asset_path)
assert self.robot is not None, f"Fail to load URDF/MJCF from {asset_path}"

# Cache robot link ids
self.robot_link_ids = [link.name for link in self.robot.get_links()]
# Cache robot link names
self.robot_link_names = [link.name for link in self.robot.get_links()]

def _after_loading_articulation(self):
"""Called after loading articulation and before setting up any controllers. By default this is empty."""
Expand Down Expand Up @@ -337,7 +339,7 @@ def set_state(self, state: Dict, ignore_controller=False):
# -------------------------------------------------------------------------- #
def reset(self, init_qpos: torch.Tensor = None):
"""
Reset the robot to a clean state with zero velocity and forces. Furthermore it resets the current active controller.
Reset the robot to a clean state with zero velocity and forces.
Args:
init_qpos (torch.Tensor): The initial qpos to set the robot to. If None, the robot's qpos is not changed.
Expand All @@ -346,7 +348,6 @@ def reset(self, init_qpos: torch.Tensor = None):
self.robot.set_qpos(init_qpos)
self.robot.set_qvel(torch.zeros(self.robot.max_dof, device=self.device))
self.robot.set_qf(torch.zeros(self.robot.max_dof, device=self.device))
self.controller.reset()

# -------------------------------------------------------------------------- #
# Optional per-agent APIs, implemented depending on agent affordances
Expand Down
18 changes: 1 addition & 17 deletions mani_skill/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ def _check_gpu_sim_works(self):
assert (
self.config.frame == "root_translation"
), "currently only translation in the root frame for EE control is supported in GPU sim"
assert (
self.config.use_delta == True
), "currently only delta EE control is supported in GPU sim"
assert (
self.config.use_target == False
), "Currently cannot take actions relative to last target pose in GPU sim"

def _initialize_joints(self):
self.initial_qpos = None
Expand Down Expand Up @@ -111,6 +105,7 @@ def set_action(self, action: Array):
self.articulation.get_qpos(),
pos_only=pos_only,
action=action,
use_delta_ik_solver=self.config.use_delta and not self.config.use_target,
)
if self._target_qpos is None:
self._target_qpos = self._start_qpos
Expand Down Expand Up @@ -179,12 +174,6 @@ def _check_gpu_sim_works(self):
assert (
self.config.frame == "root_translation:root_aligned_body_rotation"
), "currently only translation in the root frame for EE control is supported in GPU sim"
assert (
self.config.use_delta == True
), "currently only delta EE control is supported in GPU sim"
assert (
self.config.use_target == False
), "Currently cannot take actions relative to last target pose in GPU sim"

def _initialize_action_space(self):
low = np.float32(
Expand Down Expand Up @@ -219,11 +208,6 @@ def _clip_and_scale_action(self, action):
rot_action = rot_action * self.config.rot_lower
return torch.hstack([pos_action, rot_action])

def compute_ik(self, target_pose: Pose, action: Array, max_iterations=100):
return super().compute_ik(
target_pose, action, pos_only=False, max_iterations=max_iterations
)

def compute_target_pose(self, prev_ee_pose_at_base: Pose, action):
if self.config.use_delta:
delta_pos, delta_rot = action[:, 0:3], action[:, 3:6]
Expand Down
71 changes: 53 additions & 18 deletions mani_skill/agents/controllers/utils/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _setup_cpu(self):
def _setup_gpu(self):
"""setup the kinematics solvers on the GPU"""
self.use_gpu_ik = True
with open(self.urdf_path, "r") as f:
with open(self.urdf_path, "rb") as f:
urdf_str = f.read()

# NOTE (stao): it seems that the pk library currently always outputs some complaints if there are unknown attributes in a URDF. Hide it with this contextmanager here
Expand All @@ -107,34 +107,69 @@ def suppress_stdout_stderr():
urdf_str,
end_link_name=self.end_link.name,
).to(device=self.device)
lim = torch.tensor(self.pk_chain.get_joint_limits(), device=self.device)
self.pik = pk.PseudoInverseIK(
self.pk_chain,
joint_limits=lim.T,
early_stopping_any_converged=True,
max_iterations=200,
num_retries=1,
)

self.qmask = torch.zeros(
len(self.active_ancestor_joints), dtype=bool, device=self.device
)
self.qmask[self.controlled_joints_idx_in_qmask] = 1

def compute_ik(
self, target_pose: Pose, q0: torch.Tensor, pos_only: bool = False, action=None
self,
target_pose: Pose,
q0: torch.Tensor,
pos_only: bool = False,
action=None,
use_delta_ik_solver: bool = False,
):
"""Given a target pose, via inverse kinematics compute the target joint positions that will achieve the target pose"""
"""Given a target pose, via inverse kinematics compute the target joint positions that will achieve the target pose
Args:
target_pose (Pose): target pose of the end effector in the world frame. note this is not relative to the robot base frame!
q0 (torch.Tensor): initial joint positions of every active joint in the articulation
pos_only (bool): if True, only the position of the end link is considered in the IK computation
action (torch.Tensor): delta action to be applied to the articulation. Used for fast delta IK solutions on the GPU.
use_delta_ik_solver (bool): If true, returns the target joint positions that correspond with a delta IK solution. This is specifically
used for GPU simulation to determine which GPU IK algorithm to use.
"""
if self.use_gpu_ik:
q0 = q0[:, self.active_ancestor_joint_idxs]
jacobian = self.pk_chain.jacobian(q0)
# code commented out below is the fast kinematics method
# jacobian = (
# self.fast_kinematics_model.jacobian_mixed_frame_pytorch(
# self.articulation.get_qpos()[:, self.active_ancestor_joint_idxs]
# )
# .view(-1, len(self.active_ancestor_joints), 6)
# .permute(0, 2, 1)
# )
# jacobian = jacobian[:, :, self.qmask]
if pos_only:
jacobian = jacobian[:, 0:3]
if not use_delta_ik_solver:
tf = pk.Transform3d(
pos=target_pose.p,
rot=target_pose.q,
device=self.device,
)
self.pik.initial_config = q0 # shape (num_retries, active_ancestor_dof)
result = self.pik.solve(
tf
) # produce solutions in shape (B, num_retries/initial_configs, active_ancestor_dof)
# TODO return mask for invalid solutions. CPU returns None at the moment
return result.solutions[:, 0, :]
else:
jacobian = self.pk_chain.jacobian(q0)
# code commented out below is the fast kinematics method
# jacobian = (
# self.fast_kinematics_model.jacobian_mixed_frame_pytorch(
# self.articulation.get_qpos()[:, self.active_ancestor_joint_idxs]
# )
# .view(-1, len(self.active_ancestor_joints), 6)
# .permute(0, 2, 1)
# )
# jacobian = jacobian[:, :, self.qmask]
if pos_only:
jacobian = jacobian[:, 0:3]

# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
return q0 + delta_joint_pos.squeeze(-1)
# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
return q0 + delta_joint_pos.squeeze(-1)
else:
result, success, error = self.pmodel.compute_inverse_kinematics(
self.end_link_idx,
Expand Down
13 changes: 13 additions & 0 deletions mani_skill/agents/robots/panda/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ def _controller_configs(self):
ee_link=self.ee_link_name,
urdf_path=self.urdf_path,
)
arm_pd_ee_pose = PDEEPoseControllerConfig(
joint_names=self.arm_joint_names,
pos_lower=None,
pos_upper=None,
stiffness=self.arm_stiffness,
damping=self.arm_damping,
force_limit=self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path,
use_delta=False,
normalize_action=False,
)

arm_pd_ee_target_delta_pos = deepcopy(arm_pd_ee_delta_pos)
arm_pd_ee_target_delta_pos.use_target = True
Expand Down Expand Up @@ -180,6 +192,7 @@ def _controller_configs(self):
pd_ee_delta_pose=dict(
arm=arm_pd_ee_delta_pose, gripper=gripper_pd_joint_pos
),
pd_ee_pose=dict(arm=arm_pd_ee_pose, gripper=gripper_pd_joint_pos),
# TODO(jigu): how to add boundaries for the following controllers
pd_joint_target_delta_pos=dict(
arm=arm_pd_joint_target_delta_pos, gripper=gripper_pd_joint_pos
Expand Down
50 changes: 48 additions & 2 deletions mani_skill/agents/robots/widowx/widowx.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,62 @@
import numpy as np
import sapien
import torch

from mani_skill import ASSET_DIR
from mani_skill.agents.base_agent import BaseAgent
from mani_skill.agents.controllers import *
from mani_skill.agents.registration import register_agent
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common
from mani_skill.utils.structs.actor import Actor


# TODO (stao) (xuanlin): model it properly based on real2sim
@register_agent(asset_download_ids=["widowx250s"])
class WidowX250S(BaseAgent):
uid = "widowx250s"
urdf_path = f"{ASSET_DIR}/robots/widowx250s/wx250s.urdf"
urdf_path = f"{ASSET_DIR}/robots/widowx/wx250s.urdf"
urdf_config = dict()

arm_joint_names = [
"waist",
"shoulder",
"elbow",
"forearm_roll",
"wrist_angle",
"wrist_rotate",
]
gripper_joint_names = ["left_finger", "right_finger"]

def _after_loading_articulation(self):
self.finger1_link = self.robot.links_map["left_finger_link"]
self.finger2_link = self.robot.links_map["right_finger_link"]

def is_grasping(self, object: Actor, min_force=0.5, max_angle=85):
"""Check if the robot is grasping an object
Args:
object (Actor): The object to check if the robot is grasping
min_force (float, optional): Minimum force before the robot is considered to be grasping the object in Newtons. Defaults to 0.5.
max_angle (int, optional): Maximum angle of contact to consider grasping. Defaults to 85.
"""
l_contact_forces = self.scene.get_pairwise_contact_forces(
self.finger1_link, object
)
r_contact_forces = self.scene.get_pairwise_contact_forces(
self.finger2_link, object
)
lforce = torch.linalg.norm(l_contact_forces, axis=1)
rforce = torch.linalg.norm(r_contact_forces, axis=1)

# direction to open the gripper
ldirection = self.finger1_link.pose.to_transformation_matrix()[..., :3, 1]
rdirection = -self.finger2_link.pose.to_transformation_matrix()[..., :3, 1]
langle = common.compute_angle_between(ldirection, l_contact_forces)
rangle = common.compute_angle_between(rdirection, r_contact_forces)
lflag = torch.logical_and(
lforce >= min_force, torch.rad2deg(langle) <= max_angle
)
rflag = torch.logical_and(
rforce >= min_force, torch.rad2deg(rangle) <= max_angle
)
return torch.logical_and(lflag, rflag)
9 changes: 7 additions & 2 deletions mani_skill/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def __init__(
torch.zeros(self.num_envs, device=self.device, dtype=torch.int32)
)
obs, _ = self.reset(seed=2022, options=dict(reconfigure=True))

self._init_raw_obs = common.to_cpu_tensor(obs)
"""the raw observation returned by the env.reset (a cpu torch tensor/dict of tensors). Useful for future observation wrappers to use to auto generate observation spaces"""
self._init_raw_state = common.to_cpu_tensor(self.get_state_dict())
Expand Down Expand Up @@ -549,9 +550,9 @@ def _get_obs_with_sensor_data(self, info: Dict, apply_texture_transforms: bool =
)

@property
def robot_link_ids(self):
def robot_link_names(self):
"""Get link ids for the robot. This is used for segmentation observations."""
return self.agent.robot_link_ids
return self.agent.robot_link_names

# -------------------------------------------------------------------------- #
# Reward mode
Expand Down Expand Up @@ -810,6 +811,10 @@ def reset(self, seed=None, options=None):
self.scene._gpu_apply_all()
self.scene.px.gpu_update_articulation_kinematics()
self.scene._gpu_fetch_all()

# we reset controllers here because some controllers depend on the agent/articulation qpos/poses
self.agent.controller.reset()

obs = self.get_obs()

return obs, dict(reconfigure=reconfigure)
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _load_articulation(self):
assert self.robot is not None, f"Fail to load URDF/MJCF from {asset_path}"

# Cache robot link ids
self.robot_link_ids = [link.name for link in self.robot.get_links()]
self.robot_link_names = [link.name for link in self.robot.get_links()]


# @register_env("MS-CartPole-v1", max_episode_steps=500)
Expand Down
Loading

0 comments on commit 3543565

Please sign in to comment.