Skip to content

Commit

Permalink
style: code format
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed Oct 8, 2024
1 parent 1080833 commit 8b6a2d6
Show file tree
Hide file tree
Showing 22 changed files with 317 additions and 270 deletions.
52 changes: 26 additions & 26 deletions src/rl_sar/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
find_package(gazebo REQUIRED)

find_package(catkin REQUIRED COMPONENTS
controller_manager
genmsg
joint_state_controller
robot_state_publisher
roscpp
gazebo_ros
std_msgs
tf
geometry_msgs
robot_msgs
robot_joint_controller
rospy
controller_manager
genmsg
joint_state_controller
robot_state_publisher
roscpp
gazebo_ros
std_msgs
tf
geometry_msgs
robot_msgs
robot_joint_controller
rospy
)

find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
Expand All @@ -36,33 +36,33 @@ link_directories(/usr/local/lib)
include_directories(${YAML_CPP_INCLUDE_DIR})

catkin_package(
CATKIN_DEPENDS
robot_joint_controller
rospy
CATKIN_DEPENDS
robot_joint_controller
rospy
)

include_directories(library/unitree_legged_sdk_3.2/include)
link_directories(library/unitree_legged_sdk_3.2/lib)
set(EXTRA_LIBS -pthread libunitree_legged_sdk_amd64.so lcm)

include_directories(
include
${catkin_INCLUDE_DIRS}
${unitree_legged_sdk_INCLUDE_DIRS}
library/matplotlibcpp
library/observation_buffer
library/rl_sdk
library/loop
include
${catkin_INCLUDE_DIRS}
${unitree_legged_sdk_INCLUDE_DIRS}
library/matplotlibcpp
library/observation_buffer
library/rl_sdk
library/loop
)

add_library(rl_sdk library/rl_sdk/rl_sdk.cpp)
target_link_libraries(rl_sdk "${TORCH_LIBRARIES}" Python3::Python Python3::Module)
set_property(TARGET rl_sdk PROPERTY CXX_STANDARD 14)
find_package(Python3 COMPONENTS NumPy)
if(Python3_NumPy_FOUND)
target_link_libraries(rl_sdk Python3::NumPy)
target_link_libraries(rl_sdk Python3::NumPy)
else()
target_compile_definitions(rl_sdk WITHOUT_NUMPY)
target_compile_definitions(rl_sdk WITHOUT_NUMPY)
endif()

add_library(observation_buffer library/observation_buffer/observation_buffer.cpp)
Expand All @@ -71,13 +71,13 @@ set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)

add_executable(rl_sim src/rl_sim.cpp )
target_link_libraries(rl_sim
${catkin_LIBRARIES} ${EXTRA_LIBS}
${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp
)

add_executable(rl_real_a1 src/rl_real_a1.cpp )
target_link_libraries(rl_real_a1
${catkin_LIBRARIES} ${EXTRA_LIBS}
${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp
)

Expand Down
7 changes: 4 additions & 3 deletions src/rl_sar/include/rl_real_a1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class RL_Real : public RL
public:
RL_Real();
~RL_Real();

private:
// rl functions
torch::Tensor Forward() override;
Expand Down Expand Up @@ -43,8 +44,8 @@ class RL_Real : public RL
void Plot();

// unitree interface
void UDPSend(){unitree_udp.Send();}
void UDPRecv(){unitree_udp.Recv();}
void UDPSend() { unitree_udp.Send(); }
void UDPRecv() { unitree_udp.Recv(); }
UNITREE_LEGGED_SDK::Safety unitree_safe;
UNITREE_LEGGED_SDK::UDP unitree_udp;
UNITREE_LEGGED_SDK::LowCmd unitree_low_command = {0};
Expand All @@ -59,4 +60,4 @@ class RL_Real : public RL
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
};

#endif
#endif // RL_REAL_HPP
5 changes: 3 additions & 2 deletions src/rl_sar/include/rl_sim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RL_Sim : public RL
public:
RL_Sim();
~RL_Sim();

private:
// rl functions
torch::Tensor Forward() override;
Expand Down Expand Up @@ -69,7 +70,7 @@ class RL_Sim : public RL
std::vector<double> mapped_joint_positions;
std::vector<double> mapped_joint_velocities;
std::vector<double> mapped_joint_efforts;
void MapData(const std::vector<double>& source_data, std::vector<double>& target_data);
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
};

#endif
#endif // RL_SIM_HPP
4 changes: 2 additions & 2 deletions src/rl_sar/launch/gazebo_a1_isaacgym.launch
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<arg name="debug" default="false"/>
<!-- Debug mode will hung up the robot, use "true" or "false" to switch it. -->
<arg name="user_debug" default="false"/>

<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>
Expand All @@ -26,7 +26,7 @@

<!-- Load the URDF into the ROS Parameter Server -->
<param name="robot_description"
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
DEBUG:=$(arg user_debug)"/>

<!-- Run a python script to the send a service call to gazebo_ros to spawn a URDF robot -->
Expand Down
4 changes: 2 additions & 2 deletions src/rl_sar/launch/gazebo_a1_isaacsim.launch
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<arg name="debug" default="false"/>
<!-- Debug mode will hung up the robot, use "true" or "false" to switch it. -->
<arg name="user_debug" default="false"/>

<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>
Expand All @@ -26,7 +26,7 @@

<!-- Load the URDF into the ROS Parameter Server -->
<param name="robot_description"
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
DEBUG:=$(arg user_debug)"/>

<!-- Run a python script to the send a service call to gazebo_ros to spawn a URDF robot -->
Expand Down
2 changes: 1 addition & 1 deletion src/rl_sar/launch/gazebo_gr1t1_isaacgym.launch
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<arg name="gui" default="true"/>
<arg name="headless" default="false"/>
<arg name="debug" default="false"/>

<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>
Expand Down
2 changes: 1 addition & 1 deletion src/rl_sar/launch/gazebo_gr1t2_isaacgym.launch
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<arg name="gui" default="true"/>
<arg name="headless" default="false"/>
<arg name="debug" default="false"/>

<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>
Expand Down
39 changes: 20 additions & 19 deletions src/rl_sar/library/loop/loop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,9 @@

class LoopFunc
{
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;

public:
public:
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}

void start()
{
Expand Down Expand Up @@ -57,12 +47,22 @@ class LoopFunc
}
log("[Loop End] named: " + _name);
}
private:

private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;

void loop()
{
while (_running)
{
auto start = std::chrono::steady_clock::now();
while (_running)
{
auto start = std::chrono::steady_clock::now();

_func();

Expand All @@ -72,7 +72,8 @@ class LoopFunc
if (sleepTime.count() > 0)
{
std::unique_lock<std::mutex> lock(_mutex);
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
if (_cv.wait_for(lock, sleepTime, [this]
{ return !_running; }))
{
break;
}
Expand All @@ -87,7 +88,7 @@ class LoopFunc
return stream.str();
}

void log(const std::string& message)
void log(const std::string &message)
{
static std::mutex logMutex;
std::lock_guard<std::mutex> lock(logMutex);
Expand All @@ -108,4 +109,4 @@ class LoopFunc
}
};

#endif
#endif // LOOP_H
13 changes: 7 additions & 6 deletions src/rl_sar/library/observation_buffer/observation_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

ObservationBuffer::ObservationBuffer() {}

ObservationBuffer::ObservationBuffer(int num_envs,
int num_obs,
int include_history_steps)
: num_envs(num_envs),
num_obs(num_obs),
ObservationBuffer::ObservationBuffer(int num_envs,
int num_obs,
int include_history_steps)
: num_envs(num_envs),
num_obs(num_obs),
include_history_steps(include_history_steps)
{
num_obs_total = num_obs * include_history_steps;
Expand All @@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs,
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
{
std::vector<torch::indexing::TensorIndex> indices;
for (int idx : reset_idxs) {
for (int idx : reset_idxs)
{
indices.push_back(torch::indexing::Slice(idx));
}
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
Expand Down
5 changes: 3 additions & 2 deletions src/rl_sar/library/observation_buffer/observation_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#include <torch/torch.h>
#include <vector>

class ObservationBuffer {
class ObservationBuffer
{
public:
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
ObservationBuffer();
Expand All @@ -21,4 +22,4 @@ class ObservationBuffer {
torch::Tensor obs_buf;
};

#endif // OBSERVATION_BUFFER_HPP
#endif // OBSERVATION_BUFFER_HPP
Loading

0 comments on commit 8b6a2d6

Please sign in to comment.