Skip to content

Commit

Permalink
feat: add obs list in config.yaml and move ComputeObservation to rl_sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed Aug 16, 2024
1 parent 11d600a commit 4868c7d
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 86 deletions.
14 changes: 4 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,7 @@ sudo ldconfig

## Compilation

Customize the following two functions in your code to adapt to different models:

```cpp
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
```

Then compile in the root directory
Compile in the root directory of the project

```bash
cd ..
Expand Down Expand Up @@ -142,8 +135,9 @@ In the following text, `<ROBOT>` represents the name of the robot
1. Create a model package named `<ROBOT>_description` in the `rl_sar/src/robots` directory. Place the robot's URDF file in the `rl_sar/src/robots/<ROBOT>_description/urdf` directory and name it `<ROBOT>.urdf`. Additionally, create a joint configuration file with the namespace `<ROBOT>_gazebo` in the `rl_sar/src/robots/<ROBOT>_description/config` directory.
2. Place the trained RL model files in the `rl_sar/src/rl_sar/models/<ROBOT>` directory.
3. In the `rl_sar/src/rl_sar/models/<ROBOT>` directory, create a `config.yaml` file, and modify its parameters based on the `rl_sar/src/rl_sar/models/a1_isaacgym/config.yaml` file.
4. If you need to run simulations, modify the launch files as needed by referring to those in the `rl_sar/src/rl_sar/launch` directory.
5. If you need to run on the physical robot, modify the file `rl_sar/src/rl_sar/src/rl_real_a1.cpp` as needed.
4. Modify the `forward()` function in the code as needed to adapt to different models.
5. If you need to run simulations, modify the launch files as needed by referring to those in the `rl_sar/src/rl_sar/launch` directory.
6. If you need to run on the physical robot, modify the file `rl_sar/src/rl_sar/src/rl_real_a1.cpp` as needed.
## Reference
Expand Down
14 changes: 4 additions & 10 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,7 @@ sudo ldconfig

## 编译

自定义代码中的以下两个函数,以适配不同的模型:

```cpp
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
```

然后到根目录编译
在项目根目录编译

```bash
cd ..
Expand Down Expand Up @@ -143,8 +136,9 @@ rosrun rl_sar rl_real_a1
1. 在`rl_sar/src/robots`路径下创建名为`<ROBOT>_description`的模型包,将模型的urdf放到`rl_sar/src/robots/<ROBOT>_description/urdf`路径下并命名为`<ROBOT>.urdf`,并在`rl_sar/src/robots/<ROBOT>_description/config`路径下创建命名空间为`<ROBOT>_gazebo`的关节配置文件
2. 将训练好的RL模型文件放到`rl_sar/src/rl_sar/models/<ROBOT>`路径下
3. 在`rl_sar/src/rl_sar/models/<ROBOT>`中新建config.yaml文件,参考`rl_sar/src/rl_sar/models/a1_isaacgym/config.yaml`文件修改其中参数
4. 若需要运行仿真,则参考`rl_sar/src/rl_sar/launch`路径下的launch文件自行修改
5. 若需要运行实物,则参考`rl_sar/src/rl_sar/src/rl_real_a1.cpp`文件自行修改
4. 按需修改代码中的`forward()`函数,以适配不同的模型
5. 若需要运行仿真,则参考`rl_sar/src/rl_sar/launch`路径下的launch文件自行修改
6. 若需要运行实物,则参考`rl_sar/src/rl_sar/src/rl_real_a1.cpp`文件自行修改
## 参考
Expand Down
1 change: 0 additions & 1 deletion src/rl_sar/include/rl_real_a1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class RL_Real : public RL
private:
// rl functions
torch::Tensor Forward() override;
torch::Tensor ComputeObservation() override;
void GetState(RobotState<double> *state) override;
void SetCommand(const RobotCommand<double> *command) override;
void RunModel();
Expand Down
1 change: 0 additions & 1 deletion src/rl_sar/include/rl_sim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class RL_Sim : public RL
private:
// rl functions
torch::Tensor Forward() override;
torch::Tensor ComputeObservation() override;
void GetState(RobotState<double> *state) override;
void SetCommand(const RobotCommand<double> *command) override;
void RunModel();
Expand Down
59 changes: 43 additions & 16 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,5 @@
#include "rl_sdk.hpp"

/* You may need to override this ComputeObservation() function
torch::Tensor RL_XXX::ComputeObservation()
{
torch::Tensor obs = torch::cat({
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions
},1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}
*/

/* You may need to override this Forward() function
torch::Tensor RL_XXX::Forward()
{
Expand All @@ -27,6 +11,48 @@ torch::Tensor RL_XXX::Forward()
}
*/

torch::Tensor RL::ComputeObservation()
{
std::vector<torch::Tensor> obs_list;

for(const std::string& observation : this->params.observations)
{
if(observation == "lin_vel")
{
obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale);
}
else if(observation == "ang_vel")
{
// obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery?
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale);
}
else if(observation == "gravity_vec")
{
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework));
}
else if(observation == "commands")
{
obs_list.push_back(this->obs.commands * this->params.commands_scale);
}
else if(observation == "dof_pos")
{
obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale);
}
else if(observation == "dof_vel")
{
obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale);
}
else if(observation == "actions")
{
obs_list.push_back(this->obs.actions);
}
}

torch::Tensor obs = torch::cat(obs_list, 1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}

void RL::InitObservations()
{
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
Expand Down Expand Up @@ -369,6 +395,7 @@ void RL::ReadYaml(std::string robot_name)
this->params.dt = config["dt"].as<double>();
this->params.decimation = config["decimation"].as<int>();
this->params.num_observations = config["num_observations"].as<int>();
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
this->params.clip_obs = config["clip_obs"].as<double>();
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
Expand Down
3 changes: 2 additions & 1 deletion src/rl_sar/library/rl_sdk/rl_sdk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct ModelParams
double dt;
int decimation;
int num_observations;
std::vector<std::string> observations;
double damping;
double stiffness;
double action_scale;
Expand Down Expand Up @@ -128,7 +129,7 @@ class RL

// rl functions
virtual torch::Tensor Forward() = 0;
virtual torch::Tensor ComputeObservation() = 0;
torch::Tensor ComputeObservation();
virtual void GetState(RobotState<double> *state) = 0;
virtual void SetCommand(const RobotCommand<double> *command) = 0;
void StateController(const RobotState<double> *state, RobotCommand<double> *command);
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/models/a1_isaacgym/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ a1_isaacgym:
dt: 0.005
decimation: 4
num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
clip_obs: 100.0
clip_actions_lower: [-100, -100, -100,
-100, -100, -100,
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/models/a1_isaacsim/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ a1_isaacsim:
dt: 0.005
decimation: 4
num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
clip_obs: 100.0
clip_actions_lower: [-100, -100, -100,
-100, -100, -100,
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/models/gr1t1_isaacgym/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gr1t1_isaacgym:
dt: 0.001
decimation: 20
num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/models/gr1t1_isaacsim/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gr1t1_isaacsim:
dt: 0.001
decimation: 20
num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/models/gr1t2_isaacgym/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gr1t2_isaacgym:
dt: 0.001
decimation: 20
num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
Expand Down
1 change: 1 addition & 0 deletions src/rl_sar/models/gr1t2_isaacsim/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gr1t2_isaacsim:
dt: 0.001
decimation: 20
num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
Expand Down
24 changes: 24 additions & 0 deletions src/rl_sar/scripts/rl_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self):
self.dt = None
self.decimation = None
self.num_observations = None
self.observations = None
self.damping = None
self.stiffness = None
self.action_scale = None
Expand Down Expand Up @@ -134,6 +135,28 @@ def __init__(self):
self.output_torques = torch.zeros(1, 32)
self.output_dof_pos = torch.zeros(1, 32)

def ComputeObservation(self):
obs_list = []
for observation in self.params.observations:
if observation == "lin_vel":
obs_list.append(self.obs.lin_vel * self.params.lin_vel_scale)
elif observation == "ang_vel":
# obs_list.append(self.obs.ang_vel * self.params.ang_vel_scale) # TODO is QuatRotateInverse necessery?
obs_list.append(self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel, self.params.framework) * self.params.ang_vel_scale)
elif observation == "gravity_vec":
obs_list.append(self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec, self.params.framework))
elif observation == "commands":
obs_list.append(self.obs.commands * self.params.commands_scale)
elif observation == "dof_pos":
obs_list.append((self.obs.dof_pos - self.params.default_dof_pos) * self.params.dof_pos_scale)
elif observation == "dof_vel":
obs_list.append(self.obs.dof_vel * self.params.dof_vel_scale)
elif observation == "actions":
obs_list.append(self.obs.actions)
obs = torch.cat(obs_list, dim=-1)
clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs)
return clamped_obs

def InitObservations(self):
self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float)
self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float)
Expand Down Expand Up @@ -359,6 +382,7 @@ def ReadYaml(self, robot_name):
self.params.dt = config["dt"]
self.params.decimation = config["decimation"]
self.params.num_observations = config["num_observations"]
self.params.observations = config["observations"]
self.params.clip_obs = config["clip_obs"]
self.params.action_scale = config["action_scale"]
self.params.hip_scale_reduction = config["hip_scale_reduction"]
Expand Down
16 changes: 1 addition & 15 deletions src/rl_sar/scripts/rl_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def JointStatesCallback(self, msg):

def RunModel(self):
if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running:
# self.obs.lin_vel = torch.tensor([[self.vel.linear.x, self.vel.linear.y, self.vel.linear.z]])
self.obs.lin_vel = torch.tensor([[self.vel.linear.x, self.vel.linear.y, self.vel.linear.z]])
self.obs.ang_vel = torch.tensor(self.robot_state.imu.gyroscope).unsqueeze(0)
# self.obs.commands = torch.tensor([[self.cmd_vel.linear.x, self.cmd_vel.linear.y, self.cmd_vel.angular.z]])
self.obs.commands = torch.tensor([[self.control.x, self.control.y, self.control.yaw]])
Expand All @@ -199,20 +199,6 @@ def RunModel(self):
tau_est = torch.tensor(self.mapped_joint_efforts).unsqueeze(0)
self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)

def ComputeObservation(self):
obs = torch.cat([
# self.obs.lin_vel * self.params.lin_vel_scale,
# self.obs.ang_vel * self.params.ang_vel_scale, # TODO is QuatRotateInverse necessery?
self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel, self.params.framework) * self.params.ang_vel_scale,
self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec, self.params.framework),
self.obs.commands * self.params.commands_scale,
(self.obs.dof_pos - self.params.default_dof_pos) * self.params.dof_pos_scale,
self.obs.dof_vel * self.params.dof_vel_scale,
self.obs.actions
], dim = -1)
clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs)
return clamped_obs

def Forward(self):
torch.set_grad_enabled(False)
clamped_obs = self.ComputeObservation()
Expand Down
15 changes: 0 additions & 15 deletions src/rl_sar/src/rl_real_a1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,6 @@ void RL_Real::RunModel()
}
}

torch::Tensor RL_Real::ComputeObservation()
{
torch::Tensor obs = torch::cat({
// this->QuatRotateInverse(this->obs.base_quat, this->obs.lin_vel) * this->params.lin_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions
},1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}

torch::Tensor RL_Real::Forward()
{
torch::autograd::GradMode::set_enabled(false);
Expand Down
18 changes: 1 addition & 17 deletions src/rl_sar/src/rl_sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ void RL_Sim::RunModel()
{
if(this->running_state == STATE_RL_RUNNING && simulation_running)
{
// this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
// this->obs.commands = torch::tensor({{this->cmd_vel.linear.x, this->cmd_vel.linear.y, this->cmd_vel.angular.z}});
this->obs.commands = torch::tensor({{this->control.x, this->control.y, this->control.yaw}});
Expand Down Expand Up @@ -243,22 +243,6 @@ void RL_Sim::RunModel()
}
}

torch::Tensor RL_Sim::ComputeObservation()
{
torch::Tensor obs = torch::cat({
// this->obs.lin_vel * this->params.lin_vel_scale,
// this->obs.ang_vel * this->params.ang_vel_scale, // TODO is QuatRotateInverse necessery?
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions
},1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}

torch::Tensor RL_Sim::Forward()
{
torch::autograd::GradMode::set_enabled(false);
Expand Down

0 comments on commit 4868c7d

Please sign in to comment.