Skip to content

Commit

Permalink
feat: humanoid works
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed May 30, 2024
1 parent 5371380 commit e3dfb68
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 95 deletions.
8 changes: 6 additions & 2 deletions src/rl_sar/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
a1:
model_name: "model_0526.pt"
dt: 0.005
decimation: 4
num_observations: 45
clip_obs: 100.0
clip_actions_lower: [-100, -100, -100,
Expand Down Expand Up @@ -50,6 +52,8 @@ a1:

gr1t1:
model_name: "model_4000_jit.pt"
dt: 0.001
decimation: 20
num_observations: 39
clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
Expand All @@ -64,8 +68,8 @@ gr1t1:
57.0, 43.0, 114.0, 114.0, 15.3]
fixed_kd: [5.7, 4.3, 11.4, 11.4, 1.5,
5.7, 4.3, 11.4, 11.4, 1.5]
hip_scale_reduction: 0.5
hip_scale_reduction_indices: [0, 3, 6, 9]
hip_scale_reduction: 1.0
hip_scale_reduction_indices: []
num_of_dofs: 10
action_scale: 1.0
lin_vel_scale: 1.0
Expand Down
3 changes: 2 additions & 1 deletion src/rl_sar/include/rl_sim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <geometry_msgs/Twist.h>
#include "robot_msgs/MotorCommand.h"
#include <csignal>
#include <gazebo_msgs/SetModelState.h>

#include "matplotlibcpp.h"
namespace plt = matplotlibcpp;
Expand Down Expand Up @@ -54,7 +55,7 @@ class RL_Sim : public RL
ros::Subscriber model_state_subscriber;
ros::Subscriber joint_state_subscriber;
ros::Subscriber cmd_vel_subscriber;
ros::ServiceClient gazebo_reset_client;
ros::ServiceClient gazebo_set_model_state_client;
std::map<std::string, ros::Publisher> joint_publishers;
std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
Expand Down
151 changes: 89 additions & 62 deletions src/rl_sar/library/loop/loop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,82 +3,109 @@

#include <iostream>
#include <thread>
#include <mutex>
#include <chrono>
#include <functional>
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <vector>
#include <sstream>
#include <iomanip>

typedef std::function<void()> Callback;
class LoopFunc
{
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;

class Loop {
public:
Loop(std::string name, float period, int bindCPU = -1)
: _name(name), _period(period), _bindCPU(bindCPU) {}
~Loop() {
if (_isrunning) {
shutdown(); // Ensure the loop is stopped when the object is destroyed
}
}
public:
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}

void start() {
_isrunning = true;
_thread = std::thread([this]() {
if (_bindCPU >= 0) {
std::lock_guard<std::mutex> lock(_printMutex);
std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), run at cpu: " << _bindCPU << std::endl;
} else {
std::lock_guard<std::mutex> lock(_printMutex);
std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), cpu unspecified" << std::endl;
void start()
{
_running = true;
log("[Loop Start] named: " + _name + ", period: " + formatPeriod() + "(ms)" + (_bindCPU != -1 ? ", run at cpu: " + std::to_string(_bindCPU) : ", cpu unspecified"));
if (_bindCPU != -1)
{
_thread = std::thread(&LoopFunc::loop, this);
setThreadAffinity(_thread.native_handle(), _bindCPU);
}
entryFunc();
}); // Start the loop in a new thread
}

void shutdown() {
_isrunning = false;
if (_thread.joinable()) {
_thread.join(); // Wait for the loop thread to finish
std::lock_guard<std::mutex> lock(_printMutex);
std::cout << "[Loop End] named: " << _name << std::endl;
}
}

virtual void functionCB() = 0;

protected:
void entryFunc() {
while (_isrunning) {
functionCB(); // Call the overridden functionCB in a loop
std::this_thread::sleep_for(std::chrono::duration<float>(_period)); // Wait for the specified period
else
{
_thread = std::thread(&LoopFunc::loop, this);
}
_thread.detach();
}
}

std::string _name;
float _period;
int _bindCPU;
bool _isrunning = false;
std::thread _thread;
static std::mutex _printMutex;
};
void shutdown()
{
{
std::unique_lock<std::mutex> lock(_mutex);
_running = false;
_cv.notify_one();
}
if (_thread.joinable())
{
_thread.join();
}
log("[Loop End] named: " + _name);
}
private:
void loop()
{
while (_running)
{
auto start = std::chrono::steady_clock::now();

std::mutex Loop::_printMutex;
_func();

class LoopFunc : public Loop {
public:
LoopFunc(std::string name, float period, const Callback& cb)
: Loop(name, period), _fp(cb) {}
auto end = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
auto sleepTime = std::chrono::milliseconds(static_cast<int>((_period * 1000) - elapsed.count()));
if (sleepTime.count() > 0)
{
std::unique_lock<std::mutex> lock(_mutex);
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
{
break;
}
}
}
}

LoopFunc(std::string name, float period, int bindCPU, const Callback& cb)
: Loop(name, period, bindCPU), _fp(cb) {}
std::string formatPeriod() const
{
std::ostringstream stream;
stream << std::fixed << std::setprecision(0) << _period * 1000;
return stream.str();
}

void functionCB() override {
void log(const std::string& message)
{
std::lock_guard<std::mutex> lock(_printMutex);
(_fp)(); // Call the provided callback function
static std::mutex logMutex;
std::lock_guard<std::mutex> lock(logMutex);
std::cout << message << std::endl;
}
}

private:
Callback _fp;
void setThreadAffinity(std::thread::native_handle_type threadHandle, int cpuId)
{
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(cpuId, &cpuset);
if (pthread_setaffinity_np(threadHandle, sizeof(cpu_set_t), &cpuset) != 0)
{
std::ostringstream oss;
oss << "Error setting thread affinity: CPU " << cpuId << " may not be valid or accessible.";
throw std::runtime_error(oss.str());
}
}
};

#endif
#endif
43 changes: 29 additions & 14 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
start_state.motor_state.q[i] = now_state.motor_state.q[i];
}
this->running_state = STATE_POS_GETUP;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
}
}
// stand up (position control)
else if(this->running_state == STATE_POS_GETUP)
{
if(getup_percent < 1.0)
{
getup_percent += 1 / 1000.0;
getup_percent += 1 / 500.0;
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
Expand All @@ -125,9 +126,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
if(this->control.control_state == STATE_RL_INIT)
{
std::cout << std::endl;
this->control.control_state = STATE_WAITING;
this->running_state = STATE_RL_INIT;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
}
else if(this->control.control_state == STATE_POS_GETDOWN)
{
Expand All @@ -138,22 +139,26 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
now_state.motor_state.q[i] = state->motor_state.q[i];
}
this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
}
}
// init obs and start rl loop
else if(this->running_state == STATE_RL_INIT)
{
if(getup_percent == 1)
{
this->running_state = STATE_RL_RUNNING;
this->InitObservations();
this->InitOutputs();
this->InitControl();
this->running_state = STATE_RL_RUNNING;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_RUNNING" << std::endl;
}
}
// rl loop
else if(this->running_state == STATE_RL_RUNNING)
{
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";

for(int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
Expand All @@ -171,14 +176,26 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
now_state.motor_state.q[i] = state->motor_state.q[i];
}
this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
}
else if(this->control.control_state == STATE_POS_GETUP)
{
this->control.control_state = STATE_WAITING;
getup_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
this->running_state = STATE_POS_GETUP;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
}
}
// get down (position control)
else if(this->running_state == STATE_POS_GETDOWN)
{
if(getdown_percent < 1.0)
{
getdown_percent += 1 / 1000.0;
getdown_percent += 1 / 500.0;
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
Expand All @@ -192,11 +209,11 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
if(getdown_percent == 1)
{
std::cout << std::endl;
this->running_state = STATE_WAITING;
this->InitObservations();
this->InitOutputs();
this->InitControl();
this->running_state = STATE_WAITING;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_WAITING" << std::endl;
}
}
}
Expand Down Expand Up @@ -226,10 +243,11 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
double limit_lower = -this->params.torque_limits[0][index].item<double>();
double limit_upper = this->params.torque_limits[0][index].item<double>();

std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
}
this->control.control_state = STATE_POS_GETDOWN;
// Just a reminder, no protection
// this->control.control_state = STATE_POS_GETDOWN;
// std::cout << LOGGER::INFO << "Switching to STATE_POS_GETDOWN"<< std::endl;
}
}

Expand All @@ -254,11 +272,6 @@ static bool kbhit()

void RL::KeyboardInterface()
{
if(this->running_state == STATE_RL_RUNNING)
{
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
}

if(kbhit())
{
int c = fgetc(stdin);
Expand Down Expand Up @@ -308,6 +321,8 @@ void RL::ReadYaml(std::string robot_name)
}

this->params.model_name = config["model_name"].as<std::string>();
this->params.dt = config["dt"].as<double>();
this->params.decimation = config["decimation"].as<int>();
this->params.num_observations = config["num_observations"].as<int>();
this->params.clip_obs = config["clip_obs"].as<double>();
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"])).view({1, -1});
Expand Down
2 changes: 2 additions & 0 deletions src/rl_sar/library/rl_sdk/rl_sdk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ struct Control
struct ModelParams
{
std::string model_name;
double dt;
int decimation;
int num_observations;
double damping;
double stiffness;
Expand Down
15 changes: 8 additions & 7 deletions src/rl_sar/src/rl_real_a1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->model = torch::jit::load(model_path);

// loop
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , std::bind(&RL_Real::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, std::bind(&RL_Real::RobotControl , this));
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, std::bind(&RL_Real::UDPSend , this));
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, std::bind(&RL_Real::UDPRecv , this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , std::bind(&RL_Real::RunModel , this));
this->loop_keyboard->start();
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3);
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3);
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this));
this->loop_udpSend->start();
this->loop_udpRecv->start();
this->loop_keyboard->start();
this->loop_control->start();
this->loop_rl->start();


#ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0);
Expand All @@ -54,9 +55,9 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u

RL_Real::~RL_Real()
{
this->loop_keyboard->shutdown();
this->loop_udpSend->shutdown();
this->loop_udpRecv->shutdown();
this->loop_keyboard->shutdown();
this->loop_control->shutdown();
this->loop_rl->shutdown();
#ifdef PLOT
Expand Down
Loading

0 comments on commit e3dfb68

Please sign in to comment.