Skip to content

Commit

Permalink
Fix: everything in PickSequentialTask in gpu mem
Browse files Browse the repository at this point in the history
  • Loading branch information
arth-shukla committed Mar 2, 2024
1 parent af02a7c commit 9893a01
Showing 1 changed file with 88 additions and 88 deletions.
176 changes: 88 additions & 88 deletions mani_skill2/envs/scenes/tasks/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,91 +131,92 @@ def _get_navigable_spawn_positions_with_rots_and_dists(self, center_x, center_y)


def reconfigure(self):
# run reconfiguration
super().reconfigure()

self.scene_builder.initialize(torch.arange(self.num_envs))

if physx.is_gpu_enabled():
self._scene._gpu_apply_all()
self._scene.px.gpu_update_articulation_kinematics()
self._scene._gpu_fetch_all()

# links and entities for force tracking
force_rew_ignore_links = [
self.agent.finger1_link, self.agent.finger2_link, self.agent.tcp,
]
self.force_articulation_link_ids = [
link.name for link in self.agent.robot.get_links() if link not in force_rew_ignore_links
]

# NOTE (arth): targ obj should be same merged actor
obj = self.subtask_objs[0]

spawn_loc_rots = []
spawn_dists = []
for env_idx in range(self.num_envs):
center = obj.pose.p[env_idx, :2]
slr, dists = self._get_navigable_spawn_positions_with_rots_and_dists(
center[0], center[1]
)
spawn_loc_rots.append(slr)
spawn_dists.append(dists)

num_spawn_loc_rots = torch.tensor([len(slr) for slr in spawn_loc_rots])
spawn_loc_rots = pad_sequence(spawn_loc_rots, batch_first=True, padding_value=0).transpose(1, 0)
spawn_dists = pad_sequence(spawn_dists, batch_first=True, padding_value=0).transpose(1, 0)

qpos = torch.tensor(
self.agent.RESTING_QPOS[..., None].repeat(self.num_envs, axis=-1).transpose(1, 0)
).float()
accept_spawn_loc_rots = [[]] * self.num_envs
accept_dists = [[]] * self.num_envs
bounding_box_corners = [
torch.tensor([dx, dy, 0]) for dx, dy in itertools.product([0.1, -0.1], [0.1, -0.1])
]
for slr_num, (slrs, dists) in tqdm(
enumerate(zip(spawn_loc_rots, spawn_dists)), total=spawn_loc_rots.size(0)
):

slrs_within_range = slr_num < num_spawn_loc_rots
robot_force = torch.zeros(self.num_envs)

for shift in bounding_box_corners:
shifted_slrs = slrs + shift

self.agent.controller.reset()
qpos[..., 2] = shifted_slrs[..., 2]
self.agent.reset(qpos)

# ad-hoc use z-rot dim a z-height dim, set using default setting
shifted_slrs[..., 2] = self.agent.robot.pose.p[..., 2]
self.agent.robot.set_pose(Pose.create_from_pq(p=shifted_slrs.float()))

if physx.is_gpu_enabled():
self._scene._gpu_apply_all()
self._scene.px.gpu_update_articulation_kinematics()
self._scene._gpu_fetch_all()

self._scene.step()
with torch.device(self.device):
# run reconfiguration
super().reconfigure()

self.scene_builder.initialize(torch.arange(self.num_envs))

if physx.is_gpu_enabled():
self._scene._gpu_apply_all()
self._scene.px.gpu_update_articulation_kinematics()
self._scene._gpu_fetch_all()

# links and entities for force tracking
force_rew_ignore_links = [
self.agent.finger1_link, self.agent.finger2_link, self.agent.tcp,
]
self.force_articulation_link_ids = [
link.name for link in self.agent.robot.get_links() if link not in force_rew_ignore_links
]

# NOTE (arth): targ obj should be same merged actor
obj = self.subtask_objs[0]

spawn_loc_rots = []
spawn_dists = []
for env_idx in range(self.num_envs):
center = obj.pose.p[env_idx, :2]
slr, dists = self._get_navigable_spawn_positions_with_rots_and_dists(
center[0], center[1]
)
spawn_loc_rots.append(slr)
spawn_dists.append(dists)

num_spawn_loc_rots = torch.tensor([len(slr) for slr in spawn_loc_rots])
spawn_loc_rots = pad_sequence(spawn_loc_rots, batch_first=True, padding_value=0).transpose(1, 0)
spawn_dists = pad_sequence(spawn_dists, batch_first=True, padding_value=0).transpose(1, 0)

qpos = torch.tensor(
self.agent.RESTING_QPOS[..., None].repeat(self.num_envs, axis=-1).transpose(1, 0)
).float()
accept_spawn_loc_rots = [[]] * self.num_envs
accept_dists = [[]] * self.num_envs
bounding_box_corners = [
torch.tensor([dx, dy, 0]) for dx, dy in itertools.product([0.1, -0.1], [0.1, -0.1])
]
for slr_num, (slrs, dists) in tqdm(
enumerate(zip(spawn_loc_rots, spawn_dists)), total=spawn_loc_rots.size(0)
):

slrs_within_range = slr_num < num_spawn_loc_rots
robot_force = torch.zeros(self.num_envs)

for shift in bounding_box_corners:
shifted_slrs = slrs + shift

robot_force += self.agent.robot.get_net_contact_forces(
self.force_articulation_link_ids
).norm(dim=-1).sum(dim=-1).to(torch.device("cpu"))

for i in torch.where(slrs_within_range & (robot_force < 1e-3))[0]:
accept_spawn_loc_rots[i].append(slrs[i].cpu().numpy().tolist())
accept_dists[i].append(dists[i].cpu().numpy().tolist())


self.num_spawn_loc_rots = torch.tensor([len(x) for x in accept_spawn_loc_rots])
self.spawn_loc_rots = pad_sequence([
torch.tensor(x) for x in accept_spawn_loc_rots
], batch_first=True, padding_value=0,)

self.closest_spawn_loc_rots = torch.stack([
self.spawn_loc_rots[i][torch.argmin(torch.tensor(x))] for i, x in enumerate(accept_dists)
], dim=0)
self.agent.controller.reset()
qpos[..., 2] = shifted_slrs[..., 2]
self.agent.reset(qpos)

# ad-hoc use z-rot dim a z-height dim, set using default setting
shifted_slrs[..., 2] = self.agent.robot.pose.p[..., 2]
self.agent.robot.set_pose(Pose.create_from_pq(p=shifted_slrs.float()))

if physx.is_gpu_enabled():
self._scene._gpu_apply_all()
self._scene.px.gpu_update_articulation_kinematics()
self._scene._gpu_fetch_all()

self._scene.step()

robot_force += self.agent.robot.get_net_contact_forces(
self.force_articulation_link_ids
).norm(dim=-1).sum(dim=-1)

for i in torch.where(slrs_within_range & (robot_force < 1e-3))[0]:
accept_spawn_loc_rots[i].append(slrs[i].cpu().numpy().tolist())
accept_dists[i].append(dists[i].cpu().numpy().tolist())


self.num_spawn_loc_rots = torch.tensor([len(x) for x in accept_spawn_loc_rots])
self.spawn_loc_rots = pad_sequence([
torch.tensor(x) for x in accept_spawn_loc_rots
], batch_first=True, padding_value=0,)

self.closest_spawn_loc_rots = torch.stack([
self.spawn_loc_rots[i][torch.argmin(torch.tensor(x))] for i, x in enumerate(accept_dists)
], dim=0)

# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -248,12 +249,11 @@ def _initialize_agent(self, env_idx):
# NOTE (arth): it is assumed that scene builder spawns agent with some qpos
qpos = self.agent.robot.get_qpos()

cpu_device = torch.device("cpu")
if self.randomize_loc:
idxs = torch.tensor([
torch.randint(max_idx.item(), (1,), device=cpu_device) for max_idx in self.num_spawn_loc_rots
], device=cpu_device)
loc_rot = self.spawn_loc_rots[torch.arange(self.num_envs, device=cpu_device), idxs].to(self.device)
torch.randint(max_idx.item(), (1,)) for max_idx in self.num_spawn_loc_rots
])
loc_rot = self.spawn_loc_rots[torch.arange(self.num_envs), idxs].to(self.device)
else:
loc_rot = self.closest_spawn_loc_rots.to(self.device)

Expand Down

0 comments on commit 9893a01

Please sign in to comment.