From 16e5a077aa39d1554898f00edb2931356f4e5010 Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Tue, 5 Mar 2024 17:30:48 -0800 Subject: [PATCH] Update record.py --- mani_skill2/utils/wrappers/record.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mani_skill2/utils/wrappers/record.py b/mani_skill2/utils/wrappers/record.py index 713a036b6..077568693 100644 --- a/mani_skill2/utils/wrappers/record.py +++ b/mani_skill2/utils/wrappers/record.py @@ -299,9 +299,6 @@ def reset( else: self.flush_trajectory(env_idxs_to_flush=to_numpy(options["env_idx"])) - self.last_reset_kwargs = copy.deepcopy( - dict(seed=seed, options=options, **kwargs) - ) obs, info = super().reset(*args, seed=seed, options=options, **kwargs) if self.save_trajectory: @@ -365,7 +362,11 @@ def recursive_replace(x, y): ) if self._trajectory_buffer.fail is not None: recursive_replace(self._trajectory_buffer.fail, first_step.fail) - + if "env_idx" in options: + options["env_idx"] = to_numpy(options["env_idx"]) + self.last_reset_kwargs = copy.deepcopy( + dict(seed=seed, options=options, **kwargs) + ) return obs, info def step(self, action):