diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index eaf83ca..0360d2f 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -3,6 +3,7 @@ #define ROBOT_NAME "a1" // #define PLOT +#define CSV_LOGGER RL_Real rl_sar; @@ -23,9 +24,9 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) this->vq = torch::jit::load(vq_path); this->InitObservations(); this->InitOutputs(); - + this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); - + plot_real_joint_pos.resize(12); plot_target_joint_pos.resize(12); @@ -37,11 +38,15 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) loop_udpSend->start(); loop_udpRecv->start(); loop_control->start(); - + #ifdef PLOT loop_plot = std::make_shared("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this)); loop_plot->start(); #endif + +#ifdef CSV_LOGGER + CSVInit(ROBOT_NAME); +#endif } RL_Real::~RL_Real() @@ -97,7 +102,7 @@ void RL_Real::RobotControl() cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].tau = 0; } - printf("getting up %.3f%%\r", getup_percent*100.0); + printf("getting up %.3f%%\r", getup_percent * 100.0); } if((int)_keyData.btn.components.R1 == 1) { @@ -167,7 +172,7 @@ void RL_Real::RobotControl() cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].tau = 0; } - printf("getting down %.3f%%\r", getdown_percent*100.0); + printf("getting down %.3f%%\r", getdown_percent * 100.0); } if(getdown_percent == 1) { @@ -205,7 +210,7 @@ void RL_Real::RunModel() // state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, // state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, // state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq); - + this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}}); this->obs.commands = torch::tensor({{_keyData.ly, -_keyData.rx, -_keyData.lx}}); this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}}); @@ -217,7 +222,7 @@ void RL_Real::RunModel() state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}}); - + torch::Tensor actions = this->Forward(); for (int i : hip_scale_reduction_indices) @@ -227,8 +232,14 @@ void RL_Real::RunModel() output_torques = this->ComputeTorques(actions); output_dof_pos = this->ComputePosition(actions); +#ifdef CSV_LOGGER + torch::Tensor tau_est = torch::tensor({{state.motorState[FL_0].tauEst, state.motorState[FL_1].tauEst, state.motorState[FL_2].tauEst, + state.motorState[FR_0].tauEst, state.motorState[FR_1].tauEst, state.motorState[FR_2].tauEst, + state.motorState[RL_0].tauEst, state.motorState[RL_1].tauEst, state.motorState[RL_2].tauEst, + state.motorState[RR_0].tauEst, state.motorState[RR_1].tauEst, state.motorState[RR_2].tauEst}}); + CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel); +#endif } - } torch::Tensor RL_Real::ComputeObservation() @@ -275,10 +286,10 @@ void RL_Real::Plot() { plot_real_joint_pos[i].push_back(state.motorState[i].q); plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q); - plt::subplot(4, 3, i+1); + plt::subplot(4, 3, i + 1); plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r"); plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b"); - plt::xlim(motiontime-10000, motiontime); + plt::xlim(motiontime - 10000, motiontime); } // plt::legend(); plt::pause(0.0001);