Skip to content

Commit

Permalink
fix batching for scalar values
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Mar 5, 2024
1 parent 0a10249 commit cce5910
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 43 deletions.
29 changes: 1 addition & 28 deletions mani_skill2/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict, defaultdict
from typing import Dict, Sequence, Union
from typing import Dict, Sequence

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -27,33 +27,6 @@ def dict_merge(dct: dict, merge_dct: dict):
dct[k] = merge_dct[k]


def append_dict_array(
x1: Union[dict, Sequence, Array], x2: Union[dict, Sequence, Array]
):
"""Append `x2` in front of `x1` and returns the result. Tries to do this in place if possible.
Assumes both `x1, x2` have the same dictionary structure if they are dictionaries.
They may also both be lists/sequences in which case this is just appending like normal"""
if isinstance(x1, np.ndarray):
return np.concatenate([x1, x2])
elif isinstance(x1, list):
return x1 + x2
elif isinstance(x1, dict):
for k in x1.keys():
assert k in x2, "dct and append_dct need to have the same dictionary layout"
x1[k] = append_dict_array(x1[k], x2[k])
return x1


def slice_dict_array(x1: Union[dict, Sequence, Array], slice: slice):
"""Slices every array in x1 with slice and returns result. Tries to do this in place if possible"""
if isinstance(x1, np.ndarray) or isinstance(x1, list):
return x1[slice]
elif isinstance(x1, dict):
for k in x1.keys():
x1[k] = slice_dict_array(x1[k], slice)
return x1


def merge_dicts(ds: Sequence[Dict], asarray=False):
"""Merge multiple dicts with the same keys to a single one."""
# NOTE(jigu): To be compatible with generator, we only iterate once.
Expand Down
2 changes: 2 additions & 0 deletions mani_skill2/utils/sapien_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def _batch(array: Union[Array, Sequence]):
if isinstance(array, list):
if len(array) == 1:
return [array]
if isinstance(array, float) or isinstance(array, int) or isinstance(array, bool):
return np.array([[array]])
return array


Expand Down
27 changes: 23 additions & 4 deletions mani_skill2/utils/wrappers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,48 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Sequence, Union

import gymnasium as gym
import h5py
import numpy as np
import sapien.physx as physx
from gymnasium import spaces

from mani_skill2 import get_commit_info, logger
from mani_skill2 import get_commit_info
from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.utils.common import (
append_dict_array,
extract_scalars_from_info,
find_max_episode_steps_value,
)
from mani_skill2.utils.io_utils import dump_json
from mani_skill2.utils.sapien_utils import batch, to_numpy
from mani_skill2.utils.structs.types import Array
from mani_skill2.utils.visualization.misc import (
images_to_video,
put_info_on_image,
tile_images,
)

# NOTE (stao): The code for record.py is quite messy and perhaps confusing as it is trying to support both recording on CPU and GPU seamlessly
# and handle partial resets. It works but can be claned up a lot.


def append_dict_array(
x1: Union[dict, Sequence, Array], x2: Union[dict, Sequence, Array]
):
"""Append `x2` in front of `x1` and returns the result. Tries to do this in place if possible.
Assumes both `x1, x2` have the same dictionary structure if they are dictionaries.
They may also both be lists/sequences in which case this is just appending like normal"""
if isinstance(x1, np.ndarray):
return np.concatenate([x1, x2])
elif isinstance(x1, list):
return x1 + x2
elif isinstance(x1, dict):
for k in x1.keys():
assert k in x2, "dct and append_dct need to have the same dictionary layout"
x1[k] = append_dict_array(x1[k], x2[k])
return x1


def slice_dict_array(x1, slice: slice):
"""Slices every array in x1 with slice and returns result. Tries to do this in place if possible"""
Expand Down
22 changes: 11 additions & 11 deletions manualtest/record_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
if __name__ == "__main__":
# sapien.set_log_level("info")
# , "StackCube-v1", "PickCube-v1", "PushCube-v1", "PickSingleYCB-v1", "OpenCabinet-v1"
num_envs = 2
num_envs = 1
for env_id in ["PickCube-v1"]:
env = gym.make(
env_id,
Expand All @@ -25,7 +25,7 @@
# control_mode="pd_ee_delta_pos",
# sim_freq=100,
# control_freq=20,
force_use_gpu_sim=True,
# force_use_gpu_sim=True,
# reconfiguration_freq=1,
)
env = RecordEpisode(
Expand All @@ -35,21 +35,21 @@
info_on_video=False,
video_fps=30,
save_trajectory=True,
max_steps_per_video=200,
max_steps_per_video=50,
)
env = ManiSkillVectorEnv(env)
# env = ManiSkillVectorEnv(env)

env.reset(seed=52, options=dict(reconfigure=True))
# for i in range(180):
# env.step(env.action_space.sample())
env.step(env.action_space.sample())
env.step(env.action_space.sample())
print("partial reset")
env.reset(options=dict(env_idx=[0]))
# for i in range(50):
# env.step(env.action_space.sample())
for i in range(60):
# env.step(env.action_space.sample())
# env.step(env.action_space.sample())
# print("partial reset")
# env.reset(options=dict(env_idx=[0]))
for i in range(50):
env.step(env.action_space.sample())
# for i in range(60):
# env.step(env.action_space.sample())
print("prep close")
env.close()
# import h5py
Expand Down

0 comments on commit cce5910

Please sign in to comment.