diff --git a/src/rl_sar/CMakeLists.txt b/src/rl_sar/CMakeLists.txt index c60ac70..b7972e0 100644 --- a/src/rl_sar/CMakeLists.txt +++ b/src/rl_sar/CMakeLists.txt @@ -15,18 +15,18 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}") find_package(gazebo REQUIRED) find_package(catkin REQUIRED COMPONENTS - controller_manager - genmsg - joint_state_controller - robot_state_publisher - roscpp - gazebo_ros - std_msgs - tf - geometry_msgs - robot_msgs - robot_joint_controller - rospy + controller_manager + genmsg + joint_state_controller + robot_state_publisher + roscpp + gazebo_ros + std_msgs + tf + geometry_msgs + robot_msgs + robot_joint_controller + rospy ) find_package(Python3 COMPONENTS Interpreter Development REQUIRED) @@ -36,9 +36,9 @@ link_directories(/usr/local/lib) include_directories(${YAML_CPP_INCLUDE_DIR}) catkin_package( - CATKIN_DEPENDS - robot_joint_controller - rospy + CATKIN_DEPENDS + robot_joint_controller + rospy ) include_directories(library/unitree_legged_sdk_3.2/include) @@ -46,13 +46,13 @@ link_directories(library/unitree_legged_sdk_3.2/lib) set(EXTRA_LIBS -pthread libunitree_legged_sdk_amd64.so lcm) include_directories( - include - ${catkin_INCLUDE_DIRS} - ${unitree_legged_sdk_INCLUDE_DIRS} - library/matplotlibcpp - library/observation_buffer - library/rl_sdk - library/loop + include + ${catkin_INCLUDE_DIRS} + ${unitree_legged_sdk_INCLUDE_DIRS} + library/matplotlibcpp + library/observation_buffer + library/rl_sdk + library/loop ) add_library(rl_sdk library/rl_sdk/rl_sdk.cpp) @@ -60,9 +60,9 @@ target_link_libraries(rl_sdk "${TORCH_LIBRARIES}" Python3::Python Python3::Modul set_property(TARGET rl_sdk PROPERTY CXX_STANDARD 14) find_package(Python3 COMPONENTS NumPy) if(Python3_NumPy_FOUND) - target_link_libraries(rl_sdk Python3::NumPy) + target_link_libraries(rl_sdk Python3::NumPy) else() - target_compile_definitions(rl_sdk WITHOUT_NUMPY) + target_compile_definitions(rl_sdk WITHOUT_NUMPY) endif() add_library(observation_buffer library/observation_buffer/observation_buffer.cpp) @@ -71,13 +71,13 @@ set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14) add_executable(rl_sim src/rl_sim.cpp ) target_link_libraries(rl_sim - ${catkin_LIBRARIES} ${EXTRA_LIBS} + ${catkin_LIBRARIES} ${EXTRA_LIBS} rl_sdk observation_buffer yaml-cpp ) add_executable(rl_real_a1 src/rl_real_a1.cpp ) target_link_libraries(rl_real_a1 - ${catkin_LIBRARIES} ${EXTRA_LIBS} + ${catkin_LIBRARIES} ${EXTRA_LIBS} rl_sdk observation_buffer yaml-cpp ) diff --git a/src/rl_sar/include/rl_real_a1.hpp b/src/rl_sar/include/rl_real_a1.hpp index f60218e..2b15f29 100644 --- a/src/rl_sar/include/rl_real_a1.hpp +++ b/src/rl_sar/include/rl_real_a1.hpp @@ -16,6 +16,7 @@ class RL_Real : public RL public: RL_Real(); ~RL_Real(); + private: // rl functions torch::Tensor Forward() override; @@ -43,8 +44,8 @@ class RL_Real : public RL void Plot(); // unitree interface - void UDPSend(){unitree_udp.Send();} - void UDPRecv(){unitree_udp.Recv();} + void UDPSend() { unitree_udp.Send(); } + void UDPRecv() { unitree_udp.Recv(); } UNITREE_LEGGED_SDK::Safety unitree_safe; UNITREE_LEGGED_SDK::UDP unitree_udp; UNITREE_LEGGED_SDK::LowCmd unitree_low_command = {0}; @@ -59,4 +60,4 @@ class RL_Real : public RL int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; }; -#endif \ No newline at end of file +#endif // RL_REAL_HPP diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index da0c6f3..9601024 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -21,6 +21,7 @@ class RL_Sim : public RL public: RL_Sim(); ~RL_Sim(); + private: // rl functions torch::Tensor Forward() override; @@ -69,7 +70,7 @@ class RL_Sim : public RL std::vector mapped_joint_positions; std::vector mapped_joint_velocities; std::vector mapped_joint_efforts; - void MapData(const std::vector& source_data, std::vector& target_data); + void MapData(const std::vector &source_data, std::vector &target_data); }; -#endif \ No newline at end of file +#endif // RL_SIM_HPP diff --git a/src/rl_sar/launch/gazebo_a1_isaacgym.launch b/src/rl_sar/launch/gazebo_a1_isaacgym.launch index d6e511b..94c7457 100644 --- a/src/rl_sar/launch/gazebo_a1_isaacgym.launch +++ b/src/rl_sar/launch/gazebo_a1_isaacgym.launch @@ -14,7 +14,7 @@ - + @@ -26,7 +26,7 @@ diff --git a/src/rl_sar/launch/gazebo_a1_isaacsim.launch b/src/rl_sar/launch/gazebo_a1_isaacsim.launch index 245ff5f..208004d 100644 --- a/src/rl_sar/launch/gazebo_a1_isaacsim.launch +++ b/src/rl_sar/launch/gazebo_a1_isaacsim.launch @@ -14,7 +14,7 @@ - + @@ -26,7 +26,7 @@ diff --git a/src/rl_sar/launch/gazebo_gr1t1_isaacgym.launch b/src/rl_sar/launch/gazebo_gr1t1_isaacgym.launch index e92169c..d653416 100644 --- a/src/rl_sar/launch/gazebo_gr1t1_isaacgym.launch +++ b/src/rl_sar/launch/gazebo_gr1t1_isaacgym.launch @@ -12,7 +12,7 @@ - + diff --git a/src/rl_sar/launch/gazebo_gr1t2_isaacgym.launch b/src/rl_sar/launch/gazebo_gr1t2_isaacgym.launch index 92cea76..59119d0 100644 --- a/src/rl_sar/launch/gazebo_gr1t2_isaacgym.launch +++ b/src/rl_sar/launch/gazebo_gr1t2_isaacgym.launch @@ -12,7 +12,7 @@ - + diff --git a/src/rl_sar/library/loop/loop.hpp b/src/rl_sar/library/loop/loop.hpp index e36410c..76b2a02 100644 --- a/src/rl_sar/library/loop/loop.hpp +++ b/src/rl_sar/library/loop/loop.hpp @@ -14,19 +14,9 @@ class LoopFunc { - private: - std::string _name; - double _period; - std::function _func; - int _bindCPU; - std::atomic _running; - std::mutex _mutex; - std::condition_variable _cv; - std::thread _thread; - - public: +public: LoopFunc(const std::string &name, double period, std::function func, int bindCPU = -1) - : _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {} + : _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {} void start() { @@ -57,12 +47,22 @@ class LoopFunc } log("[Loop End] named: " + _name); } - private: + +private: + std::string _name; + double _period; + std::function _func; + int _bindCPU; + std::atomic _running; + std::mutex _mutex; + std::condition_variable _cv; + std::thread _thread; + void loop() { - while (_running) - { - auto start = std::chrono::steady_clock::now(); + while (_running) + { + auto start = std::chrono::steady_clock::now(); _func(); @@ -72,7 +72,8 @@ class LoopFunc if (sleepTime.count() > 0) { std::unique_lock lock(_mutex); - if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; })) + if (_cv.wait_for(lock, sleepTime, [this] + { return !_running; })) { break; } @@ -87,7 +88,7 @@ class LoopFunc return stream.str(); } - void log(const std::string& message) + void log(const std::string &message) { static std::mutex logMutex; std::lock_guard lock(logMutex); @@ -108,4 +109,4 @@ class LoopFunc } }; -#endif \ No newline at end of file +#endif // LOOP_H diff --git a/src/rl_sar/library/observation_buffer/observation_buffer.cpp b/src/rl_sar/library/observation_buffer/observation_buffer.cpp index 4f0adea..32376f8 100644 --- a/src/rl_sar/library/observation_buffer/observation_buffer.cpp +++ b/src/rl_sar/library/observation_buffer/observation_buffer.cpp @@ -2,11 +2,11 @@ ObservationBuffer::ObservationBuffer() {} -ObservationBuffer::ObservationBuffer(int num_envs, - int num_obs, - int include_history_steps) - : num_envs(num_envs), - num_obs(num_obs), +ObservationBuffer::ObservationBuffer(int num_envs, + int num_obs, + int include_history_steps) + : num_envs(num_envs), + num_obs(num_obs), include_history_steps(include_history_steps) { num_obs_total = num_obs * include_history_steps; @@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs, void ObservationBuffer::reset(std::vector reset_idxs, torch::Tensor new_obs) { std::vector indices; - for (int idx : reset_idxs) { + for (int idx : reset_idxs) + { indices.push_back(torch::indexing::Slice(idx)); } obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps})); diff --git a/src/rl_sar/library/observation_buffer/observation_buffer.hpp b/src/rl_sar/library/observation_buffer/observation_buffer.hpp index 72be75c..cb4afea 100644 --- a/src/rl_sar/library/observation_buffer/observation_buffer.hpp +++ b/src/rl_sar/library/observation_buffer/observation_buffer.hpp @@ -4,7 +4,8 @@ #include #include -class ObservationBuffer { +class ObservationBuffer +{ public: ObservationBuffer(int num_envs, int num_obs, int include_history_steps); ObservationBuffer(); @@ -21,4 +22,4 @@ class ObservationBuffer { torch::Tensor obs_buf; }; -#endif // OBSERVATION_BUFFER_HPP \ No newline at end of file +#endif // OBSERVATION_BUFFER_HPP diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index 9529b2b..ea0cb70 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -15,34 +15,34 @@ torch::Tensor RL::ComputeObservation() { std::vector obs_list; - for(const std::string& observation : this->params.observations) + for (const std::string &observation : this->params.observations) { - if(observation == "lin_vel") + if (observation == "lin_vel") { obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale); } - else if(observation == "ang_vel") + else if (observation == "ang_vel") { // obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery? obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale); } - else if(observation == "gravity_vec") + else if (observation == "gravity_vec") { obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework)); } - else if(observation == "commands") + else if (observation == "commands") { obs_list.push_back(this->obs.commands * this->params.commands_scale); } - else if(observation == "dof_pos") + else if (observation == "dof_pos") { obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale); } - else if(observation == "dof_vel") + else if (observation == "dof_vel") { obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale); } - else if(observation == "actions") + else if (observation == "actions") { obs_list.push_back(this->obs.actions); } @@ -92,22 +92,22 @@ torch::Tensor RL::ComputePosition(torch::Tensor actions) return actions_scaled + this->params.default_dof_pos; } -torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework) +torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework) { torch::Tensor q_w; torch::Tensor q_vec; - if(framework == "isaacsim") + if (framework == "isaacsim") { q_w = q.index({torch::indexing::Slice(), 0}); q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(1, 4)}); } - else if(framework == "isaacgym") + else if (framework == "isaacgym") { q_w = q.index({torch::indexing::Slice(), 3}); q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)}); } c10::IntArrayRef shape = q.sizes(); - + torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1); torch::Tensor b = torch::cross(q_vec, v, -1) * q_w.unsqueeze(-1) * 2.0; torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0; @@ -122,17 +122,17 @@ void RL::StateController(const RobotState *state, RobotCommand * static float getdown_percent = 0.0; // waiting - if(this->running_state == STATE_WAITING) + if (this->running_state == STATE_WAITING) { - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { command->motor_command.q[i] = state->motor_state.q[i]; } - if(this->control.control_state == STATE_POS_GETUP) + if (this->control.control_state == STATE_POS_GETUP) { this->control.control_state = STATE_WAITING; getup_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { now_state.motor_state.q[i] = state->motor_state.q[i]; start_state.motor_state.q[i] = now_state.motor_state.q[i]; @@ -142,13 +142,13 @@ void RL::StateController(const RobotState *state, RobotCommand * } } // stand up (position control) - else if(this->running_state == STATE_POS_GETUP) + else if (this->running_state == STATE_POS_GETUP) { - if(getup_percent < 1.0) + if (getup_percent < 1.0) { getup_percent += 1 / 500.0; getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent; - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item(); command->motor_command.dq[i] = 0; @@ -158,17 +158,17 @@ void RL::StateController(const RobotState *state, RobotCommand * } std::cout << "\r" << std::flush << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << std::flush; } - if(this->control.control_state == STATE_RL_INIT) + if (this->control.control_state == STATE_RL_INIT) { this->control.control_state = STATE_WAITING; this->running_state = STATE_RL_INIT; std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl; } - else if(this->control.control_state == STATE_POS_GETDOWN) + else if (this->control.control_state == STATE_POS_GETDOWN) { this->control.control_state = STATE_WAITING; getdown_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { now_state.motor_state.q[i] = state->motor_state.q[i]; } @@ -177,9 +177,9 @@ void RL::StateController(const RobotState *state, RobotCommand * } } // init obs and start rl loop - else if(this->running_state == STATE_RL_INIT) + else if (this->running_state == STATE_RL_INIT) { - if(getup_percent == 1) + if (getup_percent == 1) { this->InitObservations(); this->InitOutputs(); @@ -189,10 +189,10 @@ void RL::StateController(const RobotState *state, RobotCommand * } } // rl loop - else if(this->running_state == STATE_RL_RUNNING) + else if (this->running_state == STATE_RL_RUNNING) { std::cout << "\r" << std::flush << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << std::flush; - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { command->motor_command.q[i] = this->output_dof_pos[0][i].item(); command->motor_command.dq[i] = 0; @@ -200,22 +200,22 @@ void RL::StateController(const RobotState *state, RobotCommand * command->motor_command.kd[i] = this->params.rl_kd[0][i].item(); command->motor_command.tau[i] = 0; } - if(this->control.control_state == STATE_POS_GETDOWN) + if (this->control.control_state == STATE_POS_GETDOWN) { this->control.control_state = STATE_WAITING; getdown_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { now_state.motor_state.q[i] = state->motor_state.q[i]; } this->running_state = STATE_POS_GETDOWN; std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl; } - else if(this->control.control_state == STATE_POS_GETUP) + else if (this->control.control_state == STATE_POS_GETUP) { this->control.control_state = STATE_WAITING; getup_percent = 0.0; - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { now_state.motor_state.q[i] = state->motor_state.q[i]; } @@ -224,13 +224,13 @@ void RL::StateController(const RobotState *state, RobotCommand * } } // get down (position control) - else if(this->running_state == STATE_POS_GETDOWN) + else if (this->running_state == STATE_POS_GETDOWN) { - if(getdown_percent < 1.0) + if (getdown_percent < 1.0) { getdown_percent += 1 / 500.0; getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent; - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i]; command->motor_command.dq[i] = 0; @@ -240,7 +240,7 @@ void RL::StateController(const RobotState *state, RobotCommand * } std::cout << "\r" << std::flush << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << std::flush; } - if(getdown_percent == 1) + if (getdown_percent == 1) { this->InitObservations(); this->InitOutputs(); @@ -255,28 +255,28 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques) { std::vector out_of_range_indices; std::vector out_of_range_values; - for(int i = 0; i < origin_output_torques.size(1); ++i) + for (int i = 0; i < origin_output_torques.size(1); ++i) { double torque_value = origin_output_torques[0][i].item(); double limit_lower = -this->params.torque_limits[0][i].item(); double limit_upper = this->params.torque_limits[0][i].item(); - if(torque_value < limit_lower || torque_value > limit_upper) + if (torque_value < limit_lower || torque_value > limit_upper) { out_of_range_indices.push_back(i); out_of_range_values.push_back(torque_value); } } - if(!out_of_range_indices.empty()) + if (!out_of_range_indices.empty()) { - for(int i = 0; i < out_of_range_indices.size(); ++i) + for (int i = 0; i < out_of_range_indices.size(); ++i) { int index = out_of_range_indices[i]; double value = out_of_range_values[i]; double limit_lower = -this->params.torque_limits[0][index].item(); double limit_upper = this->params.torque_limits[0][index].item(); - std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; + std::cout << LOGGER::WARNING << "Torque(" << index + 1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; } // Just a reminder, no protection // this->control.control_state = STATE_POS_GETDOWN; @@ -290,79 +290,109 @@ static bool kbhit() { termios term; tcgetattr(0, &term); - + termios term2 = term; term2.c_lflag &= ~ICANON; tcsetattr(0, TCSANOW, &term2); - + int byteswaiting; ioctl(0, FIONREAD, &byteswaiting); - + tcsetattr(0, TCSANOW, &term); - + return byteswaiting > 0; } void RL::KeyboardInterface() { - if(kbhit()) + if (kbhit()) { int c = fgetc(stdin); - switch(c) + switch (c) { - case '0': this->control.control_state = STATE_POS_GETUP; break; - case 'p': this->control.control_state = STATE_RL_INIT; break; - case '1': this->control.control_state = STATE_POS_GETDOWN; break; - case 'q': break; - case 'w': this->control.x += 0.1; break; - case 's': this->control.x -= 0.1; break; - case 'a': this->control.yaw += 0.1; break; - case 'd': this->control.yaw -= 0.1; break; - case 'i': break; - case 'k': break; - case 'j': this->control.y += 0.1; break; - case 'l': this->control.y -= 0.1; break; - case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break; - case 'r': this->control.control_state = STATE_RESET_SIMULATION; break; - case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break; - default: break; + case '0': + this->control.control_state = STATE_POS_GETUP; + break; + case 'p': + this->control.control_state = STATE_RL_INIT; + break; + case '1': + this->control.control_state = STATE_POS_GETDOWN; + break; + case 'q': + break; + case 'w': + this->control.x += 0.1; + break; + case 's': + this->control.x -= 0.1; + break; + case 'a': + this->control.yaw += 0.1; + break; + case 'd': + this->control.yaw -= 0.1; + break; + case 'i': + break; + case 'k': + break; + case 'j': + this->control.y += 0.1; + break; + case 'l': + this->control.y -= 0.1; + break; + case ' ': + this->control.x = 0; + this->control.y = 0; + this->control.yaw = 0; + break; + case 'r': + this->control.control_state = STATE_RESET_SIMULATION; + break; + case '\n': + this->control.control_state = STATE_TOGGLE_SIMULATION; + break; + default: + break; } } } -template -std::vector ReadVectorFromYaml(const YAML::Node& node) +template +std::vector ReadVectorFromYaml(const YAML::Node &node) { std::vector values; - for(const auto& val : node) + for (const auto &val : node) { values.push_back(val.as()); } return values; } -template -std::vector ReadVectorFromYaml(const YAML::Node& node, const std::string& framework, const int& rows, const int& cols) +template +std::vector ReadVectorFromYaml(const YAML::Node &node, const std::string &framework, const int &rows, const int &cols) { std::vector values; - for(const auto& val : node) + for (const auto &val : node) { values.push_back(val.as()); } - if(framework == "isaacsim") + if (framework == "isaacsim") { std::vector transposed_values(cols * rows); - for(int r = 0; r < rows; ++r) + for (int r = 0; r < rows; ++r) { - for(int c = 0; c < cols; ++c) + for (int c = 0; c < cols; ++c) { transposed_values[c * rows + r] = values[r * cols + c]; } } return transposed_values; } - else if(framework == "isaacgym") + else if (framework == "isaacgym") { return values; } @@ -380,7 +410,8 @@ void RL::ReadYaml(std::string robot_name) try { config = YAML::LoadFile(config_path)[robot_name]; - } catch(YAML::BadFile &e) + } + catch (YAML::BadFile &e) { std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl; return; @@ -396,7 +427,7 @@ void RL::ReadYaml(std::string robot_name) this->params.num_observations = config["num_observations"].as(); this->params.observations = ReadVectorFromYaml(config["observations"]); this->params.clip_obs = config["clip_obs"].as(); - if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull()) + if (config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull()) { this->params.clip_actions_upper = torch::tensor({}).view({1, -1}); this->params.clip_actions_lower = torch::tensor({}).view({1, -1}); @@ -440,11 +471,11 @@ void RL::CSVInit(std::string robot_name) csv_filename += ".csv"; std::ofstream file(csv_filename.c_str()); - for(int i = 0; i < 12; ++i) {file << "tau_cal_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";} - for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";} + for(int i = 0; i < 12; ++i) { file << "tau_cal_" << i << ","; } + for(int i = 0; i < 12; ++i) { file << "tau_est_" << i << ","; } + for(int i = 0; i < 12; ++i) { file << "joint_pos_" << i << ","; } + for(int i = 0; i < 12; ++i) { file << "joint_pos_target_" << i << ","; } + for(int i = 0; i < 12; ++i) { file << "joint_vel_" << i << ","; } file << std::endl; @@ -455,13 +486,13 @@ void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor jo { std::ofstream file(csv_filename.c_str(), std::ios_base::app); - for(int i = 0; i < 12; ++i) {file << torque[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item() << ",";} - for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item() << ",";} + for(int i = 0; i < 12; ++i) { file << torque[0][i].item() << ","; } + for(int i = 0; i < 12; ++i) { file << tau_est[0][i].item() << ","; } + for(int i = 0; i < 12; ++i) { file << joint_pos[0][i].item() << ","; } + for(int i = 0; i < 12; ++i) { file << joint_pos_target[0][i].item() << ","; } + for(int i = 0; i < 12; ++i) { file << joint_vel[0][i].item() << ","; } file << std::endl; file.close(); -} \ No newline at end of file +} diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.hpp b/src/rl_sar/library/rl_sdk/rl_sdk.hpp index 703770b..597de22 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.hpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.hpp @@ -8,14 +8,15 @@ #include -namespace LOGGER { - const char* const INFO = "\033[0;37m[INFO]\033[0m "; - const char* const WARNING = "\033[0;33m[WARNING]\033[0m "; - const char* const ERROR = "\033[0;31m[ERROR]\033[0m "; - const char* const DEBUG = "\033[0;32m[DEBUG]\033[0m "; +namespace LOGGER +{ + const char *const INFO = "\033[0;37m[INFO]\033[0m "; + const char *const WARNING = "\033[0;33m[WARNING]\033[0m "; + const char *const ERROR = "\033[0;31m[ERROR]\033[0m "; + const char *const DEBUG = "\033[0;32m[DEBUG]\033[0m "; } -template +template struct RobotCommand { struct MotorCommand @@ -28,7 +29,7 @@ struct RobotCommand } motor_command; }; -template +template struct RobotState { struct IMU @@ -48,7 +49,8 @@ struct RobotState } motor_state; }; -enum STATE { +enum STATE +{ STATE_WAITING = 0, STATE_POS_GETUP, STATE_RL_INIT, @@ -100,21 +102,21 @@ struct ModelParams struct Observations { - torch::Tensor lin_vel; - torch::Tensor ang_vel; - torch::Tensor gravity_vec; - torch::Tensor commands; - torch::Tensor base_quat; - torch::Tensor dof_pos; - torch::Tensor dof_vel; + torch::Tensor lin_vel; + torch::Tensor ang_vel; + torch::Tensor gravity_vec; + torch::Tensor commands; + torch::Tensor base_quat; + torch::Tensor dof_pos; + torch::Tensor dof_vel; torch::Tensor actions; }; class RL { public: - RL(){}; - ~RL(){}; + RL() {}; + ~RL() {}; ModelParams params; Observations obs; @@ -135,7 +137,7 @@ class RL void StateController(const RobotState *state, RobotCommand *command); torch::Tensor ComputeTorques(torch::Tensor actions); torch::Tensor ComputePosition(torch::Tensor actions); - torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework); + torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework); // yaml params void ReadYaml(std::string robot_name); @@ -165,4 +167,4 @@ class RL torch::Tensor output_dof_pos; }; -#endif \ No newline at end of file +#endif // RL_SDK_HPP diff --git a/src/rl_sar/package.xml b/src/rl_sar/package.xml index d665e58..ccb0b2a 100644 --- a/src/rl_sar/package.xml +++ b/src/rl_sar/package.xml @@ -8,7 +8,7 @@ TODO - catkin + catkin genmsg controller_manager joint_state_controller diff --git a/src/rl_sar/scripts/actuator_net.py b/src/rl_sar/scripts/actuator_net.py index 9c0acb7..b62b705 100644 --- a/src/rl_sar/scripts/actuator_net.py +++ b/src/rl_sar/scripts/actuator_net.py @@ -95,7 +95,7 @@ def load_data(data_path): for key in data_dict.keys(): data_dict[key] = np.array(data_dict[key]).T - + return data_dict, num_motors def process_data(data_dict, num_motors, step): @@ -122,7 +122,7 @@ def process_data(data_dict, num_motors, step): xs_joint = torch.cat(xs_joint, dim=1) xs.append(xs_joint) ys.append(tau_ests_joint) - + xs = torch.cat(xs, dim=0) ys = torch.cat(ys, dim=0) return xs, ys diff --git a/src/rl_sar/scripts/observation_buffer.py b/src/rl_sar/scripts/observation_buffer.py index ddb53b5..a1bb7a4 100644 --- a/src/rl_sar/scripts/observation_buffer.py +++ b/src/rl_sar/scripts/observation_buffer.py @@ -23,7 +23,7 @@ def insert(self, new_obs): def get_obs_vec(self, obs_ids): """Gets history of observations indexed by obs_ids. - + Arguments: obs_ids: An array of integers with which to index the desired observations, where 0 is the latest observation and diff --git a/src/rl_sar/scripts/rl_sdk.py b/src/rl_sar/scripts/rl_sdk.py index 938fc5e..dbebf1c 100644 --- a/src/rl_sar/scripts/rl_sdk.py +++ b/src/rl_sar/scripts/rl_sdk.py @@ -124,7 +124,7 @@ def __init__(self): self.robot_name = "" self.running_state = STATE.STATE_RL_RUNNING # default running_state set to STATE_RL_RUNNING self.simulation_running = False - + ### protected in cpp ### # rl module self.model = None @@ -156,7 +156,7 @@ def ComputeObservation(self): obs = torch.cat(obs_list, dim=-1) clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs) return clamped_obs - + def InitObservations(self): self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float) self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float) @@ -409,35 +409,35 @@ def ReadYaml(self, robot_name): def CSVInit(self, robot_name): self.csv_filename = os.path.join(BASE_PATH, "models", robot_name, 'motor') - + # Uncomment these lines if need timestamp for file name # now = datetime.now() # timestamp = now.strftime("%Y%m%d%H%M%S") # self.csv_filename += f"_{timestamp}" - + self.csv_filename += ".csv" - + with open(self.csv_filename, 'w', newline='') as file: writer = csv.writer(file) - + header = [] header += [f"tau_cal_{i}" for i in range(12)] header += [f"tau_est_{i}" for i in range(12)] header += [f"joint_pos_{i}" for i in range(12)] header += [f"joint_pos_target_{i}" for i in range(12)] header += [f"joint_vel_{i}" for i in range(12)] - + writer.writerow(header) def CSVLogger(self, torque, tau_est, joint_pos, joint_pos_target, joint_vel): with open(self.csv_filename, 'a', newline='') as file: writer = csv.writer(file) - + row = [] row += [torque[0][i].item() for i in range(12)] row += [tau_est[0][i].item() for i in range(12)] row += [joint_pos[0][i].item() for i in range(12)] row += [joint_pos_target[0][i].item() for i in range(12)] row += [joint_vel[0][i].item() for i in range(12)] - + writer.writerow(row) diff --git a/src/rl_sar/scripts/rl_sim.py b/src/rl_sar/scripts/rl_sim.py index 22803ac..660d665 100644 --- a/src/rl_sar/scripts/rl_sim.py +++ b/src/rl_sar/scripts/rl_sim.py @@ -234,4 +234,3 @@ def ThreadRL(self): if __name__ == "__main__": rl_sim = RL_Sim() rospy.spin() - \ No newline at end of file diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index a6db08a..ac7000c 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -28,10 +28,10 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u this->model = torch::jit::load(model_path); // loop - this->loop_udpSend = std::make_shared("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3); - this->loop_udpRecv = std::make_shared("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3); + this->loop_udpSend = std::make_shared("loop_udpSend", 0.002, std::bind(&RL_Real::UDPSend, this), 3); + this->loop_udpRecv = std::make_shared("loop_udpRecv", 0.002, std::bind(&RL_Real::UDPRecv, this), 3); this->loop_keyboard = std::make_shared("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this)); - this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this)); + this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this)); this->loop_rl = std::make_shared("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this)); this->loop_udpSend->start(); this->loop_udpRecv->start(); @@ -39,14 +39,13 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u this->loop_control->start(); this->loop_rl->start(); - #ifdef PLOT this->plot_t = std::vector(this->plot_size, 0); this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs); - for(auto& vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } - for(auto& vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } - this->loop_plot = std::make_shared("loop_plot" , 0.002, std::bind(&RL_Real::Plot, this)); + for (auto &vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } + for (auto &vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } + this->loop_plot = std::make_shared("loop_plot", 0.002, std::bind(&RL_Real::Plot, this)); this->loop_plot->start(); #endif #ifdef CSV_LOGGER @@ -72,27 +71,27 @@ void RL_Real::GetState(RobotState *state) this->unitree_udp.GetRecv(this->unitree_low_state); memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40); - if((int)this->unitree_joy.btn.components.R2 == 1) + if ((int)this->unitree_joy.btn.components.R2 == 1) { this->control.control_state = STATE_POS_GETUP; } - else if((int)this->unitree_joy.btn.components.R1 == 1) + else if ((int)this->unitree_joy.btn.components.R1 == 1) { this->control.control_state = STATE_RL_INIT; } - else if((int)this->unitree_joy.btn.components.L2 == 1) + else if ((int)this->unitree_joy.btn.components.L2 == 1) { this->control.control_state = STATE_POS_GETDOWN; } - if(this->params.framework == "isaacgym") + if (this->params.framework == "isaacgym") { state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[0]; // w state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[1]; // x state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[2]; // y state->imu.quaternion[2] = this->unitree_low_state.imu.quaternion[3]; // z } - else if(this->params.framework == "isaacsim") + else if (this->params.framework == "isaacsim") { state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[0]; // w state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[1]; // x @@ -100,11 +99,11 @@ void RL_Real::GetState(RobotState *state) state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[3]; // z } - for(int i = 0; i < 3; ++i) + for (int i = 0; i < 3; ++i) { state->imu.gyroscope[i] = this->unitree_low_state.imu.gyroscope[i]; } - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q; state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq; @@ -114,7 +113,7 @@ void RL_Real::GetState(RobotState *state) void RL_Real::SetCommand(const RobotCommand *command) { - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { this->unitree_low_command.motorCmd[i].mode = 0x0A; this->unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]]; @@ -140,7 +139,7 @@ void RL_Real::RobotControl() void RL_Real::RunModel() { - if(this->running_state == STATE_RL_RUNNING) + if (this->running_state == STATE_RL_RUNNING) { this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0); this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}}); @@ -182,7 +181,7 @@ torch::Tensor RL_Real::Forward() torch::Tensor actions = this->model.forward({this->history_obs}).toTensor(); - if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) + if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) { return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); } @@ -198,7 +197,7 @@ void RL_Real::Plot() this->plot_t.push_back(this->motiontime); plt::cla(); plt::clf(); - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); @@ -222,10 +221,10 @@ int main(int argc, char **argv) { signal(SIGINT, signalHandler); - while(1) + while (1) { sleep(10); - }; + } return 0; } diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 255cd68..a48b498 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -12,7 +12,7 @@ RL_Sim::RL_Sim() this->ReadYaml(this->robot_name); // history - if(this->params.use_history) + if (this->params.use_history) { this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); } @@ -21,7 +21,7 @@ RL_Sim::RL_Sim() // the mapping table is established according to the order defined in the YAML file std::vector sorted_joint_controller_names = this->params.joint_controller_names; std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end()); - for(size_t i = 0; i < this->params.joint_controller_names.size(); ++i) + for (size_t i = 0; i < this->params.joint_controller_names.size(); ++i) { this->sorted_to_original_index[sorted_joint_controller_names[i]] = i; } @@ -46,8 +46,8 @@ RL_Sim::RL_Sim() for (int i = 0; i < this->params.num_of_dofs; ++i) { // joint need to rename as xxx_joint - this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise( - this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10); + this->joint_publishers[this->params.joint_controller_names[i]] = + nh.advertise(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10); } // subscriber @@ -62,7 +62,7 @@ RL_Sim::RL_Sim() this->gazebo_unpause_physics_client = nh.serviceClient("/gazebo/unpause_physics"); // loop - this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this)); + this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this)); this->loop_rl = std::make_shared("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this)); this->loop_control->start(); this->loop_rl->start(); @@ -75,9 +75,9 @@ RL_Sim::RL_Sim() this->plot_t = std::vector(this->plot_size, 0); this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs); - for(auto& vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } - for(auto& vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } - this->loop_plot = std::make_shared("loop_plot" , 0.001 , std::bind(&RL_Sim::Plot , this)); + for (auto &vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } + for (auto &vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } + this->loop_plot = std::make_shared("loop_plot", 0.001, std::bind(&RL_Sim::Plot, this)); this->loop_plot->start(); #endif #ifdef CSV_LOGGER @@ -100,14 +100,14 @@ RL_Sim::~RL_Sim() void RL_Sim::GetState(RobotState *state) { - if(this->params.framework == "isaacgym") + if (this->params.framework == "isaacgym") { state->imu.quaternion[3] = this->pose.orientation.w; state->imu.quaternion[0] = this->pose.orientation.x; state->imu.quaternion[1] = this->pose.orientation.y; state->imu.quaternion[2] = this->pose.orientation.z; } - else if(this->params.framework == "isaacsim") + else if (this->params.framework == "isaacsim") { state->imu.quaternion[0] = this->pose.orientation.w; state->imu.quaternion[1] = this->pose.orientation.x; @@ -121,7 +121,7 @@ void RL_Sim::GetState(RobotState *state) // state->imu.accelerometer - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { state->motor_state.q[i] = this->mapped_joint_positions[i]; state->motor_state.dq[i] = this->mapped_joint_velocities[i]; @@ -131,7 +131,7 @@ void RL_Sim::GetState(RobotState *state) void RL_Sim::SetCommand(const RobotCommand *command) { - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { this->joint_publishers_commands[i].q = command->motor_command.q[i]; this->joint_publishers_commands[i].dq = command->motor_command.dq[i]; @@ -140,7 +140,7 @@ void RL_Sim::SetCommand(const RobotCommand *command) this->joint_publishers_commands[i].tau = command->motor_command.tau[i]; } - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { this->joint_publishers[this->params.joint_controller_names[i]].publish(this->joint_publishers_commands[i]); } @@ -148,7 +148,7 @@ void RL_Sim::SetCommand(const RobotCommand *command) void RL_Sim::RobotControl() { - if(this->control.control_state == STATE_RESET_SIMULATION) + if (this->control.control_state == STATE_RESET_SIMULATION) { gazebo_msgs::SetModelState set_model_state; set_model_state.request.model_state.model_name = this->gazebo_model_name; @@ -158,10 +158,10 @@ void RL_Sim::RobotControl() this->control.control_state = STATE_WAITING; } - if(this->control.control_state == STATE_TOGGLE_SIMULATION) + if (this->control.control_state == STATE_TOGGLE_SIMULATION) { std_srvs::Empty empty; - if(simulation_running) + if (simulation_running) { this->gazebo_pause_physics_client.call(empty); std::cout << std::endl << LOGGER::INFO << "Simulation Stop" << std::endl; @@ -174,7 +174,7 @@ void RL_Sim::RobotControl() simulation_running = !simulation_running; this->control.control_state = STATE_WAITING; } - if(simulation_running) + if (simulation_running) { this->motiontime++; this->GetState(&this->robot_state); @@ -194,9 +194,9 @@ void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) this->cmd_vel = *msg; } -void RL_Sim::MapData(const std::vector& source_data, std::vector& target_data) +void RL_Sim::MapData(const std::vector &source_data, std::vector &target_data) { - for(size_t i = 0; i < source_data.size(); ++i) + for (size_t i = 0; i < source_data.size(); ++i) { target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]]; } @@ -211,7 +211,7 @@ void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) void RL_Sim::RunModel() { - if(this->running_state == STATE_RL_RUNNING && simulation_running) + if (this->running_state == STATE_RL_RUNNING && simulation_running) { this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}}); this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0); @@ -249,7 +249,7 @@ torch::Tensor RL_Sim::Forward() torch::autograd::GradMode::set_enabled(false); torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor actions; - if(this->params.use_history) + if (this->params.use_history) { this->history_obs_buf.insert(clamped_obs); this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); @@ -260,7 +260,7 @@ torch::Tensor RL_Sim::Forward() actions = this->model.forward({clamped_obs}).toTensor(); } - if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) + if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) { return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); } @@ -276,13 +276,13 @@ void RL_Sim::Plot() this->plot_t.push_back(this->motiontime); plt::cla(); plt::clf(); - for(int i = 0; i < this->params.num_of_dofs; ++i) + for (int i = 0; i < this->params.num_of_dofs; ++i) { this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]); this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q); - plt::subplot(4, 3, i+1); + plt::subplot(4, 3, i + 1); plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r"); plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b"); plt::xlim(this->plot_t.front(), this->plot_t.back()); diff --git a/src/rl_sar/worlds/stairs.world b/src/rl_sar/worlds/stairs.world index 615ee47..faf0f47 100644 --- a/src/rl_sar/worlds/stairs.world +++ b/src/rl_sar/worlds/stairs.world @@ -9,16 +9,16 @@ 0 0 -9.81 - quick - 50 + quick + 50 1.3 - + 0.0 0.2 10.0 0.001 - + diff --git a/src/robot_joint_controller/include/robot_joint_controller.h b/src/robot_joint_controller/include/robot_joint_controller.h index 66b4151..a58e4fa 100644 --- a/src/robot_joint_controller/include/robot_joint_controller.h +++ b/src/robot_joint_controller/include/robot_joint_controller.h @@ -20,10 +20,10 @@ #include #include -#define PosStopF (2.146E+9f) -#define VelStopF (16000.0f) +#define PosStopF (2.146E+9f) +#define VelStopF (16000.0f) -typedef struct +typedef struct { uint8_t mode; double pos; @@ -35,15 +35,15 @@ typedef struct namespace robot_joint_controller { - class RobotJointController: public controller_interface::Controller + class RobotJointController : public controller_interface::Controller { -private: + private: hardware_interface::JointHandle joint; ros::Subscriber sub_command, sub_ft; control_toolbox::Pid pid_controller_; - std::unique_ptr > controller_state_publisher_ ; + std::unique_ptr> controller_state_publisher_; -public: + public: std::string name_space; std::string joint_name; urdf::JointConstSharedPtr joint_urdf; @@ -55,10 +55,10 @@ namespace robot_joint_controller RobotJointController(); ~RobotJointController(); virtual bool init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n); - virtual void starting(const ros::Time& time); - virtual void update(const ros::Time& time, const ros::Duration& period); + virtual void starting(const ros::Time &time); + virtual void update(const ros::Time &time, const ros::Duration &period); virtual void stopping(); - void setCommandCB(const robot_msgs::MotorCommandConstPtr& msg); + void setCommandCB(const robot_msgs::MotorCommandConstPtr &msg); void positionLimits(double &position); void velocityLimits(double &velocity); void effortLimits(double &effort); @@ -66,7 +66,6 @@ namespace robot_joint_controller void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false); void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup); void getGains(double &p, double &i, double &d, double &i_max, double &i_min); - }; } diff --git a/src/robot_joint_controller/src/robot_joint_controller.cpp b/src/robot_joint_controller/src/robot_joint_controller.cpp index bc287b2..41d652c 100644 --- a/src/robot_joint_controller/src/robot_joint_controller.cpp +++ b/src/robot_joint_controller/src/robot_joint_controller.cpp @@ -3,30 +3,36 @@ // #define rqtTune // use rqt or not -double clamp(double& value, double min, double max) { - if (value < min) { +double clamp(double &value, double min, double max) +{ + if (value < min) + { value = min; - } else if (value > max) { + } + else if (value > max) + { value = max; } return value; } -namespace robot_joint_controller +namespace robot_joint_controller { - RobotJointController::RobotJointController(){ + RobotJointController::RobotJointController() + { memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand)); memset(&lastState, 0, sizeof(robot_msgs::MotorState)); memset(&servoCommand, 0, sizeof(ServoCommand)); } - RobotJointController::~RobotJointController(){ + RobotJointController::~RobotJointController() + { sub_ft.shutdown(); sub_command.shutdown(); } - void RobotJointController::setCommandCB(const robot_msgs::MotorCommandConstPtr& msg) + void RobotJointController::setCommandCB(const robot_msgs::MotorCommandConstPtr &msg) { lastCommand.q = msg->q; lastCommand.kp = msg->kp; @@ -43,28 +49,31 @@ namespace robot_joint_controller bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n) { name_space = n.getNamespace(); - if (!n.getParam("joint", joint_name)){ + if (!n.getParam("joint", joint_name)) + { ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str()); return false; } - - // load pid param from ymal only if rqt need + + // load pid param from ymal only if rqt need #ifdef rqtTune - // Load PID Controller using gains set on parameter server - if (!pid_controller_.init(ros::NodeHandle(n, "pid"))) - return false; + // Load PID Controller using gains set on parameter server + if (!pid_controller_.init(ros::NodeHandle(n, "pid"))) + return false; #endif urdf::Model urdf; // Get URDF info about joint - if (!urdf.initParamWithNodeHandle("robot_description", n)){ + if (!urdf.initParamWithNodeHandle("robot_description", n)) + { ROS_ERROR("Failed to parse urdf file"); return false; } joint_urdf = urdf.getJoint(joint_name); - if (!joint_urdf){ + if (!joint_urdf) + { ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str()); return false; - } + } joint = robot->getHandle(joint_name); // Start command subscriber @@ -72,29 +81,29 @@ namespace robot_joint_controller // Start realtime state publisher controller_state_publisher_.reset( - new realtime_tools::RealtimePublisher(n, name_space + "/state", 1)); + new realtime_tools::RealtimePublisher(n, name_space + "/state", 1)); return true; } void RobotJointController::setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup) { - pid_controller_.setGains(p,i,d,i_max,i_min,antiwindup); + pid_controller_.setGains(p, i, d, i_max, i_min, antiwindup); } void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup) { - pid_controller_.getGains(p,i,d,i_max,i_min,antiwindup); + pid_controller_.getGains(p, i, d, i_max, i_min, antiwindup); } void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min) { bool dummy; - pid_controller_.getGains(p,i,d,i_max,i_min,dummy); + pid_controller_.getGains(p, i, d, i_max, i_min, dummy); } // Controller startup in realtime - void RobotJointController::starting(const ros::Time& time) + void RobotJointController::starting(const ros::Time &time) { double init_pos = joint.getPosition(); lastCommand.q = init_pos; @@ -109,7 +118,7 @@ namespace robot_joint_controller } // Controller update loop in realtime - void RobotJointController::update(const ros::Time& time, const ros::Duration& period) + void RobotJointController::update(const ros::Time &time, const ros::Duration &period) { double currentPos, currentVel, calcTorque; lastCommand = *(command.readFromRT()); @@ -118,27 +127,29 @@ namespace robot_joint_controller servoCommand.pos = lastCommand.q; positionLimits(servoCommand.pos); servoCommand.posStiffness = lastCommand.kp; - if(fabs(lastCommand.q - PosStopF) < 0.00001){ + if (fabs(lastCommand.q - PosStopF) < 0.00001) + { servoCommand.posStiffness = 0; } servoCommand.vel = lastCommand.dq; velocityLimits(servoCommand.vel); servoCommand.velStiffness = lastCommand.kd; - if(fabs(lastCommand.dq - VelStopF) < 0.00001){ + if (fabs(lastCommand.dq - VelStopF) < 0.00001) + { servoCommand.velStiffness = 0; } servoCommand.torque = lastCommand.tau; effortLimits(servoCommand.torque); - + // rqt set P D gains #ifdef rqtTune - double i, i_max, i_min; - getGains(servoCommand.posStiffness,i,servoCommand.velStiffness,i_max,i_min); + double i, i_max, i_min; + getGains(servoCommand.posStiffness, i, servoCommand.velStiffness, i_max, i_min); #endif currentPos = joint.getPosition(); // currentVel = computeVel(currentPos, (double)lastState.q, (double)lastState.dq, period.toSec()); - // calcTorque = computeTorque(currentPos, currentVel, servoCommand); + // calcTorque = computeTorque(currentPos, currentVel, servoCommand); currentVel = (currentPos - (double)lastState.q) / period.toSec(); calcTorque = servoCommand.posStiffness * (servoCommand.pos - currentPos) + servoCommand.velStiffness * (servoCommand.vel - currentVel) + servoCommand.torque; effortLimits(calcTorque); @@ -151,7 +162,8 @@ namespace robot_joint_controller lastState.tauEst = joint.getEffort(); // publish state - if (controller_state_publisher_ && controller_state_publisher_->trylock()) { + if (controller_state_publisher_ && controller_state_publisher_->trylock()) + { controller_state_publisher_->msg_.q = lastState.q; controller_state_publisher_->msg_.dq = lastState.dq; controller_state_publisher_->msg_.tauEst = lastState.tauEst; @@ -160,7 +172,7 @@ namespace robot_joint_controller } // Controller stopping in realtime - void RobotJointController::stopping(){} + void RobotJointController::stopping() {} void RobotJointController::positionLimits(double &position) {