Skip to content

Commit

Permalink
fix a bunch of bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Mar 5, 2024
1 parent ede8108 commit 0a10249
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 119 deletions.
222 changes: 110 additions & 112 deletions mani_skill2/utils/wrappers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
append_dict_array,
extract_scalars_from_info,
find_max_episode_steps_value,
flatten_dict_keys,
slice_dict_array,
)
from mani_skill2.utils.io_utils import dump_json
from mani_skill2.utils.sapien_utils import batch, to_numpy
Expand All @@ -28,6 +26,16 @@
)


def slice_dict_array(x1, slice: slice):
"""Slices every array in x1 with slice and returns result. Tries to do this in place if possible"""
if isinstance(x1, np.ndarray) or isinstance(x1, list):
return x1[slice]
elif isinstance(x1, dict):
for k in x1.keys():
x1[k] = slice_dict_array(x1[k], slice)
return x1


def parse_env_info(env: gym.Env):
# spec can be None if not initialized from gymnasium.make
env = env.unwrapped
Expand Down Expand Up @@ -191,6 +199,8 @@ def __init__(
self.output_dir.mkdir(parents=True, exist_ok=True)
self.video_fps = video_fps
self._episode_id = -1
self._video_id = -1
self._closed = False

self._trajectory_buffer: Step = None
self._episode_info = {}
Expand Down Expand Up @@ -261,16 +271,28 @@ def reset(

# when we just have one env, we look at save_on_reset and clear the trajectory buffer
# when there are mutliple envs we save based on timesteps and must do more finegrained management of the buffer
if (
self.num_envs == 1
and self.save_on_reset
and self._trajectory_buffer is not None
):
if not skip_trajectory:
self.flush_trajectory(ignore_empty_transition=True)
# if (
# self.num_envs == 1
# and self.save_on_reset
# and self._trajectory_buffer is not None
# ):
# if not skip_trajectory:
# self.flush_trajectory(ignore_empty_transition=True, env_idxs_to_flush=[0])
# self.flush_video()
# else:
# self._trajectory_buffer = None

if self.save_on_reset and self._trajectory_buffer is not None:
if self.save_video and self.num_envs == 1:
self.flush_video()
else:
self._trajectory_buffer = None
if not skip_trajectory:
# if doing a full reset then we flush all trajectories including incompleted ones
if "env_idx" not in options:
self.flush_trajectory(env_idxs_to_flush=np.arange(self.num_envs))
else:
self.flush_trajectory(
env_idxs_to_flush=to_numpy(options["env_idx"])
)

self._episode_info = {}

Expand Down Expand Up @@ -307,41 +329,43 @@ def reset(
fail=np.zeros((1, self.num_envs), dtype=bool),
env_episode_ptr=np.zeros((self.num_envs,), dtype=int),
)
env_idx = np.arange(self.num_envs)
if "env_idx" in options:
env_idx = to_numpy(options["env_idx"])
if self._trajectory_buffer is None:
# Initialize trajectory buffer on the first episode based on given observation (which should be generated after all wrappers)
# TODO (stao): we do not really know the max size of the trajectory buffer since we keep it in memory until we flush?
# which for cpu env we do not know max size. gpu env we do.
self._trajectory_buffer = first_step
else:
self._trajectory_buffer.state = append_dict_array(
self._trajectory_buffer.state, first_step.state
)
self._trajectory_buffer.observation = append_dict_array(

def recursive_replace(x, y):
if isinstance(x, np.ndarray):
x[-1, env_idx] = y[-1, env_idx]
else:
for k in x.keys():
recursive_replace(x[k], y[k])

# import ipdb;ipdb.set_trace()
recursive_replace(self._trajectory_buffer.state, first_step.state)
recursive_replace(
self._trajectory_buffer.observation, first_step.observation
)
self._trajectory_buffer.action = append_dict_array(
self._trajectory_buffer.action, first_step.action
)
self._trajectory_buffer.reward = append_dict_array(
self._trajectory_buffer.reward, first_step.reward
)
self._trajectory_buffer.terminated = append_dict_array(
recursive_replace(self._trajectory_buffer.action, first_step.action)
recursive_replace(self._trajectory_buffer.reward, first_step.reward)
recursive_replace(
self._trajectory_buffer.terminated, first_step.terminated
)
self._trajectory_buffer.truncated = append_dict_array(
recursive_replace(
self._trajectory_buffer.truncated, first_step.truncated
)
self._trajectory_buffer.done = append_dict_array(
self._trajectory_buffer.done, first_step.done
)
recursive_replace(self._trajectory_buffer.done, first_step.done)
if self._trajectory_buffer.success is not None:
self._trajectory_buffer.success = append_dict_array(
recursive_replace(
self._trajectory_buffer.success, first_step.success
)
if self._trajectory_buffer.fail is not None:
self._trajectory_buffer.fail = append_dict_array(
self._trajectory_buffer.fail, first_step.fail
)
recursive_replace(self._trajectory_buffer.fail, first_step.fail)
if self.save_video:
self._render_images.append(self.capture_image())

Expand Down Expand Up @@ -394,8 +418,6 @@ def step(self, action):
else:
self._trajectory_buffer.fail = None
self._last_info = to_numpy(info)
if done.any():
self.flush_trajectory()

if self.save_video:
self._video_steps += 1
Expand All @@ -422,44 +444,18 @@ def flush_trajectory(
self,
verbose=False,
ignore_empty_transition=False,
flush_incomplete_trajectories=False,
env_idxs_to_flush=[],
):
# if ignore_empty_transition and len(self.t) == 1:
# return

# truncate the trajectory buffer

# find which trajectories completed

env_idxs_to_flush = []
if flush_incomplete_trajectories:
env_idxs_to_flush = np.arange(self.num_envs)
else:
# env_idxs_to_flush = np.argwhere(
# self._trajectory_buffer.done.sum(0) == 2
# ).flatten()
for env_idx in range(self.num_envs):
this_env_eps_dones = self._trajectory_buffer.done[
self._trajectory_buffer.env_episode_ptr[env_idx] :, env_idx
]
assert this_env_eps_dones[0] == True
if this_env_eps_dones[-1] == True:
env_idxs_to_flush.append(env_idx)

flush_count = 0
for env_idx in env_idxs_to_flush:
start_ptr = self._trajectory_buffer.env_episode_ptr[env_idx]
end_ptr = len(self._trajectory_buffer.done)
if ignore_empty_transition and end_ptr - start_ptr <= 1:
continue
self._episode_id += 1

traj_id = "traj_{}".format(self._episode_id)
group = self._h5_file.create_group(traj_id, track_order=True)
start_ptr, end_ptr = (
np.argwhere(
self._trajectory_buffer.done[
self._trajectory_buffer.env_episode_ptr[env_idx] :, env_idx
]
== True
).flatten()
+ self._trajectory_buffer.env_episode_ptr[env_idx]
)
end_ptr += 1 # inclusive

def recursive_add_to_h5py(group: h5py.Group, data: dict, key):
"""simple recursive data insertion for nested data structures into h5py, optimizing for visual data as well"""
Expand Down Expand Up @@ -574,66 +570,65 @@ def recursive_add_to_h5py(group: h5py.Group, data: dict, key):
episode_id=self._episode_id,
episode_seed=self._base_env._episode_seed,
control_mode=self._base_env.control_mode,
elapsed_steps=end_ptr - start_ptr,
elapsed_steps=end_ptr - start_ptr - 1,
)
self._json_data["episodes"].append(episode_info)
dump_json(self._json_path, self._json_data, indent=2)
flush_count += 1

if verbose:
if len(env_idxs_to_flush) == 1:
if flush_count == 1:
print(f"Recorded episode {self._episode_id}")
else:
print(
f"Recorded episodes {self._episode_id - len(env_idxs_to_flush + 1)} to {self._episode_id}"
f"Recorded episodes {self._episode_id - flush_count} to {self._episode_id}"
)
# truncate self._trajectory_buffer down to save memory

self._trajectory_buffer.env_episode_ptr[env_idxs_to_flush] = len(
self._trajectory_buffer.done
)
min_env_ptr = self._trajectory_buffer.env_episode_ptr.min()
N = len(self._trajectory_buffer.done)
self._trajectory_buffer.state = slice_dict_array(
self._trajectory_buffer.state, slice(min_env_ptr, N)
)
self._trajectory_buffer.observation = slice_dict_array(
self._trajectory_buffer.observation, slice(min_env_ptr, N)
)
self._trajectory_buffer.action = slice_dict_array(
self._trajectory_buffer.action, slice(min_env_ptr, N)
)
self._trajectory_buffer.reward = slice_dict_array(
self._trajectory_buffer.reward, slice(min_env_ptr, N)
)
self._trajectory_buffer.terminated = slice_dict_array(
self._trajectory_buffer.terminated, slice(min_env_ptr, N)
)
self._trajectory_buffer.truncated = slice_dict_array(
self._trajectory_buffer.truncated, slice(min_env_ptr, N)
)
self._trajectory_buffer.done = slice_dict_array(
self._trajectory_buffer.done, slice(min_env_ptr, N)
)
if self._trajectory_buffer.success is not None:
self._trajectory_buffer.success = slice_dict_array(
self._trajectory_buffer.success, slice(min_env_ptr, N)
if flush_count > 0:
self._trajectory_buffer.env_episode_ptr[env_idxs_to_flush] = (
len(self._trajectory_buffer.done) - 1
)
if self._trajectory_buffer.fail is not None:
self._trajectory_buffer.fail = slice_dict_array(
self._trajectory_buffer.fail, slice(min_env_ptr, N)
min_env_ptr = self._trajectory_buffer.env_episode_ptr.min()
N = len(self._trajectory_buffer.done)
self._trajectory_buffer.state = slice_dict_array(
self._trajectory_buffer.state, slice(min_env_ptr, N)
)
import ipdb

ipdb.set_trace()
self._trajectory_buffer.env_episode_ptr -= min_env_ptr
self._trajectory_buffer.observation = slice_dict_array(
self._trajectory_buffer.observation, slice(min_env_ptr, N)
)
self._trajectory_buffer.action = slice_dict_array(
self._trajectory_buffer.action, slice(min_env_ptr, N)
)
self._trajectory_buffer.reward = slice_dict_array(
self._trajectory_buffer.reward, slice(min_env_ptr, N)
)
self._trajectory_buffer.terminated = slice_dict_array(
self._trajectory_buffer.terminated, slice(min_env_ptr, N)
)
self._trajectory_buffer.truncated = slice_dict_array(
self._trajectory_buffer.truncated, slice(min_env_ptr, N)
)
self._trajectory_buffer.done = slice_dict_array(
self._trajectory_buffer.done, slice(min_env_ptr, N)
)
if self._trajectory_buffer.success is not None:
self._trajectory_buffer.success = slice_dict_array(
self._trajectory_buffer.success, slice(min_env_ptr, N)
)
if self._trajectory_buffer.fail is not None:
self._trajectory_buffer.fail = slice_dict_array(
self._trajectory_buffer.fail, slice(min_env_ptr, N)
)
self._trajectory_buffer.env_episode_ptr -= min_env_ptr

def flush_video(self, suffix="", verbose=False, ignore_empty_transition=True):
if not self.save_video or len(self._render_images) == 0:
if len(self._render_images) == 0:
return
if ignore_empty_transition and len(self._render_images) == 1:
return

video_name = "{}".format(self._episode_id)
self._video_id += 1
video_name = "{}".format(self._video_id)
if suffix:
video_name += "_" + suffix
images_to_video(
Expand All @@ -647,14 +642,17 @@ def flush_video(self, suffix="", verbose=False, ignore_empty_transition=True):
self._render_images = []

def close(self) -> None:
if self._closed:
# There is some strange bug when vector envs using record wrapper are closed/deleted, this code runs twice
return
self._closed = True
if self.save_trajectory:
# Handle the last episode only when `save_on_reset=True`
if self.save_on_reset:
traj_id = "traj_{}".format(self._episode_id)
if traj_id in self._h5_file:
logger.warning(f"{traj_id} exists in h5.")
else:
self.flush_trajectory(ignore_empty_transition=True)
self.flush_trajectory(
ignore_empty_transition=True,
env_idxs_to_flush=np.arange(self.num_envs),
)
if self.clean_on_close:
clean_trajectories(self._h5_file, self._json_data)
dump_json(self._json_path, self._json_data, indent=2)
Expand Down
22 changes: 15 additions & 7 deletions manualtest/record_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,26 @@
info_on_video=False,
video_fps=30,
save_trajectory=True,
max_steps_per_video=50,
max_steps_per_video=200,
)
env = ManiSkillVectorEnv(env)

env.reset(seed=52, options=dict(reconfigure=True))
for i in range(200):
# for i in range(180):
# env.step(env.action_space.sample())
env.step(env.action_space.sample())
env.step(env.action_space.sample())
print("partial reset")
env.reset(options=dict(env_idx=[0]))
# for i in range(50):
# env.step(env.action_space.sample())
for i in range(60):
env.step(env.action_space.sample())
print("prep close")
env.close()
# import h5py

import h5py
# data = h5py.File("videos/manual_test/PickCube-v1.h5")
# import ipdb

data = h5py.File("videos/manual_test/PickCube-v1.h5")
import ipdb

ipdb.set_trace()
# ipdb.set_trace()

0 comments on commit 0a10249

Please sign in to comment.