Skip to content

Commit

Permalink
fix: add this to class member
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed May 29, 2024
1 parent be79e7f commit 8327b52
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 236 deletions.
5 changes: 2 additions & 3 deletions src/rl_sar/include/rl_sim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,19 @@ class RL_Sim : public RL
geometry_msgs::Twist vel;
geometry_msgs::Pose pose;
geometry_msgs::Twist cmd_vel;
std::vector<std::string> torque_command_topics;
ros::Subscriber model_state_subscriber;
ros::Subscriber joint_state_subscriber;
ros::Subscriber cmd_vel_subscriber;
std::map<std::string, ros::Publisher> torque_publishers;
ros::ServiceClient gazebo_reset_client;
std::map<std::string, ros::Publisher> joint_publishers;
std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);

// others
int motiontime = 0;
std::map<std::string, size_t> sorted_to_original_index;
std::vector<robot_msgs::MotorCommand> motor_commands;
std::vector<double> mapped_joint_positions;
std::vector<double> mapped_joint_velocities;
std::vector<double> mapped_joint_efforts;
Expand Down
104 changes: 52 additions & 52 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ void RL::InitObservations()
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
this->obs.dof_pos = this->params.default_dof_pos;
this->obs.dof_vel = torch::zeros({1, params.num_of_dofs});
this->obs.actions = torch::zeros({1, params.num_of_dofs});
this->obs.dof_vel = torch::zeros({1, this->params.num_of_dofs});
this->obs.actions = torch::zeros({1, this->params.num_of_dofs});
}

void RL::InitOutputs()
{
this->output_torques = torch::zeros({1, params.num_of_dofs});
this->output_dof_pos = params.default_dof_pos;
this->output_torques = torch::zeros({1, this->params.num_of_dofs});
this->output_dof_pos = this->params.default_dof_pos;
}

void RL::InitControl()
Expand Down Expand Up @@ -88,112 +88,112 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
static float getdown_percent = 0.0;

// waiting
if(running_state == STATE_WAITING)
if(this->running_state == STATE_WAITING)
{
for(int i = 0; i < params.num_of_dofs; ++i)
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = state->motor_state.q[i];
}
if(control.control_state == STATE_POS_GETUP)
if(this->control.control_state == STATE_POS_GETUP)
{
control.control_state = STATE_WAITING;
this->control.control_state = STATE_WAITING;
getup_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i)
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
start_state.motor_state.q[i] = now_state.motor_state.q[i];
}
running_state = STATE_POS_GETUP;
this->running_state = STATE_POS_GETUP;
}
}
// stand up (position control)
else if(running_state == STATE_POS_GETUP)
else if(this->running_state == STATE_POS_GETUP)
{
if(getup_percent < 1.0)
{
getup_percent += 1 / 1000.0;
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
for(int i = 0; i < params.num_of_dofs; ++i)
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * params.default_dof_pos[0][i].item<double>();
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = params.fixed_kp[0][i].item<double>();
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
}
if(control.control_state == STATE_RL_INIT)
if(this->control.control_state == STATE_RL_INIT)
{
std::cout << std::endl;
control.control_state = STATE_WAITING;
running_state = STATE_RL_INIT;
this->control.control_state = STATE_WAITING;
this->running_state = STATE_RL_INIT;
}
else if(control.control_state == STATE_POS_GETDOWN)
else if(this->control.control_state == STATE_POS_GETDOWN)
{
control.control_state = STATE_WAITING;
this->control.control_state = STATE_WAITING;
getdown_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i)
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
running_state = STATE_POS_GETDOWN;
this->running_state = STATE_POS_GETDOWN;
}
}
// init obs and start rl loop
else if(running_state == STATE_RL_INIT)
else if(this->running_state == STATE_RL_INIT)
{
if(getup_percent == 1)
{
running_state = STATE_RL_RUNNING;
this->running_state = STATE_RL_RUNNING;
this->InitObservations();
this->InitOutputs();
this->InitControl();
}
}
// rl loop
else if(running_state == STATE_RL_RUNNING)
else if(this->running_state == STATE_RL_RUNNING)
{
for(int i = 0; i < params.num_of_dofs; ++i)
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = output_dof_pos[0][i].item<double>();
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = params.rl_kp[0][i].item<double>();
command->motor_command.kd[i] = params.rl_kd[0][i].item<double>();
command->motor_command.kp[i] = this->params.rl_kp[0][i].item<double>();
command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
if(control.control_state == STATE_POS_GETDOWN)
if(this->control.control_state == STATE_POS_GETDOWN)
{
control.control_state = STATE_WAITING;
this->control.control_state = STATE_WAITING;
getdown_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i)
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
running_state = STATE_POS_GETDOWN;
this->running_state = STATE_POS_GETDOWN;
}
}
// get down (position control)
else if(running_state == STATE_POS_GETDOWN)
else if(this->running_state == STATE_POS_GETDOWN)
{
if(getdown_percent < 1.0)
{
getdown_percent += 1 / 1000.0;
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
for(int i = 0; i < params.num_of_dofs; ++i)
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i];
command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = params.fixed_kp[0][i].item<double>();
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r";
}
if(getdown_percent == 1)
{
std::cout << std::endl;
running_state = STATE_WAITING;
this->running_state = STATE_WAITING;
this->InitObservations();
this->InitOutputs();
this->InitControl();
Expand Down Expand Up @@ -229,7 +229,7 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
}
control.control_state = STATE_POS_GETDOWN;
this->control.control_state = STATE_POS_GETDOWN;
}
}

Expand All @@ -254,30 +254,30 @@ static bool kbhit()

void RL::KeyboardInterface()
{
if(running_state == STATE_RL_RUNNING)
if(this->running_state == STATE_RL_RUNNING)
{
std::cout << LOGGER::INFO << "RL Controller x:" << control.x << " y:" << control.y << " yaw:" << control.yaw << " \r";
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
}

if(kbhit())
{
int c = fgetc(stdin);
switch(c)
{
case '0': control.control_state = STATE_POS_GETUP; break;
case 'p': control.control_state = STATE_RL_INIT; break;
case '1': control.control_state = STATE_POS_GETDOWN; break;
case '0': this->control.control_state = STATE_POS_GETUP; break;
case 'p': this->control.control_state = STATE_RL_INIT; break;
case '1': this->control.control_state = STATE_POS_GETDOWN; break;
case 'q': break;
case 'w': control.x += 0.1; break;
case 's': control.x -= 0.1; break;
case 'a': control.yaw += 0.1; break;
case 'd': control.yaw -= 0.1; break;
case 'w': this->control.x += 0.1; break;
case 's': this->control.x -= 0.1; break;
case 'a': this->control.yaw += 0.1; break;
case 'd': this->control.yaw -= 0.1; break;
case 'i': break;
case 'k': break;
case 'j': control.y += 0.1; break;
case 'l': control.y -= 0.1; break;
case ' ': control.x = 0; control.y = 0; control.yaw = 0; break;
case 'r': control.control_state = STATE_RESET_SIMULATION; break;
case 'j': this->control.y += 0.1; break;
case 'l': this->control.y -= 0.1; break;
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
default: break;
}
}
Expand Down
Loading

0 comments on commit 8327b52

Please sign in to comment.