Skip to content

Commit

Permalink
Add the new reward for gait, which is based on the feet body name
Browse files Browse the repository at this point in the history
  • Loading branch information
farbod-farshidian committed Jun 3, 2024
1 parent ff651f8 commit 9a9a311
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
class LinearInterpolation:
"""Linearly interpolates a sampled scalar function ``y = f(x)`` where :math:`f: R -> R`.
It assumes that the function's domain, X, is sampled in an ascending order. For the query points out of
the sampling range of X, the class does a zero-order-hold extrapolation based on the boundary values.
It assumes that the function's domain, X, is sampled in an ascending order. For the query points out the
sampling range of X, the class does a zero-order-hold extrapolation based on the boundary values.
"""

def __init__(self, x: torch.Tensor, y: torch.Tensor, device: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,14 @@ class SpotRewardsCfg:
},
)
gait = RewardTermCfg(
func=spot_mdp.gait_reward,
func=spot_mdp.GaitReward,
weight=10.0,
params={"std": 0.1, "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_foot")},
params={
"std": 0.1,
"max_err": 0.2,
"synced_feet_pair_names": (("fl_foot", "hr_foot"), ("fr_foot", "hl_foot")),
"sensor_cfg": SceneEntityCfg("contact_forces"),
},
)

# -- penalties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from typing import TYPE_CHECKING

from omni.isaac.lab.assets import Articulation, RigidObject
from omni.isaac.lab.managers import SceneEntityCfg
from omni.isaac.lab.managers import ManagerTermBase, SceneEntityCfg
from omni.isaac.lab.sensors import ContactSensor

if TYPE_CHECKING:
from omni.isaac.lab.envs import ManagerBasedRLEnv
from omni.isaac.lab.managers import RewardTermCfg

# -- Task Rewards

Expand Down Expand Up @@ -62,41 +63,80 @@ def base_linear_velocity_reward(
return torch.exp(-lin_vel_error / std) * velocity_scaling_multiple


# ! need to finalize logic, params, and docstring
def gait_reward(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg, std: float) -> torch.Tensor:
"""Penalize ..."""
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
if contact_sensor.cfg.track_air_time is False:
raise RuntimeError("Activate ContactSensor's track_air_time!")
# compute the reward
air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]
contact_time = contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids]

max_err = 0.2
indices_0 = [0, 1]
indices_1 = [2, 3]
cmd = torch.norm(env.command_manager.get_command("base_velocity"), dim=1)
asym_err_0 = torch.clip(
torch.square(air_time[:, indices_0[0]] - contact_time[:, indices_0[1]]), max=max_err**2
) + torch.clip(torch.square(contact_time[:, indices_0[0]] - air_time[:, indices_0[1]]), max=max_err**2)
asym_err_1 = torch.clip(
torch.square(air_time[:, indices_1[0]] - contact_time[:, indices_1[1]]), max=max_err**2
) + torch.clip(torch.square(contact_time[:, indices_1[0]] - air_time[:, indices_1[1]]), max=max_err**2)
asym_err_2 = torch.clip(
torch.square(air_time[:, indices_0[0]] - contact_time[:, indices_1[0]]), max=max_err**2
) + torch.clip(torch.square(contact_time[:, indices_0[0]] - air_time[:, indices_1[0]]), max=max_err**2)
asym_err_3 = torch.clip(
torch.square(air_time[:, indices_0[1]] - contact_time[:, indices_1[1]]), max=max_err**2
) + torch.clip(torch.square(contact_time[:, indices_0[1]] - air_time[:, indices_1[1]]), max=max_err**2)
sym_err_0 = torch.clip(
torch.square(air_time[:, indices_0[0]] - air_time[:, indices_1[1]]), max=max_err**2
) + torch.clip(torch.square(contact_time[:, indices_0[0]] - contact_time[:, indices_1[1]]), max=max_err**2)
sym_err_1 = torch.clip(
torch.square(air_time[:, indices_0[1]] - air_time[:, indices_1[0]]), max=max_err**2
) + torch.clip(torch.square(contact_time[:, indices_0[1]] - contact_time[:, indices_1[0]]), max=max_err**2)
gait_err = asym_err_0 + asym_err_1 + sym_err_0 + sym_err_1 + asym_err_2 + asym_err_3
return torch.where(cmd > 0.0, torch.exp(-gait_err / std), 0.0)
class GaitReward(ManagerTermBase):
"""Gait enforcing reward term for quadrupeds.
This reward penalizes contact timing differences between selected foot pairs defined in :attr:`synced_feet_pair_names`
to bias the policy towards a desired gait, i.e trotting, bounding, or pacing. Note that this reward is only for
quadrupedal gaits with two pairs of synchronized feet.
"""

def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):
"""Initialize the term.
Args:
cfg: The configuration of the reward.
env: The RL environment instance.
"""
super().__init__(cfg, env)
self.std: float = cfg.params["std"]
self.max_err: float = cfg.params["max_err"]
self.contact_sensor: ContactSensor = env.scene.sensors[cfg.params["sensor_cfg"].name]
# match foot body names with corresponding foot body ids
synced_feet_pair_names = cfg.params["synced_feet_pair_names"]
if (
len(synced_feet_pair_names) != 2
or len(synced_feet_pair_names[0]) != 2
or len(synced_feet_pair_names[1]) != 2
):
raise ValueError("This reward only supports gaits with two pairs of synchronized feet, like trotting.")
synced_feet_pair_0 = self.contact_sensor.find_bodies(synced_feet_pair_names[0])[0]
synced_feet_pair_1 = self.contact_sensor.find_bodies(synced_feet_pair_names[1])[0]
self.synced_feet_pairs = [synced_feet_pair_0, synced_feet_pair_1]

def __call__(self, env: ManagerBasedRLEnv, std, max_err, synced_feet_pair_names, sensor_cfg) -> torch.Tensor:
"""Compute the reward.
This reward is defined as a multiplication between six terms where two of them enforce pair feet
being in sync and the other four rewards if all the other remaining pairs are out of sync
Args:
env: The RL environment instance.
Returns:
The reward value.
"""
# for synchronous feet, the contact (air) times of two feet should match
sync_reward_0 = self._sync_reward_func(self.synced_feet_pairs[0][0], self.synced_feet_pairs[0][1])
sync_reward_1 = self._sync_reward_func(self.synced_feet_pairs[1][0], self.synced_feet_pairs[1][1])
sync_reward = sync_reward_0 * sync_reward_1
# for asynchronous feet, the contact time of one foot should match the air time of the other one
async_reward_0 = self._async_reward_func(self.synced_feet_pairs[0][0], self.synced_feet_pairs[1][0])
async_reward_1 = self._async_reward_func(self.synced_feet_pairs[0][1], self.synced_feet_pairs[1][1])
async_reward_2 = self._async_reward_func(self.synced_feet_pairs[0][0], self.synced_feet_pairs[1][1])
async_reward_3 = self._async_reward_func(self.synced_feet_pairs[1][0], self.synced_feet_pairs[0][1])
async_reward = async_reward_0 * async_reward_1 * async_reward_2 * async_reward_3
# only enforce gait if cmd > 0
cmd = torch.norm(env.command_manager.get_command("base_velocity"), dim=1)
return torch.where(cmd > 0.0, sync_reward * async_reward, 0.0)

def _sync_reward_func(self, foot_0: int, foot_1: int) -> torch.Tensor:
"""Reward synchronization of two feet."""
air_time = self.contact_sensor.data.current_air_time
contact_time = self.contact_sensor.data.current_contact_time
# penalize the difference between the most recent air time and contact time of synced feet pairs.
se_air = torch.clip(torch.square(air_time[:, foot_0] - air_time[:, foot_1]), max=self.max_err**2)
se_contact = torch.clip(torch.square(contact_time[:, foot_0] - contact_time[:, foot_1]), max=self.max_err**2)
return torch.exp(-(se_air + se_contact) / self.std)

def _async_reward_func(self, foot_0: int, foot_1: int) -> torch.Tensor:
"""Reward anti-synchronization of two feet."""
air_time = self.contact_sensor.data.current_air_time
contact_time = self.contact_sensor.data.current_contact_time
# penalize the difference between opposing contact modes air time of feet 1 to contact time of feet 2
# and contact time of feet 1 to air time of feet 2) of feet pairs that are not in sync with each other.
se_act_0 = torch.clip(torch.square(air_time[:, foot_0] - contact_time[:, foot_1]), max=self.max_err**2)
se_act_1 = torch.clip(torch.square(contact_time[:, foot_0] - air_time[:, foot_1]), max=self.max_err**2)
return torch.exp(-(se_act_0 + se_act_1) / self.std)


def foot_clearance_reward(
Expand Down

0 comments on commit 9a9a311

Please sign in to comment.