Skip to content

Commit

Permalink
fix: add USE_HISTORY
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed May 23, 2024
1 parent 13dbb89 commit 87d9b69
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/rl_sar/src/rl_sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// #define PLOT
// #define CSV_LOGGER
#define USE_HISTORY

RL_Sim::RL_Sim()
{
Expand Down Expand Up @@ -33,7 +34,9 @@ RL_Sim::RL_Sim()
this->InitObservations();
this->InitOutputs();

#ifdef USE_HISTORY
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
#endif

joint_positions = std::vector<double>(params.num_of_dofs, 0.0);
joint_velocities = std::vector<double>(params.num_of_dofs, 0.0);
Expand Down Expand Up @@ -175,10 +178,13 @@ torch::Tensor RL_Sim::Forward()
{
torch::Tensor obs = this->ComputeObservation();

#ifdef USE_HISTORY
history_obs_buf.insert(obs);
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});

torch::Tensor action = this->model.forward({history_obs}).toTensor();
#else
torch::Tensor action = this->model.forward({obs}).toTensor();
#endif

this->obs.actions = action;
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
Expand Down

0 comments on commit 87d9b69

Please sign in to comment.