Skip to content

Commit

Permalink
feat: add TorqueProtect
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed May 27, 2024
1 parent 03cccc1 commit e84f2d8
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 14 deletions.
48 changes: 42 additions & 6 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
// stand up (position control)
else if(running_state == STATE_POS_GETUP)
{
if(getup_percent != 1)
if(getup_percent < 1.0)
{
getup_percent += 1 / 1000.0;
getup_percent = getup_percent > 1 ? 1 : getup_percent;
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getup_percent) * now_pos[i] + getup_percent * params.default_dof_pos[0][i].item<double>();
Expand All @@ -122,10 +122,11 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
printf("getting up %.3f%%\r", getup_percent*100.0);
printf("Getting up %.3f%%\r", getup_percent*100.0);
}
if(keyboard.keyboard_state == STATE_RL_INIT)
{
std::cout << std::endl;
keyboard.keyboard_state = STATE_WAITING;
running_state = STATE_RL_INIT;
}
Expand Down Expand Up @@ -154,6 +155,8 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
// rl loop
else if(running_state == STATE_RL_RUNNING)
{
std::cout << "[RL Controller] x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r";

for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = output_dof_pos[0][i].item<double>();
Expand All @@ -176,10 +179,10 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
// get down (position control)
else if(running_state == STATE_POS_GETDOWN)
{
if(getdown_percent != 1)
if(getdown_percent < 1.0)
{
getdown_percent += 1 / 1000.0;
getdown_percent = getdown_percent > 1 ? 1 : getdown_percent;
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getdown_percent) * now_pos[i] + getdown_percent * start_pos[i];
Expand All @@ -188,10 +191,11 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
printf("getting down %.3f%%\r", getdown_percent*100.0);
printf("Getting down %.3f%%\r", getdown_percent*100.0);
}
if(getdown_percent == 1)
{
std::cout << std::endl;
running_state = STATE_WAITING;
this->InitObservations();
this->InitOutputs();
Expand All @@ -200,6 +204,38 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
}

void RL::TorqueProtect(torch::Tensor origin_output_torques)
{
std::vector<int> out_of_range_indices;
std::vector<double> out_of_range_values;
for(int i = 0; i < origin_output_torques.size(1); ++i)
{
double torque_value = origin_output_torques[0][i].item<double>();
double limit_lower = -this->params.torque_limits[0][i].item<double>();
double limit_upper = this->params.torque_limits[0][i].item<double>();

if(torque_value < limit_lower || torque_value > limit_upper)
{
out_of_range_indices.push_back(i);
out_of_range_values.push_back(torque_value);
}
}
if(!out_of_range_indices.empty())
{
std::cout << "Error: origin_output_torques is out of range at indices: ";
for(int i = 0; i < out_of_range_indices.size(); ++i)
{
std::cout << out_of_range_indices[i] << " (value: " << out_of_range_values[i] << ")";
if(i < out_of_range_indices.size() - 1)
{
std::cout << ", ";
}
}
std::cout << std::endl;
keyboard.keyboard_state = STATE_POS_GETDOWN;
}
}

#include <termios.h>
#include <sys/ioctl.h>
static bool kbhit()
Expand Down
3 changes: 3 additions & 0 deletions src/rl_sar/library/rl_sdk/rl_sdk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ class RL
std::string robot_name;
STATE running_state = STATE_WAITING;

// protect func
void TorqueProtect(torch::Tensor origin_output_torques);

protected:
// rl module
torch::jit::script::Module model;
Expand Down
7 changes: 5 additions & 2 deletions src/rl_sar/src/rl_real_a1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ void RL_Real::RunModel()

this->obs.actions = clamped_actions;

// torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
// output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);

TorqueProtect(origin_output_torques);

output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
output_dof_pos = this->ComputePosition(this->obs.actions);

#ifdef CSV_LOGGER
Expand Down
11 changes: 5 additions & 6 deletions src/rl_sar/src/rl_sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,6 @@ void RL_Sim::SetCommand(const RobotCommand<double> *command)

void RL_Sim::RobotControl()
{
std::cout << "running_state:" << keyboard.keyboard_state
<< " x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw
<< " \r";

motiontime++;

if(keyboard.keyboard_state == STATE_RESET_SIMULATION)
Expand Down Expand Up @@ -198,8 +194,11 @@ void RL_Sim::RunModel()

this->obs.actions = clamped_actions;

// torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
// output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);

TorqueProtect(origin_output_torques);

output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
output_dof_pos = this->ComputePosition(this->obs.actions);

#ifdef CSV_LOGGER
Expand Down

0 comments on commit e84f2d8

Please sign in to comment.