Skip to content

Commit

Permalink
feat: 1.fix lower upper bug
Browse files Browse the repository at this point in the history
2. ComputeTorques return origin value
3. change RobotState-IMU std::vector<T>
4. Update RunModel && ComputeObservation && Forward funcs
5. update a1 pt model
  • Loading branch information
fan-ziqi committed May 27, 2024
1 parent 109ded0 commit 03cccc1
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 63 deletions.
8 changes: 4 additions & 4 deletions src/rl_sar/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
a1:
model_name: "model_0522.pt"
model_name: "model_0526.pt"
num_observations: 45
clip_obs: 100.0
clip_actions_lower: [-100, -100, -100,
Expand Down Expand Up @@ -34,7 +34,7 @@ a1:
ang_vel_scale: 0.25
dof_pos_scale: 1.0
dof_vel_scale: 0.05
commands_scale: [2.0, 2.0, 0.5]
commands_scale: [2.0, 2.0, 1.0]
torque_limits: [33.5, 33.5, 33.5,
33.5, 33.5, 33.5,
33.5, 33.5, 33.5,
Expand All @@ -52,10 +52,10 @@ gr1t1:
model_name: "model_4000_jit.pt"
num_observations: 39
clip_obs: 100.0
clip_actions_upper: [1.1391, 1.0491, 1.0491, 2.2691, 0.8691,
0.4391, 1.0491, 1.0491, 2.2691, 0.8691]
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
clip_actions_upper: [1.1391, 1.0491, 1.0491, 2.2691, 0.8691,
0.4391, 1.0491, 1.0491, 2.2691, 0.8691]
rl_kp: [57.0, 43.0, 114.0, 114.0, 15.3,
57.0, 43.0, 114.0, 114.0, 15.3]
rl_kd: [5.7, 4.3, 11.4, 11.4, 1.5,
Expand Down
1 change: 0 additions & 1 deletion src/rl_sar/include/rl_real_a1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class RL_Real : public RL
std::vector<double> mapped_joint_velocities;
int command_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
int hip_scale_reduction_indices[4] = {0, 3, 6, 9};
};

#endif
5 changes: 1 addition & 4 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ torch::Tensor RL::Forward()
}
*/



void RL::InitObservations()
{
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
Expand Down Expand Up @@ -68,8 +66,7 @@ torch::Tensor RL::ComputeTorques(torch::Tensor actions)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
torch::Tensor output_torques = this->params.rl_kp * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.rl_kd * this->obs.dof_vel;
torch::Tensor clamped = torch::clamp(output_torques, -(this->params.torque_limits), this->params.torque_limits);
return clamped;
return output_torques;
}

torch::Tensor RL::ComputePosition(torch::Tensor actions)
Expand Down
6 changes: 3 additions & 3 deletions src/rl_sar/library/rl_sdk/rl_sdk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ struct RobotState
{
struct IMU
{
T quaternion[4] = {1.0, 0.0, 0.0, 0.0}; // w, x, y, z
T gyroscope[3] = {0.0, 0.0, 0.0};
T accelerometer[3] = {0.0, 0.0, 0.0};
std::vector<T> quaternion = {1.0, 0.0, 0.0, 0.0}; // w, x, y, z
std::vector<T> gyroscope = {0.0, 0.0, 0.0};
std::vector<T> accelerometer = {0.0, 0.0, 0.0};
} imu;

struct MotorState
Expand Down
Binary file added src/rl_sar/models/a1/model_0526.pt
Binary file not shown.
58 changes: 25 additions & 33 deletions src/rl_sar/src/rl_real_a1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,14 @@ void RL_Real::GetState(RobotState<double> *state)
keyboard.keyboard_state = STATE_POS_GETDOWN;
}

for(int i = 0; i < 4; ++i)
{
state->imu.quaternion[i] = unitree_low_state.imu.quaternion[i];
}
state->imu.quaternion[3] = unitree_low_state.imu.quaternion[0]; // w
state->imu.quaternion[0] = unitree_low_state.imu.quaternion[1]; // x
state->imu.quaternion[1] = unitree_low_state.imu.quaternion[2]; // y
state->imu.quaternion[2] = unitree_low_state.imu.quaternion[3]; // z
for(int i = 0; i < 3; ++i)
{
state->imu.gyroscope[i] = unitree_low_state.imu.gyroscope[i];
}

// state->imu.accelerometer

for(int i = 0; i < params.num_of_dofs; ++i)
{
state->motor_state.q[i] = unitree_low_state.motorState[state_mapping[i]].q;
Expand All @@ -116,8 +113,8 @@ void RL_Real::SetCommand(const RobotCommand<double> *command)
unitree_low_command.motorCmd[i].tau = command->motor_command.tau[command_mapping[i]];
}

unitree_safe.PowerProtect(unitree_low_command, unitree_low_state, 8);
// safe.PositionProtect(unitree_low_command, unitree_low_state);
unitree_safe.PowerProtect(unitree_low_command, unitree_low_state, 6);
// unitree_safe.PositionProtect(unitree_low_command, unitree_low_state);
unitree_udp.SetSend(unitree_low_command);
}

Expand All @@ -134,34 +131,27 @@ void RL_Real::RunModel()
{
if(running_state == STATE_RL_RUNNING)
{
this->obs.ang_vel = torch::tensor({{unitree_low_state.imu.gyroscope[0], unitree_low_state.imu.gyroscope[1], unitree_low_state.imu.gyroscope[2]}});
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0);
this->obs.commands = torch::tensor({{unitree_joy.ly, -unitree_joy.rx, -unitree_joy.lx}});
this->obs.base_quat = torch::tensor({{unitree_low_state.imu.quaternion[1], unitree_low_state.imu.quaternion[2], unitree_low_state.imu.quaternion[3], unitree_low_state.imu.quaternion[0]}});
this->obs.dof_pos = torch::tensor({{unitree_low_state.motorState[3].q, unitree_low_state.motorState[4].q, unitree_low_state.motorState[5].q,
unitree_low_state.motorState[0].q, unitree_low_state.motorState[1].q, unitree_low_state.motorState[2].q,
unitree_low_state.motorState[9].q, unitree_low_state.motorState[10].q, unitree_low_state.motorState[11].q,
unitree_low_state.motorState[6].q, unitree_low_state.motorState[7].q, unitree_low_state.motorState[8].q}});
this->obs.dof_vel = torch::tensor({{unitree_low_state.motorState[3].dq, unitree_low_state.motorState[4].dq, unitree_low_state.motorState[5].dq,
unitree_low_state.motorState[0].dq, unitree_low_state.motorState[1].dq, unitree_low_state.motorState[2].dq,
unitree_low_state.motorState[9].dq, unitree_low_state.motorState[10].dq, unitree_low_state.motorState[11].dq,
unitree_low_state.motorState[6].dq, unitree_low_state.motorState[7].dq, unitree_low_state.motorState[8].dq}});
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0);
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0);

torch::Tensor clamped_actions = this->Forward();

for (int i : hip_scale_reduction_indices)
for (int i : this->params.hip_scale_reduction_indices)
{
clamped_actions[0][i] *= this->params.hip_scale_reduction;
}

this->obs.actions = clamped_actions;

output_torques = this->ComputeTorques(clamped_actions);
output_dof_pos = this->ComputePosition(clamped_actions);
// torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
// output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
output_dof_pos = this->ComputePosition(this->obs.actions);

#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor({{unitree_low_state.motorState[3].tauEst, unitree_low_state.motorState[4].tauEst, unitree_low_state.motorState[5].tauEst,
unitree_low_state.motorState[0].tauEst, unitree_low_state.motorState[1].tauEst, unitree_low_state.motorState[2].tauEst,
unitree_low_state.motorState[9].tauEst, unitree_low_state.motorState[10].tauEst, unitree_low_state.motorState[11].tauEst,
unitree_low_state.motorState[6].tauEst, unitree_low_state.motorState[7].tauEst, unitree_low_state.motorState[8].tauEst}});
torch::Tensor tau_est = torch::tensor(robot_state.motor_state.tauEst).unsqueeze(0);
CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel);
#endif
}
Expand All @@ -177,22 +167,24 @@ torch::Tensor RL_Real::ComputeObservation()
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions
},1);
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return obs;
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}

torch::Tensor RL_Real::Forward()
{
torch::Tensor obs = this->ComputeObservation();
torch::autograd::GradMode::set_enabled(false);

torch::Tensor clamped_obs = this->ComputeObservation();

history_obs_buf.insert(obs);
history_obs_buf.insert(clamped_obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});

torch::Tensor action = this->model.forward({history_obs}).toTensor();
torch::Tensor actions = this->model.forward({history_obs}).toTensor();

torch::Tensor clamped = torch::clamp(action, this->params.clip_actions_upper, this->params.clip_actions_lower);
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);

return clamped;
return clamped_actions;
}

void RL_Real::Plot()
Expand Down
38 changes: 20 additions & 18 deletions src/rl_sar/src/rl_sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ RL_Sim::~RL_Sim()

void RL_Sim::GetState(RobotState<double> *state)
{
state->imu.quaternion[0] = pose.orientation.w;
state->imu.quaternion[1] = pose.orientation.x;
state->imu.quaternion[2] = pose.orientation.y;
state->imu.quaternion[3] = pose.orientation.z;
state->imu.quaternion[3] = pose.orientation.w;
state->imu.quaternion[0] = pose.orientation.x;
state->imu.quaternion[1] = pose.orientation.y;
state->imu.quaternion[2] = pose.orientation.z;

state->imu.gyroscope[0] = vel.angular.x;
state->imu.gyroscope[1] = vel.angular.y;
Expand Down Expand Up @@ -133,8 +133,8 @@ void RL_Sim::SetCommand(const RobotCommand<double> *command)

void RL_Sim::RobotControl()
{
std::cout << "running_state " << keyboard.keyboard_state
<< " x" << keyboard.x << " y" << keyboard.y << " yaw" << keyboard.yaw
std::cout << "running_state:" << keyboard.keyboard_state
<< " x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw
<< " \r";

motiontime++;
Expand Down Expand Up @@ -182,12 +182,12 @@ void RL_Sim::RunModel()
if(running_state == STATE_RL_RUNNING)
{
// this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0);
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
this->obs.commands = torch::tensor({{keyboard.x, keyboard.y, keyboard.yaw}});
this->obs.base_quat = torch::tensor({{pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w}});
this->obs.dof_pos = torch::tensor(mapped_joint_positions).unsqueeze(0);
this->obs.dof_vel = torch::tensor(mapped_joint_velocities).unsqueeze(0);
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0);
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0);

torch::Tensor clamped_actions = this->Forward();

Expand All @@ -198,8 +198,9 @@ void RL_Sim::RunModel()

this->obs.actions = clamped_actions;

// output_torques = this->ComputeTorques(clamped_actions);
output_dof_pos = this->ComputePosition(clamped_actions);
// torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
// output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
output_dof_pos = this->ComputePosition(this->obs.actions);

#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(mapped_joint_efforts).unsqueeze(0);
Expand All @@ -211,34 +212,35 @@ void RL_Sim::RunModel()
torch::Tensor RL_Sim::ComputeObservation()
{
torch::Tensor obs = torch::cat({// this->obs.lin_vel * this->params.lin_vel_scale,
this->obs.ang_vel * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
// this->obs.ang_vel * this->params.ang_vel_scale, // TODO
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions
},1);
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return obs;
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}

torch::Tensor RL_Sim::Forward()
{
torch::autograd::GradMode::set_enabled(false);

torch::Tensor obs = this->ComputeObservation();
torch::Tensor clamped_obs = this->ComputeObservation();

torch::Tensor actions;

if(use_history)
{
history_obs_buf.insert(obs);
history_obs_buf.insert(clamped_obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
actions = this->model.forward({history_obs}).toTensor();
}
else
{
actions = this->model.forward({obs}).toTensor();
actions = this->model.forward({clamped_obs}).toTensor();
}

torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
Expand Down

0 comments on commit 03cccc1

Please sign in to comment.