Skip to content

Commit

Permalink
Update record.py
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Mar 5, 2024
1 parent 229f679 commit 0abc5cd
Showing 1 changed file with 5 additions and 33 deletions.
38 changes: 5 additions & 33 deletions mani_skill2/utils/wrappers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,46 +290,20 @@ def reset(
options: Optional[dict] = dict(),
**kwargs,
):
skip_trajectory = False
options.pop("save_trajectory", False)

# when we just have one env, we look at save_on_reset and clear the trajectory buffer
# when there are mutliple envs we save based on timesteps and must do more finegrained management of the buffer
# if (
# self.num_envs == 1
# and self.save_on_reset
# and self._trajectory_buffer is not None
# ):
# if not skip_trajectory:
# self.flush_trajectory(ignore_empty_transition=True, env_idxs_to_flush=[0])
# self.flush_video()
# else:
# self._trajectory_buffer = None

if self.save_on_reset and self._trajectory_buffer is not None:
if self.save_video and self.num_envs == 1:
self.flush_video()
if not skip_trajectory:
# if doing a full reset then we flush all trajectories including incompleted ones
if "env_idx" not in options:
self.flush_trajectory(env_idxs_to_flush=np.arange(self.num_envs))
else:
self.flush_trajectory(
env_idxs_to_flush=to_numpy(options["env_idx"])
)
# if doing a full reset then we flush all trajectories including incompleted ones
if "env_idx" not in options:
self.flush_trajectory(env_idxs_to_flush=np.arange(self.num_envs))
else:
self.flush_trajectory(env_idxs_to_flush=to_numpy(options["env_idx"]))

reset_kwargs = copy.deepcopy(dict(seed=seed, options=options, **kwargs))
obs, info = super().reset(*args, seed=seed, options=options, **kwargs)

if self.save_trajectory:
state_dict = self._base_env.get_state_dict()
# self._episode_info.update(
# episode_id=self._episode_id,
# episode_seed=getattr(self.unwrapped, "_episode_seed", None),
# reset_kwargs=reset_kwargs,
# control_mode=getattr(self.unwrapped, "control_mode", None),
# elapsed_steps=0,
# )
action = batch(self.action_space.sample())
first_step = Step(
state=to_numpy(batch(state_dict)),
Expand Down Expand Up @@ -360,8 +334,6 @@ def reset(
env_idx = to_numpy(options["env_idx"])
if self._trajectory_buffer is None:
# Initialize trajectory buffer on the first episode based on given observation (which should be generated after all wrappers)
# TODO (stao): we do not really know the max size of the trajectory buffer since we keep it in memory until we flush?
# which for cpu env we do not know max size. gpu env we do.
self._trajectory_buffer = first_step
else:

Expand Down

0 comments on commit 0abc5cd

Please sign in to comment.