diff --git a/src/rl_sar/config.yaml b/src/rl_sar/config.yaml index 6d4dffe..5a43f3a 100644 --- a/src/rl_sar/config.yaml +++ b/src/rl_sar/config.yaml @@ -1,5 +1,7 @@ a1: model_name: "model_0526.pt" + dt: 0.005 + decimation: 4 num_observations: 45 clip_obs: 100.0 clip_actions_lower: [-100, -100, -100, @@ -50,6 +52,8 @@ a1: gr1t1: model_name: "model_4000_jit.pt" + dt: 0.001 + decimation: 20 num_observations: 39 clip_obs: 100.0 clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, @@ -64,8 +68,8 @@ gr1t1: 57.0, 43.0, 114.0, 114.0, 15.3] fixed_kd: [5.7, 4.3, 11.4, 11.4, 1.5, 5.7, 4.3, 11.4, 11.4, 1.5] - hip_scale_reduction: 0.5 - hip_scale_reduction_indices: [0, 3, 6, 9] + hip_scale_reduction: 1.0 + hip_scale_reduction_indices: [] num_of_dofs: 10 action_scale: 1.0 lin_vel_scale: 1.0 diff --git a/src/rl_sar/include/rl_sim.hpp b/src/rl_sar/include/rl_sim.hpp index be5ce13..d39298c 100644 --- a/src/rl_sar/include/rl_sim.hpp +++ b/src/rl_sar/include/rl_sim.hpp @@ -11,6 +11,7 @@ #include #include "robot_msgs/MotorCommand.h" #include +#include #include "matplotlibcpp.h" namespace plt = matplotlibcpp; @@ -54,7 +55,7 @@ class RL_Sim : public RL ros::Subscriber model_state_subscriber; ros::Subscriber joint_state_subscriber; ros::Subscriber cmd_vel_subscriber; - ros::ServiceClient gazebo_reset_client; + ros::ServiceClient gazebo_set_model_state_client; std::map joint_publishers; std::vector joint_publishers_commands; void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); diff --git a/src/rl_sar/library/loop/loop.hpp b/src/rl_sar/library/loop/loop.hpp index 5c68b27..e36410c 100644 --- a/src/rl_sar/library/loop/loop.hpp +++ b/src/rl_sar/library/loop/loop.hpp @@ -3,82 +3,109 @@ #include #include -#include #include #include +#include +#include +#include +#include +#include +#include -typedef std::function Callback; +class LoopFunc +{ + private: + std::string _name; + double _period; + std::function _func; + int _bindCPU; + std::atomic _running; + std::mutex _mutex; + std::condition_variable _cv; + std::thread _thread; -class Loop { - public: - Loop(std::string name, float period, int bindCPU = -1) - : _name(name), _period(period), _bindCPU(bindCPU) {} - ~Loop() { - if (_isrunning) { - shutdown(); // Ensure the loop is stopped when the object is destroyed - } - } + public: + LoopFunc(const std::string &name, double period, std::function func, int bindCPU = -1) + : _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {} - void start() { - _isrunning = true; - _thread = std::thread([this]() { - if (_bindCPU >= 0) { - std::lock_guard lock(_printMutex); - std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), run at cpu: " << _bindCPU << std::endl; - } else { - std::lock_guard lock(_printMutex); - std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), cpu unspecified" << std::endl; + void start() + { + _running = true; + log("[Loop Start] named: " + _name + ", period: " + formatPeriod() + "(ms)" + (_bindCPU != -1 ? ", run at cpu: " + std::to_string(_bindCPU) : ", cpu unspecified")); + if (_bindCPU != -1) + { + _thread = std::thread(&LoopFunc::loop, this); + setThreadAffinity(_thread.native_handle(), _bindCPU); } - entryFunc(); - }); // Start the loop in a new thread - } - - void shutdown() { - _isrunning = false; - if (_thread.joinable()) { - _thread.join(); // Wait for the loop thread to finish - std::lock_guard lock(_printMutex); - std::cout << "[Loop End] named: " << _name << std::endl; - } - } - - virtual void functionCB() = 0; - - protected: - void entryFunc() { - while (_isrunning) { - functionCB(); // Call the overridden functionCB in a loop - std::this_thread::sleep_for(std::chrono::duration(_period)); // Wait for the specified period + else + { + _thread = std::thread(&LoopFunc::loop, this); + } + _thread.detach(); } - } - std::string _name; - float _period; - int _bindCPU; - bool _isrunning = false; - std::thread _thread; - static std::mutex _printMutex; -}; + void shutdown() + { + { + std::unique_lock lock(_mutex); + _running = false; + _cv.notify_one(); + } + if (_thread.joinable()) + { + _thread.join(); + } + log("[Loop End] named: " + _name); + } + private: + void loop() + { + while (_running) + { + auto start = std::chrono::steady_clock::now(); -std::mutex Loop::_printMutex; + _func(); -class LoopFunc : public Loop { - public: - LoopFunc(std::string name, float period, const Callback& cb) - : Loop(name, period), _fp(cb) {} + auto end = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start); + auto sleepTime = std::chrono::milliseconds(static_cast((_period * 1000) - elapsed.count())); + if (sleepTime.count() > 0) + { + std::unique_lock lock(_mutex); + if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; })) + { + break; + } + } + } + } - LoopFunc(std::string name, float period, int bindCPU, const Callback& cb) - : Loop(name, period, bindCPU), _fp(cb) {} + std::string formatPeriod() const + { + std::ostringstream stream; + stream << std::fixed << std::setprecision(0) << _period * 1000; + return stream.str(); + } - void functionCB() override { + void log(const std::string& message) { - std::lock_guard lock(_printMutex); - (_fp)(); // Call the provided callback function + static std::mutex logMutex; + std::lock_guard lock(logMutex); + std::cout << message << std::endl; } - } - private: - Callback _fp; + void setThreadAffinity(std::thread::native_handle_type threadHandle, int cpuId) + { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(cpuId, &cpuset); + if (pthread_setaffinity_np(threadHandle, sizeof(cpu_set_t), &cpuset) != 0) + { + std::ostringstream oss; + oss << "Error setting thread affinity: CPU " << cpuId << " may not be valid or accessible."; + throw std::runtime_error(oss.str()); + } + } }; -#endif +#endif \ No newline at end of file diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index 89dfc25..0841a34 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -104,6 +104,7 @@ void RL::StateController(const RobotState *state, RobotCommand * start_state.motor_state.q[i] = now_state.motor_state.q[i]; } this->running_state = STATE_POS_GETUP; + std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl; } } // stand up (position control) @@ -111,7 +112,7 @@ void RL::StateController(const RobotState *state, RobotCommand * { if(getup_percent < 1.0) { - getup_percent += 1 / 1000.0; + getup_percent += 1 / 500.0; getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent; for(int i = 0; i < this->params.num_of_dofs; ++i) { @@ -125,9 +126,9 @@ void RL::StateController(const RobotState *state, RobotCommand * } if(this->control.control_state == STATE_RL_INIT) { - std::cout << std::endl; this->control.control_state = STATE_WAITING; this->running_state = STATE_RL_INIT; + std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl; } else if(this->control.control_state == STATE_POS_GETDOWN) { @@ -138,6 +139,7 @@ void RL::StateController(const RobotState *state, RobotCommand * now_state.motor_state.q[i] = state->motor_state.q[i]; } this->running_state = STATE_POS_GETDOWN; + std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl; } } // init obs and start rl loop @@ -145,15 +147,18 @@ void RL::StateController(const RobotState *state, RobotCommand * { if(getup_percent == 1) { - this->running_state = STATE_RL_RUNNING; this->InitObservations(); this->InitOutputs(); this->InitControl(); + this->running_state = STATE_RL_RUNNING; + std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_RUNNING" << std::endl; } } // rl loop else if(this->running_state == STATE_RL_RUNNING) { + std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r"; + for(int i = 0; i < this->params.num_of_dofs; ++i) { command->motor_command.q[i] = this->output_dof_pos[0][i].item(); @@ -171,6 +176,18 @@ void RL::StateController(const RobotState *state, RobotCommand * now_state.motor_state.q[i] = state->motor_state.q[i]; } this->running_state = STATE_POS_GETDOWN; + std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl; + } + else if(this->control.control_state == STATE_POS_GETUP) + { + this->control.control_state = STATE_WAITING; + getup_percent = 0.0; + for(int i = 0; i < this->params.num_of_dofs; ++i) + { + now_state.motor_state.q[i] = state->motor_state.q[i]; + } + this->running_state = STATE_POS_GETUP; + std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl; } } // get down (position control) @@ -178,7 +195,7 @@ void RL::StateController(const RobotState *state, RobotCommand * { if(getdown_percent < 1.0) { - getdown_percent += 1 / 1000.0; + getdown_percent += 1 / 500.0; getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent; for(int i = 0; i < this->params.num_of_dofs; ++i) { @@ -192,11 +209,11 @@ void RL::StateController(const RobotState *state, RobotCommand * } if(getdown_percent == 1) { - std::cout << std::endl; - this->running_state = STATE_WAITING; this->InitObservations(); this->InitOutputs(); this->InitControl(); + this->running_state = STATE_WAITING; + std::cout << std::endl << LOGGER::INFO << "Switching to STATE_WAITING" << std::endl; } } } @@ -226,10 +243,11 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques) double limit_lower = -this->params.torque_limits[0][index].item(); double limit_upper = this->params.torque_limits[0][index].item(); - std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; - std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl; + std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; } - this->control.control_state = STATE_POS_GETDOWN; + // Just a reminder, no protection + // this->control.control_state = STATE_POS_GETDOWN; + // std::cout << LOGGER::INFO << "Switching to STATE_POS_GETDOWN"<< std::endl; } } @@ -254,11 +272,6 @@ static bool kbhit() void RL::KeyboardInterface() { - if(this->running_state == STATE_RL_RUNNING) - { - std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r"; - } - if(kbhit()) { int c = fgetc(stdin); @@ -308,6 +321,8 @@ void RL::ReadYaml(std::string robot_name) } this->params.model_name = config["model_name"].as(); + this->params.dt = config["dt"].as(); + this->params.decimation = config["decimation"].as(); this->params.num_observations = config["num_observations"].as(); this->params.clip_obs = config["clip_obs"].as(); this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml(config["clip_actions_upper"])).view({1, -1}); diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.hpp b/src/rl_sar/library/rl_sdk/rl_sdk.hpp index 72e3776..2067265 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.hpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.hpp @@ -69,6 +69,8 @@ struct Control struct ModelParams { std::string model_name; + double dt; + int decimation; int num_observations; double damping; double stiffness; diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index 2b46b2f..7b1718f 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -27,16 +27,17 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u this->model = torch::jit::load(model_path); // loop - this->loop_keyboard = std::make_shared("loop_keyboard", 0.05 , std::bind(&RL_Real::KeyboardInterface, this)); - this->loop_control = std::make_shared("loop_control" , 0.002, std::bind(&RL_Real::RobotControl , this)); - this->loop_udpSend = std::make_shared("loop_udpSend" , 0.002, 3, std::bind(&RL_Real::UDPSend , this)); - this->loop_udpRecv = std::make_shared("loop_udpRecv" , 0.002, 3, std::bind(&RL_Real::UDPRecv , this)); - this->loop_rl = std::make_shared("loop_rl" , 0.02 , std::bind(&RL_Real::RunModel , this)); - this->loop_keyboard->start(); + this->loop_udpSend = std::make_shared("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3); + this->loop_udpRecv = std::make_shared("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3); + this->loop_keyboard = std::make_shared("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this)); + this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this)); + this->loop_rl = std::make_shared("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this)); this->loop_udpSend->start(); this->loop_udpRecv->start(); + this->loop_keyboard->start(); this->loop_control->start(); this->loop_rl->start(); + #ifdef PLOT this->plot_t = std::vector(this->plot_size, 0); @@ -54,9 +55,9 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u RL_Real::~RL_Real() { - this->loop_keyboard->shutdown(); this->loop_udpSend->shutdown(); this->loop_udpRecv->shutdown(); + this->loop_keyboard->shutdown(); this->loop_control->shutdown(); this->loop_rl->shutdown(); #ifdef PLOT diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 56334f0..aab1ea3 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -56,23 +56,23 @@ RL_Sim::RL_Sim() this->joint_state_subscriber = nh.subscribe(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this); // service - this->gazebo_reset_client = nh.serviceClient("/gazebo/reset_simulation"); + this->gazebo_set_model_state_client = nh.serviceClient("/gazebo/set_model_state"); // loop - this->loop_keyboard = std::make_shared("loop_keyboard", 0.05 , std::bind(&RL_Sim::KeyboardInterface, this)); - this->loop_control = std::make_shared("loop_control" , 0.002, std::bind(&RL_Sim::RobotControl , this)); - this->loop_rl = std::make_shared("loop_rl" , 0.02 , std::bind(&RL_Sim::RunModel , this)); + this->loop_keyboard = std::make_shared("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this)); + this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this)); + this->loop_rl = std::make_shared("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this)); this->loop_keyboard->start(); this->loop_control->start(); this->loop_rl->start(); #ifdef PLOT - plot_t = std::vector(this->plot_size, 0); + this->plot_t = std::vector(this->plot_size, 0); this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs); for(auto& vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } for(auto& vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } - this->loop_plot = std::make_shared("loop_plot" , 0.002, std::bind(&RL_Sim::Plot, this)); + this->loop_plot = std::make_shared("loop_plot" , 0.001 , std::bind(&RL_Sim::Plot , this)); this->loop_plot->start(); #endif #ifdef CSV_LOGGER @@ -135,9 +135,14 @@ void RL_Sim::RobotControl() if(this->control.control_state == STATE_RESET_SIMULATION) { + gazebo_msgs::SetModelState set_model_state; + std::string gazebo_model_name = this->robot_name + "_gazebo"; + set_model_state.request.model_state.model_name = gazebo_model_name; + set_model_state.request.model_state.pose.position.z = 1.0; + set_model_state.request.model_state.reference_frame = "world"; + this->gazebo_set_model_state_client.call(set_model_state); + this->control.control_state = STATE_WAITING; - std_srvs::Empty srv; - this->gazebo_reset_client.call(srv); } this->GetState(&this->robot_state); @@ -263,7 +268,7 @@ void RL_Sim::Plot() plt::xlim(this->plot_t.front(), this->plot_t.back()); } // plt::legend(); - plt::pause(0.0001); + plt::pause(0.01); } void signalHandler(int signum)