Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
reginald-mclean committed Nov 6, 2024
1 parent b677a18 commit 279136c
Showing 1 changed file with 56 additions and 48 deletions.
104 changes: 56 additions & 48 deletions metaworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,33 +573,37 @@ def _ml_bench_vector_entry_point(
)

register(
id=f"Meta-World/MT1",
entry_point=lambda env_name, seed=None, vector_strategy='sync': _mt_bench_vector_entry_point(env_name, vector_strategy, seed),
id="Meta-World/MT1",
entry_point=lambda env_name, seed=None, vector_strategy="sync": _mt_bench_vector_entry_point(
env_name, vector_strategy, seed
),
kwargs={},
)

for split in ["train", "test"]:
register(
id=f"Meta-World/ML1-{split}",
vector_entry_point=lambda env_name, vector_strategy='sync', seed=None, *args, _split=split, **kwargs: _ml_bench_vector_entry_point(
ml_bench=env_name,
split=_split,
vector_strategy=vector_strategy,
seed=seed,
*args,
**kwargs),
)
vector_entry_point=lambda env_name, vector_strategy="sync", seed=None, *args, **kwargs: _ml_bench_vector_entry_point(
env_name, # positional arguments
split,
vector_strategy,
seed,
*args,
**kwargs,
),
kwargs={},
)

register(
id=f"Meta-World/goal_hidden",
id="Meta-World/goal_hidden",
entry_point=lambda env_name, seed: ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[env_name]( # type: ignore
seed=seed
),
kwargs={},
)

register(
id=f"Meta-World/goal_observable",
id="Meta-World/goal_observable",
entry_point=lambda env_name, seed: ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[env_name]( # type: ignore
seed=seed
),
Expand All @@ -609,31 +613,32 @@ def _ml_bench_vector_entry_point(
for mt_bench in ["MT10", "MT50"]:
register(
id=f"Meta-World/{mt_bench}",
vector_entry_point=lambda vector_strategy='sync', seed=None, use_one_hot=False, *args, _mt_bench=mt_bench, **kwargs: _mt_bench_vector_entry_point(
mt_bench=_mt_bench,
vector_strategy=vector_strategy,
seed=seed,
use_one_hot=use_one_hot,
*args,
**kwargs),
vector_entry_point=lambda vector_strategy="sync", seed=None, use_one_hot=False, *args, _mt_bench=mt_bench, **kwargs: _mt_bench_vector_entry_point(
_mt_bench, # positional arguments
vector_strategy,
seed,
use_one_hot,
*args,
**kwargs,
),
kwargs={},
)


for ml_bench in ["ML10", "ML45"]:
for split in ["train", "test"]:
register(
id=f"Meta-World/{ml_bench}-{split}",
vector_entry_point=lambda vector_strategy='sync', seed=None, *args,_split=split, _ml_bench=ml_bench, **kwargs: _ml_bench_vector_entry_point(
ml_bench=_ml_bench,
split=_split,
vector_strategy=vector_strategy,
seed=seed,
*args,
**kwargs),
id=f"Meta-World/{ml_bench}-{split}", # Fixed f-string
vector_entry_point=lambda vector_strategy="sync", seed=None, *args, _ml_bench=ml_bench, _split=split, **kwargs: _ml_bench_vector_entry_point(
_ml_bench,
_split,
vector_strategy,
seed,
*args,
**kwargs,
),
kwargs={},
)


def _custom_mt_vector_entry_point(
vector_strategy: str,
envs_list: list[str],
Expand All @@ -648,29 +653,30 @@ def _custom_mt_vector_entry_point(
)
return (
vectorizer( # type: ignore
[
partial( # type: ignore
make_mt_envs,
env_name,
num_tasks=len(envs_list),
env_id=idx,
seed=None if not seed else seed + idx,
use_one_hot=use_one_hot,
*args,
**lamb_kwargs,
)
for idx, env_name in enumerate(envs_list)
]
),
[
partial( # type: ignore
make_mt_envs,
env_name,
num_tasks=len(envs_list),
env_id=idx,
seed=None if not seed else seed + idx,
use_one_hot=use_one_hot,
*args,
**lamb_kwargs,
)
for idx, env_name in enumerate(envs_list)
]
),
)

register(
id=f"Meta-World/custom-mt-envs",
vector_entry_point=lambda vector_strategy, envs_list, seed=None, use_one_hot=False, num_envs=None: _custom_mt_vector_entry_point(vector_strategy, envs_list, seed, use_one_hot, num_envs),
id="Meta-World/custom-mt-envs",
vector_entry_point=lambda vector_strategy, envs_list, seed=None, use_one_hot=False, num_envs=None: _custom_mt_vector_entry_point(
vector_strategy, envs_list, seed, use_one_hot, num_envs
),
kwargs={},
)


def _custom_ml_vector_entry_point(
vector_strategy: str,
train_envs: list[str],
Expand All @@ -690,8 +696,10 @@ def _custom_ml_vector_entry_point(
)

register(
id=f"Meta-World/custom-ml-envs",
vector_entry_point=lambda vector_strategy, train_envs, test_envs, meta_batch_size=20, seed=None, num_envs=None: _custom_ml_vector_entry_point(vector_strategy, train_envs, test_envs, meta_batch_size, seed, num_envs),
id="Meta-World/custom-ml-envs",
vector_entry_point=lambda vector_strategy, train_envs, test_envs, meta_batch_size=20, seed=None, num_envs=None: _custom_ml_vector_entry_point(
vector_strategy, train_envs, test_envs, meta_batch_size, seed, num_envs
),
kwargs={},
)

Expand Down

0 comments on commit 279136c

Please sign in to comment.