diff --git a/mani_skill2/utils/wrappers/record.py b/mani_skill2/utils/wrappers/record.py index f38edd398..5f7d03320 100644 --- a/mani_skill2/utils/wrappers/record.py +++ b/mani_skill2/utils/wrappers/record.py @@ -290,46 +290,20 @@ def reset( options: Optional[dict] = dict(), **kwargs, ): - skip_trajectory = False - options.pop("save_trajectory", False) - - # when we just have one env, we look at save_on_reset and clear the trajectory buffer - # when there are mutliple envs we save based on timesteps and must do more finegrained management of the buffer - # if ( - # self.num_envs == 1 - # and self.save_on_reset - # and self._trajectory_buffer is not None - # ): - # if not skip_trajectory: - # self.flush_trajectory(ignore_empty_transition=True, env_idxs_to_flush=[0]) - # self.flush_video() - # else: - # self._trajectory_buffer = None - if self.save_on_reset and self._trajectory_buffer is not None: if self.save_video and self.num_envs == 1: self.flush_video() - if not skip_trajectory: - # if doing a full reset then we flush all trajectories including incompleted ones - if "env_idx" not in options: - self.flush_trajectory(env_idxs_to_flush=np.arange(self.num_envs)) - else: - self.flush_trajectory( - env_idxs_to_flush=to_numpy(options["env_idx"]) - ) + # if doing a full reset then we flush all trajectories including incompleted ones + if "env_idx" not in options: + self.flush_trajectory(env_idxs_to_flush=np.arange(self.num_envs)) + else: + self.flush_trajectory(env_idxs_to_flush=to_numpy(options["env_idx"])) reset_kwargs = copy.deepcopy(dict(seed=seed, options=options, **kwargs)) obs, info = super().reset(*args, seed=seed, options=options, **kwargs) if self.save_trajectory: state_dict = self._base_env.get_state_dict() - # self._episode_info.update( - # episode_id=self._episode_id, - # episode_seed=getattr(self.unwrapped, "_episode_seed", None), - # reset_kwargs=reset_kwargs, - # control_mode=getattr(self.unwrapped, "control_mode", None), - # elapsed_steps=0, - # ) action = batch(self.action_space.sample()) first_step = Step( state=to_numpy(batch(state_dict)), @@ -360,8 +334,6 @@ def reset( env_idx = to_numpy(options["env_idx"]) if self._trajectory_buffer is None: # Initialize trajectory buffer on the first episode based on given observation (which should be generated after all wrappers) - # TODO (stao): we do not really know the max size of the trajectory buffer since we keep it in memory until we flush? - # which for cpu env we do not know max size. gpu env we do. self._trajectory_buffer = first_step else: