Skip to content

Commit

Permalink
feat: Handle situations that do not require clamp
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed Aug 26, 2024
1 parent 4868c7d commit db09012
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 11 deletions.
12 changes: 10 additions & 2 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,16 @@ void RL::ReadYaml(std::string robot_name)
this->params.num_observations = config["num_observations"].as<int>();
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
this->params.clip_obs = config["clip_obs"].as<double>();
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
if(config["clip_actions_lower"] && config["clip_actions_upper"])
{
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
}
else
{
this->params.clip_actions_upper = torch::tensor({}).view({1, -1});
this->params.clip_actions_lower = torch::tensor({}).view({1, -1});
}
this->params.action_scale = config["action_scale"].as<double>();
this->params.hip_scale_reduction = config["hip_scale_reduction"].as<double>();
this->params.hip_scale_reduction_indices = ReadVectorFromYaml<int>(config["hip_scale_reduction_indices"]);
Expand Down
8 changes: 6 additions & 2 deletions src/rl_sar/scripts/rl_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,12 @@ def ReadYaml(self, robot_name):
self.params.action_scale = config["action_scale"]
self.params.hip_scale_reduction = config["hip_scale_reduction"]
self.params.hip_scale_reduction_indices = config["hip_scale_reduction_indices"]
self.params.clip_actions_upper = torch.tensor(self.ReadVectorFromYaml(config["clip_actions_upper"], self.params.framework, rows, cols)).view(1, -1)
self.params.clip_actions_lower = torch.tensor(self.ReadVectorFromYaml(config["clip_actions_lower"], self.params.framework, rows, cols)).view(1, -1)
if config["clip_actions_upper"] and config["clip_actions_upper"]:
self.params.clip_actions_upper = torch.tensor(self.ReadVectorFromYaml(config["clip_actions_upper"], self.params.framework, rows, cols)).view(1, -1)
self.params.clip_actions_lower = torch.tensor(self.ReadVectorFromYaml(config["clip_actions_lower"], self.params.framework, rows, cols)).view(1, -1)
else:
self.params.clip_actions_upper = None
self.params.clip_actions_lower = None
self.params.num_of_dofs = config["num_of_dofs"]
self.params.lin_vel_scale = config["lin_vel_scale"]
self.params.ang_vel_scale = config["ang_vel_scale"]
Expand Down
6 changes: 4 additions & 2 deletions src/rl_sar/scripts/rl_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,10 @@ def Forward(self):
actions = self.model.forward(history_obs)
else:
actions = self.model.forward(clamped_obs)
clamped_actions = torch.clamp(actions, self.params.clip_actions_lower, self.params.clip_actions_upper)
return clamped_actions
if self.params.clip_actions_lower is not None and self.params.clip_actions_upper is not None:
return torch.clamp(actions, self.params.clip_actions_lower, self.params.clip_actions_upper)
else:
return actions

def ThreadControl(self):
thread_period = self.params.dt
Expand Down
11 changes: 8 additions & 3 deletions src/rl_sar/src/rl_real_a1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,14 @@ torch::Tensor RL_Real::Forward()

torch::Tensor actions = this->model.forward({this->history_obs}).toTensor();

torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);

return clamped_actions;
if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{
return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
}
else
{
return actions;
}
}

void RL_Real::Plot()
Expand Down
10 changes: 8 additions & 2 deletions src/rl_sar/src/rl_sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,14 @@ torch::Tensor RL_Sim::Forward()
actions = this->model.forward({clamped_obs}).toTensor();
}

torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
return clamped_actions;
if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{
return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
}
else
{
return actions;
}
}

void RL_Sim::Plot()
Expand Down

0 comments on commit db09012

Please sign in to comment.