diff --git a/mani_skill2/agents/robots/__init__.py b/mani_skill2/agents/robots/__init__.py index db95d0511..f3db0e3e1 100644 --- a/mani_skill2/agents/robots/__init__.py +++ b/mani_skill2/agents/robots/__init__.py @@ -5,6 +5,7 @@ from .panda import Panda from .xarm import XArm7Ability from .xmate3 import Xmate3Robotiq +from .allegro_hand import AllegroHandRightTouch, AllegroHandRight, AllegroHandLeft ROBOTS = { "panda": Panda, @@ -15,6 +16,9 @@ # Dexterous Hand "dclaw": DClaw, "xarm7_ability": XArm7Ability, + "allegro_hand_right": AllegroHandRight, + "allegro_hand_left": AllegroHandLeft, + "allegro_hand_right_touch": AllegroHandRightTouch, # Locomotion "anymal-c": ANYmalC, } diff --git a/mani_skill2/agents/robots/allegro_hand/__init__.py b/mani_skill2/agents/robots/allegro_hand/__init__.py new file mode 100644 index 000000000..a5bdb3985 --- /dev/null +++ b/mani_skill2/agents/robots/allegro_hand/__init__.py @@ -0,0 +1,2 @@ +from .allegro import AllegroHandLeft, AllegroHandRight +from .allegro_touch import AllegroHandRightTouch diff --git a/mani_skill2/agents/robots/allegro_hand/allegro.py b/mani_skill2/agents/robots/allegro_hand/allegro.py new file mode 100644 index 000000000..e9cd288f4 --- /dev/null +++ b/mani_skill2/agents/robots/allegro_hand/allegro.py @@ -0,0 +1,151 @@ +from copy import deepcopy +from typing import List + +import sapien +import torch + +from mani_skill2 import PACKAGE_ASSET_DIR +from mani_skill2.agents.base_agent import BaseAgent +from mani_skill2.agents.controllers import * +from mani_skill2.utils.sapien_utils import ( + get_obj_by_name, +) +from mani_skill2.utils.sapien_utils import get_objs_by_names +from mani_skill2.utils.structs.pose import vectorize_pose + + +class AllegroHandRight(BaseAgent): + uid = "allegro_hand_right" + urdf_path = f"{PACKAGE_ASSET_DIR}/robots/allegro/allegro_hand_right_glb.urdf" + urdf_config = dict( + _materials=dict( + tip=dict(static_friction=2.0, dynamic_friction=1.0, restitution=0.0) + ), + link={ + "link_3.0_tip": dict( + material="tip", patch_radius=0.1, min_patch_radius=0.1 + ), + "link_7.0_tip": dict( + material="tip", patch_radius=0.1, min_patch_radius=0.1 + ), + "link_11.0_tip": dict( + material="tip", patch_radius=0.1, min_patch_radius=0.1 + ), + "link_15.0_tip": dict( + material="tip", patch_radius=0.1, min_patch_radius=0.1 + ), + }, + ) + sensor_configs = {} + + def __init__(self, *args, **kwargs): + self.joint_names = [ + "joint_0.0", + "joint_1.0", + "joint_2.0", + "joint_3.0", + "joint_4.0", + "joint_5.0", + "joint_6.0", + "joint_7.0", + "joint_8.0", + "joint_9.0", + "joint_10.0", + "joint_11.0", + "joint_12.0", + "joint_13.0", + "joint_14.0", + "joint_15.0", + ] + + self.joint_stiffness = 4e2 + self.joint_damping = 1e1 + self.joint_force_limit = 5e1 + + # Order: thumb finger, index finger, middle finger, ring finger + self.tip_link_names = [ + "link_15.0_tip", + "link_3.0_tip", + "link_7.0_tip", + "link_11.0_tip", + ] + + self.palm_link_name = "palm" + super().__init__(*args, **kwargs) + + def _after_init(self): + self.tip_links: List[sapien.Entity] = get_objs_by_names( + self.robot.get_links(), self.tip_link_names + ) + self.palm_link: sapien.Entity = get_obj_by_name( + self.robot.get_links(), self.palm_link_name + ) + + @property + def controller_configs(self): + # -------------------------------------------------------------------------- # + # Arm + # -------------------------------------------------------------------------- # + joint_pos = PDJointPosControllerConfig( + self.joint_names, + None, + None, + self.joint_stiffness, + self.joint_damping, + self.joint_force_limit, + normalize_action=False, + ) + joint_delta_pos = PDJointPosControllerConfig( + self.joint_names, + -0.1, + 0.1, + self.joint_stiffness, + self.joint_damping, + self.joint_force_limit, + use_delta=True, + ) + joint_target_delta_pos = deepcopy(joint_delta_pos) + joint_target_delta_pos.use_target = True + + controller_configs = dict( + pd_joint_delta_pos=joint_delta_pos, + pd_joint_pos=joint_pos, + pd_joint_target_delta_pos=joint_target_delta_pos, + ) + + # Make a deepcopy in case users modify any config + return deepcopy_dict(controller_configs) + + def get_proprioception(self): + """ + Get the proprioceptive state of the agent. + """ + obs = super().get_proprioception() + obs.update( + { + "palm_pose": self.palm_pose, + "tip_poses": self.tip_poses.reshape(-1, len(self.tip_links) * 7), + } + ) + + return obs + + @property + def tip_poses(self): + """ + Get the tip pose for each of the finger, four fingers in total + """ + tip_poses = [vectorize_pose(link.pose) for link in self.tip_links] + return torch.stack(tip_poses, dim=-2) + + @property + def palm_pose(self): + """ + Get the palm pose for allegro hand + """ + return vectorize_pose(self.palm_link.pose) + + +class AllegroHandLeft(AllegroHandRight): + uid = "allegro_hand_left" + urdf_path = f"{PACKAGE_ASSET_DIR}/robots/allegro/allegro_hand_left.urdf" diff --git a/mani_skill2/agents/robots/allegro_hand/allegro_touch.py b/mani_skill2/agents/robots/allegro_hand/allegro_touch.py new file mode 100644 index 000000000..f08c284bf --- /dev/null +++ b/mani_skill2/agents/robots/allegro_hand/allegro_touch.py @@ -0,0 +1,153 @@ +import itertools +from typing import List, Dict, Tuple, Optional + +import numpy as np +import sapien +import torch +from sapien import physx + +from mani_skill2 import PACKAGE_ASSET_DIR +from mani_skill2.agents.robots.allegro_hand.allegro import AllegroHandRight +from mani_skill2.utils.sapien_utils import ( + compute_total_impulse, + get_multiple_pairwise_contacts, + get_actors_contacts, +) +from mani_skill2.utils.sapien_utils import get_objs_by_names +from mani_skill2.utils.structs.actor import Actor + + +class AllegroHandRightTouch(AllegroHandRight): + uid = "allegro_hand_right_touch" + urdf_path = f"{PACKAGE_ASSET_DIR}/robots/allegro/variation/allegro_hand_right_fsr_simple.urdf" + + def __init__(self, *args, **kwargs): + # Order: thumb finger, index finger, middle finger, ring finger, from finger root to fingertip + self.finger_fsr_link_names = [ + # allegro thumb has a different hardware design compared with other fingers + "link_14.0_fsr", + "link_15.0_fsr", + "link_15.0_tip_fsr", + # the hardware design of index, middle and ring finger are the same + "link_1.0_fsr", + "link_2.0_fsr", + "link_3.0_tip_fsr", + "link_5.0_fsr", + "link_6.0_fsr", + "link_7.0_tip_fsr", + "link_9.0_fsr", + "link_10.0_fsr", + "link_11.0_tip_fsr", + ] + self.palm_fsr_link_names = [ + "link_base_fsr", + "link_0.0_fsr", + "link_4.0_fsr", + "link_8.0_fsr", + ] + + super().__init__(*args, **kwargs) + + self.pair_query: Dict[ + str, Tuple[physx.PhysxGpuContactPairImpulseQuery, Tuple[int, int, int]] + ] = dict() + self.body_query: Optional[ + Tuple[physx.PhysxGpuContactBodyImpulseQuery, Tuple[int, int, int]] + ] = None + + def _after_init(self): + super()._after_init() + self.fsr_links: List[Actor] = get_objs_by_names( + self.robot.get_links(), + self.palm_fsr_link_names + self.finger_fsr_link_names, + ) + + def get_fsr_obj_impulse(self, obj: Actor = None): + if physx.is_gpu_enabled(): + px: sapien.physx.PhysxGpuSystem = self.scene.px + # Create contact query if it is not existed + if obj.name not in self.pair_query: + bodies = list(zip(*[link._bodies for link in self.fsr_links])) + bodies = list(itertools.chain(*bodies)) + obj_bodies = [ + elem for item in obj._bodies for elem in itertools.repeat(item, 2) + ] + body_pairs = list(zip(bodies, obj_bodies)) + query = px.gpu_create_contact_pair_impulse_query(body_pairs) + self.pair_query[obj.name] = ( + query, + (len(obj._bodies), len(self.fsr_links), 3), + ) + + # Query contact buffer + query, contacts_shape = self.pair_query[obj.name] + px.gpu_query_contact_pair_impulses(query) + contacts = ( + query.cuda_impulses.torch() + .clone() + .reshape((len(self.fsr_links), *contacts_shape)) + ) # [n, 16, 3] + + return contacts + + else: + internal_fsr_links = [link._bodies[0].entity for link in self.fsr_links] + contacts = self.scene.get_contacts() + obj_contacts = get_multiple_pairwise_contacts( + contacts, obj._bodies[0].entity, internal_fsr_links + ) + sorted_contacts = [obj_contacts[link] for link in internal_fsr_links] + contact_forces = [ + compute_total_impulse(contact) for contact in sorted_contacts + ] + + return np.stack(contact_forces) + + def get_fsr_impulse(self): + if physx.is_gpu_enabled(): + px: sapien.physx.PhysxGpuSystem = self.scene.px + # Create contact query if it is not existed + if self.body_query is None: + # Convert the order of links so that the link from the same sub-scene will come together + # It makes life easier for reshape + bodies = list(zip(*[link._bodies for link in self.fsr_links])) + bodies = list(itertools.chain(*bodies)) + + query = px.gpu_create_contact_body_impulse_query(bodies) + self.body_query = ( + query, + (len(self.fsr_links[0]._bodies), len(self.fsr_links), 3), + ) + + # Query contact buffer + query, contacts_shape = self.body_query + px.gpu_query_contact_body_impulses(query) + contacts = ( + query.cuda_impulses.torch().clone().reshape(*contacts_shape) + ) # [n, 16, 3] + + return contacts + + else: + internal_fsr_links = [link._bodies[0].entity for link in self.fsr_links] + contacts = self.scene.get_contacts() + contact_map = get_actors_contacts(contacts, internal_fsr_links) + sorted_contacts = [contact_map[link] for link in internal_fsr_links] + contact_forces = [ + compute_total_impulse(contact) for contact in sorted_contacts + ] + + contact_impulse = torch.from_numpy( + np.stack(contact_forces)[None, ...] + ) # [1, 16, 3] + return contact_impulse + + def get_proprioception(self): + """ + Get the proprioceptive state of the agent. + """ + obs = super().get_proprioception() + fsr_impulse = self.get_fsr_impulse() + obs.update({"fsr_impulse": torch.linalg.norm(fsr_impulse, dim=-1)}) + + return obs diff --git a/mani_skill2/agents/robots/dclaw/dclaw.py b/mani_skill2/agents/robots/dclaw/dclaw.py index 9ededc392..77464e0a6 100644 --- a/mani_skill2/agents/robots/dclaw/dclaw.py +++ b/mani_skill2/agents/robots/dclaw/dclaw.py @@ -1,7 +1,6 @@ from copy import deepcopy from typing import List -import sapien import torch from mani_skill2 import PACKAGE_ASSET_DIR @@ -13,6 +12,8 @@ from mani_skill2.utils.sapien_utils import ( get_objs_by_names, ) +from mani_skill2.utils.structs.joint import Joint +from mani_skill2.utils.structs.link import Link from mani_skill2.utils.structs.pose import vectorize_pose @@ -53,10 +54,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _after_init(self): - self.tip_links: List[sapien.Entity] = get_objs_by_names( + self.tip_links: List[Link] = get_objs_by_names( self.robot.get_links(), self.tip_link_names ) - self.root_joints = [ + self.root_joints: List[Joint] = [ self.robot.find_joint_by_name(n) for n in self.root_joint_names ] self.root_joint_indices = get_active_joint_indices( @@ -89,43 +90,10 @@ def controller_configs(self): joint_target_delta_pos = deepcopy(joint_delta_pos) joint_target_delta_pos.use_target = True - # PD joint velocity - pd_joint_vel = PDJointVelControllerConfig( - self.joint_names, - -1.0, - 1.0, - self.joint_damping, # this might need to be tuned separately - self.joint_force_limit, - ) - - # PD joint position and velocity - joint_pos_vel = PDJointPosVelControllerConfig( - self.joint_names, - None, - None, - self.joint_stiffness, - self.joint_damping, - self.joint_force_limit, - normalize_action=False, - ) - joint_delta_pos_vel = PDJointPosVelControllerConfig( - self.joint_names, - -0.1, - 0.1, - self.joint_stiffness, - self.joint_damping, - self.joint_force_limit, - use_delta=True, - ) - controller_configs = dict( pd_joint_delta_pos=dict(joint=joint_delta_pos), pd_joint_pos=dict(joint=joint_pos), pd_joint_target_delta_pos=dict(joint=joint_target_delta_pos), - # Caution to use the following controllers - pd_joint_vel=dict(joint=pd_joint_vel), - pd_joint_pos_vel=dict(joint=joint_pos_vel), - pd_joint_delta_pos_vel=dict(joint=joint_delta_pos_vel), ) # Make a deepcopy in case users modify any config @@ -136,7 +104,7 @@ def get_proprioception(self): Get the proprioceptive state of the agent. """ obs = super().get_proprioception() - obs.update({"tip_poses": self.tip_poses.view(-1, 21)}) + obs.update({"tip_poses": self.tip_poses.view(-1, len(self.tip_links) * 7)}) return obs @@ -146,4 +114,4 @@ def tip_poses(self): Get the tip pose for each of the finger, three fingers in total """ tip_poses = [vectorize_pose(link.pose) for link in self.tip_links] - return torch.stack(tip_poses, dim=-1) + return torch.stack(tip_poses, dim=-2) diff --git a/mani_skill2/envs/tasks/__init__.py b/mani_skill2/envs/tasks/__init__.py index 23f4a2446..9606f63b9 100644 --- a/mani_skill2/envs/tasks/__init__.py +++ b/mani_skill2/envs/tasks/__init__.py @@ -10,3 +10,4 @@ from .stack_cube import StackCubeEnv from .two_robot_pick_cube import TwoRobotPickCube from .two_robot_stack_cube import TwoRobotStackCube +from .dexterity import RotateValveEnv, RotateSingleObjectInHand diff --git a/mani_skill2/envs/tasks/dexterity/__init__.py b/mani_skill2/envs/tasks/dexterity/__init__.py index 5e02a351e..118c0bd67 100644 --- a/mani_skill2/envs/tasks/dexterity/__init__.py +++ b/mani_skill2/envs/tasks/dexterity/__init__.py @@ -1,2 +1,3 @@ # isort: off from .rotate_valve import RotateValveEnv +from .rotate_single_object_in_hand import RotateSingleObjectInHand diff --git a/mani_skill2/envs/tasks/dexterity/rotate_single_object_in_hand.py b/mani_skill2/envs/tasks/dexterity/rotate_single_object_in_hand.py new file mode 100644 index 000000000..58aeac9d2 --- /dev/null +++ b/mani_skill2/envs/tasks/dexterity/rotate_single_object_in_hand.py @@ -0,0 +1,346 @@ +from collections import OrderedDict +from typing import Union, Dict, Any, List + +import numpy as np +import sapien +import torch +import torch.nn.functional as F + +from mani_skill2.agents.robots import ( + AllegroHandRightTouch, +) +from mani_skill2.envs.sapien_env import BaseEnv +from mani_skill2.sensors.camera import CameraConfig +from mani_skill2.utils.building.actors import ( + build_cube, + build_actor_ycb, + MODEL_DBS, + _load_ycb_dataset, +) +from mani_skill2.utils.geometry.rotation_conversions import quaternion_apply +from mani_skill2.utils.registration import register_env +from mani_skill2.utils.sapien_utils import look_at +from mani_skill2.utils.scene_builder.table.table_scene_builder import TableSceneBuilder +from mani_skill2.utils.structs.actor import Actor +from mani_skill2.utils.structs.pose import Pose, vectorize_pose +from mani_skill2.utils.structs.types import Array + + +@register_env("RotateSingleObjectInHand-v1", max_episode_steps=300) +class RotateSingleObjectInHand(BaseEnv): + agent: Union[AllegroHandRightTouch] + _clearance = 0.003 + hand_init_height = 0.25 + + def __init__( + self, + *args, + robot_init_qpos_noise=0.02, + obj_init_pos_noise=0.02, + difficulty_level: int = -1, + **kwargs, + ): + self.robot_init_qpos_noise = robot_init_qpos_noise + self.obj_init_pos_noise = obj_init_pos_noise + self.obj_heights: torch.Tensor = torch.Tensor() + _load_ycb_dataset() + + if ( + not isinstance(difficulty_level, int) + or difficulty_level >= 4 + or difficulty_level < 0 + ): + raise ValueError( + f"Difficulty level must be a int within 0-3, but get {difficulty_level}" + ) + self.difficulty_level = difficulty_level + + num_envs = kwargs.get("num_envs") + if num_envs > 1: + sapien.physx.set_gpu_memory_config( + max_rigid_contact_count=num_envs * max(1024, num_envs) * 8, + max_rigid_patch_count=num_envs * max(1024, num_envs) * 2, + found_lost_pairs_capacity=2**26, + ) + + super().__init__(*args, robot_uids="allegro_hand_right_touch", **kwargs) + + with torch.device(self.device): + self.prev_unit_vector = torch.zeros((self.num_envs, 3)) + self.cum_rotation_angle = torch.zeros((self.num_envs,)) + + def _register_sensors(self): + pose = look_at(eye=[0.15, 0, 0.45], target=[-0.1, 0, self.hand_init_height]) + return [ + CameraConfig("base_camera", pose.p, pose.q, 128, 128, np.pi / 2, 0.01, 10) + ] + + def _register_human_render_cameras(self): + pose = look_at([0.2, 0.4, 0.4], [0.0, 0.0, 0.1]) + return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10) + + def _load_actors(self): + self.table_scene = TableSceneBuilder( + env=self, robot_init_qpos_noise=self.robot_init_qpos_noise + ) + self.table_scene.build() + + obj_heights = [] + if self.difficulty_level == 0: + self.obj = build_cube( + self._scene, + half_size=0.04, + color=np.array([255, 255, 255, 255]) / 255, + name="cube", + body_type="dynamic", + ) + obj_heights.append(0.03) + elif self.difficulty_level == 1: + half_sizes = (torch.randn(self.num_envs) * 0.1 + 1) * 0.04 + actors: List[Actor] = [] + for i, half_size in enumerate(half_sizes): + builder = self._scene.create_actor_builder() + builder.add_box_collision( + half_size=[half_size] * 3, + ) + builder.add_box_visual( + half_size=[half_size] * 3, + material=sapien.render.RenderMaterial( + base_color=np.array([255, 255, 255, 255]) / 255, + ), + ) + scene_mask = np.zeros(self.num_envs, dtype=bool) + scene_mask[i] = True + builder.set_scene_mask(scene_mask) + actors.append(builder.build(name=f"cube-{i}")) + obj_heights.append(half_size) + self.obj = Actor.merge(actors, name="cube") + elif self.difficulty_level >= 2: + all_model_ids = np.array(list(MODEL_DBS["YCB"]["model_data"].keys())) + rand_idx = torch.randperm(len(all_model_ids)) + model_ids = all_model_ids[rand_idx] + model_ids = np.concatenate( + [model_ids] * np.ceil(self.num_envs / len(all_model_ids)).astype(int) + )[: self.num_envs] + actors: List[Actor] = [] + for i, model_id in enumerate(model_ids): + builder, obj_height = build_actor_ycb( + model_id, self._scene, name=model_id, return_builder=True + ) + scene_mask = np.zeros(self.num_envs, dtype=bool) + scene_mask[i] = True + builder.set_scene_mask(scene_mask) + actors.append(builder.build(name=f"{model_id}-{i}")) + obj_heights.append(obj_height) + self.obj = Actor.merge(actors, name="ycb_object") + else: + raise ValueError( + f"Difficulty level must be a int within 0-4, but get {self.difficulty_level}" + ) + + self.obj_heights = torch.from_numpy(np.array(obj_heights)).to(self.device) + + def _initialize_actors(self, env_idx: torch.Tensor): + with torch.device(self.device): + b = len(env_idx) + # Initialize object pose + self.table_scene.initialize() + pose = self.obj.pose + new_pos = torch.randn((b, 3)) * self.obj_init_pos_noise + # hand_init_height is robot hand position while the 0.03 is a margin to ensure + new_pos[:, 2] = ( + torch.abs(new_pos[:, 2]) + self.hand_init_height + self.obj_heights + ) + pose.raw_pose[:, 0:3] = new_pos + pose.raw_pose[:, 3:7] = torch.tensor([[1, 0, 0, 0]]) + self.obj.set_pose(pose) + + # Initialize object axis + if self.difficulty_level <= 2: + axis = torch.ones((b,), dtype=torch.long) * 2 + else: + axis = torch.randint(0, 3, (b,), dtype=torch.long) + self.rot_dir = F.one_hot(axis, num_classes=3) + + # Sample a unit vector on the tangent plane of rotating axis + vector_axis = (axis + 1) % 3 + vector = F.one_hot(vector_axis, num_classes=3) + + # Initialize task related cache + self.unit_vector = vector + self.prev_unit_vector = vector.clone() + self.success_threshold = torch.pi * 4 + self.cum_rotation_angle = torch.zeros((b,)) + + # Controller parameters + stiffness = torch.tensor(self.agent.controller.config.stiffness) + damping = torch.tensor(self.agent.controller.config.damping) + force_limit = torch.tensor(self.agent.controller.config.force_limit) + self.controller_param = ( + stiffness.expand(b, self.agent.robot.dof[0]), + damping.expand(b, self.agent.robot.dof[0]), + force_limit.expand(b, self.agent.robot.dof[0]), + ) + + def _initialize_agent(self, env_idx: torch.Tensor): + with torch.device(self.device): + b = len(env_idx) + dof = self.agent.robot.dof + if isinstance(dof, torch.Tensor): + dof = dof[0] + init_qpos = torch.zeros((b, dof)) + self.agent.reset(init_qpos) + self.agent.robot.set_pose( + Pose.create_from_pq( + torch.tensor([0.0, 0, self.hand_init_height]), + torch.tensor([-0.707, 0, 0.707, 0]), + ) + ) + + def _get_obs_extra(self, info: Dict): + with torch.device(self.device): + obs = OrderedDict(rotate_dir=self.rot_dir) + if self._obs_mode in ["state", "state_dict"]: + obs.update( + obj_pose=vectorize_pose(self.obj.pose), + obj_tip_vec=info["obj_tip_vec"].view(self.num_envs, 12), + ) + return obs + + def evaluate(self, **kwargs) -> dict: + with torch.device(self.device): + # 1. rotation angle + obj_pose = self.obj.pose + new_unit_vector = quaternion_apply(obj_pose.q, self.unit_vector) + new_unit_vector -= ( + torch.sum(new_unit_vector * self.rot_dir, dim=-1, keepdim=True) + * self.rot_dir + ) + new_unit_vector = new_unit_vector / torch.linalg.norm( + new_unit_vector, dim=-1, keepdim=True + ) + angle = torch.acos( + torch.clip( + torch.sum(new_unit_vector * self.prev_unit_vector, dim=-1), 0, 1 + ) + ) + # We do not expect the rotation angle for a single step to be so large + angle = torch.clip(angle, -torch.pi / 20, torch.pi / 20) + self.prev_unit_vector = new_unit_vector + + # 2. object velocity + obj_vel = torch.linalg.norm(self.obj.get_linear_velocity(), dim=-1) + + # 3. object falling + obj_fall = (obj_pose.p[:, 2] < self.hand_init_height - 0.05).to(torch.bool) + + # 4. finger object distance + tip_poses = [vectorize_pose(link.pose) for link in self.agent.tip_links] + tip_poses = torch.stack(tip_poses, dim=1) # (b, 4, 7) + obj_tip_vec = tip_poses[..., :3] - obj_pose.p[:, None, :] # (b, 4, 3) + obj_tip_dist = torch.linalg.norm(obj_tip_vec, dim=-1) # (b, 4) + + # 5. cum rotation angle + self.cum_rotation_angle += angle + success = self.cum_rotation_angle > self.success_threshold + + # 6. controller effort + qpos_target = self.agent.controller._target_qpos + qpos_error = qpos_target - self.agent.robot.qpos + qvel = self.agent.robot.qvel + qf = qpos_error * self.controller_param[0] - qvel * self.controller_param[1] + qf = torch.clip(qf, -self.controller_param[2], self.controller_param[2]) + power = torch.sum(qf * qvel, dim=-1) + + return dict( + rotation_angle=angle, + obj_vel=obj_vel, + obj_fall=obj_fall, + obj_tip_vec=obj_tip_vec, + obj_tip_dist=obj_tip_dist, + success=success, + qf=qf, + power=power, + fail=obj_fall, + ) + + def compute_dense_reward(self, obs: Any, action: Array, info: Dict): + # 1. rotation reward + angle = info["rotation_angle"] + reward = 20 * angle + + # 2. velocity penalty + obj_vel = info["obj_vel"] + reward += -0.1 * obj_vel + + # 3. falling penalty + obj_fall = info["obj_fall"] + reward += -50.0 * obj_fall + + # 4. effort penalty + power = torch.abs(info["power"]) + reward += -0.0003 * power + + # 5. torque penalty + qf = info["qf"] + qf_norm = torch.linalg.norm(qf, dim=-1) + reward += -0.0003 * qf_norm + + # 6. finger object distance reward + obj_tip_dist = info["obj_tip_dist"] + distance_rew = 0.1 / (0.02 + 4 * obj_tip_dist) + reward += torch.mean(torch.clip(distance_rew, 0, 1), dim=-1) + + return reward + + def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): + # this should be equal to compute_dense_reward / max possible reward + return self.compute_dense_reward(obs=obs, action=action, info=info) / 4.0 + + +@register_env("RotateSingleObjectInHandLevel0-v1", max_episode_steps=300) +class RotateSingleObjectInHandLevel0(RotateSingleObjectInHand): + def __init__(self, *args, **kwargs): + super().__init__( + *args, + robot_init_qpos_noise=0.02, + obj_init_pos_noise=0.02, + difficulty_level=0, + **kwargs, + ) + + +@register_env("RotateSingleObjectInHandLevel1-v1", max_episode_steps=300) +class RotateSingleObjectInHandLevel1(RotateSingleObjectInHand): + def __init__(self, *args, **kwargs): + super().__init__( + *args, + robot_init_qpos_noise=0.02, + obj_init_pos_noise=0.02, + difficulty_level=1, + **kwargs, + ) + + +@register_env("RotateSingleObjectInHandLevel2-v1", max_episode_steps=300) +class RotateSingleObjectInHandLevel2(RotateSingleObjectInHand): + def __init__(self, *args, **kwargs): + super().__init__( + *args, + robot_init_qpos_noise=0.02, + obj_init_pos_noise=0.02, + difficulty_level=2, + **kwargs, + ) + + +@register_env("RotateSingleObjectInHandLevel3-v1", max_episode_steps=300) +class RotateSingleObjectInHandLevel3(RotateSingleObjectInHand): + def __init__(self, *args, **kwargs): + super().__init__( + *args, + robot_init_qpos_noise=0.02, + obj_init_pos_noise=0.02, + difficulty_level=3, + **kwargs, + ) diff --git a/mani_skill2/utils/sapien_utils.py b/mani_skill2/utils/sapien_utils.py index 3a57f9a8c..dceca5328 100644 --- a/mani_skill2/utils/sapien_utils.py +++ b/mani_skill2/utils/sapien_utils.py @@ -349,6 +349,30 @@ def get_pairwise_contacts( return pairwise_contacts +def get_multiple_pairwise_contacts( + contacts: List[physx.PhysxContact], + actor0: sapien.Entity, + actor1_list: List[sapien.Entity], +) -> Dict[sapien.Entity, List[Tuple[physx.PhysxContact, bool]]]: + """ + Given a list of contacts, return the dict of contacts involving the one actor and actors + This function is used to avoid double for-loop when using `get_pairwise_contacts` with multiple actors + """ + pairwise_contacts = {actor: [] for actor in actor1_list} + for contact in contacts: + if ( + contact.bodies[0].entity == actor0 + and contact.bodies[1].entity in actor1_list + ): + pairwise_contacts[contact.bodies[1].entity].append((contact, True)) + elif ( + contact.bodies[0].entity in actor1_list + and contact.bodies[1].entity == actor0 + ): + pairwise_contacts[contact.bodies[0].entity].append((contact, False)) + return pairwise_contacts + + def compute_total_impulse(contact_infos: List[Tuple[physx.PhysxContact, bool]]): total_impulse = np.zeros(3) for contact, flag in contact_infos: @@ -378,6 +402,21 @@ def get_actor_contacts( return entity_contacts +def get_actors_contacts( + contacts: List[physx.PhysxContact], actors: List[sapien.Entity] +) -> Dict[sapien.Entity, List[Tuple[physx.PhysxContact, bool]]]: + """ + This function is used to avoid double for-loop when using `get_actor_contacts` with multiple actors + """ + entity_contacts = {actor: [] for actor in actors} + for contact in contacts: + if contact.bodies[0].entity in actors: + entity_contacts[contact.bodies[0].entity].append((contact, True)) + elif contact.bodies[1].entity in actors: + entity_contacts[contact.bodies[1].entity].append((contact, False)) + return entity_contacts + + def get_articulation_contacts( contacts: List[physx.PhysxContact], articulation: physx.PhysxArticulation,