diff --git a/metaworld/__init__.py b/metaworld/__init__.py index cc2b5b51..833ab5d1 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -573,25 +573,29 @@ def _ml_bench_vector_entry_point( ) register( - id=f"Meta-World/MT1", - entry_point=lambda env_name, seed=None, vector_strategy='sync': _mt_bench_vector_entry_point(env_name, vector_strategy, seed), + id="Meta-World/MT1", + entry_point=lambda env_name, seed=None, vector_strategy="sync": _mt_bench_vector_entry_point( + env_name, vector_strategy, seed + ), kwargs={}, ) for split in ["train", "test"]: register( id=f"Meta-World/ML1-{split}", - vector_entry_point=lambda env_name, vector_strategy='sync', seed=None, *args, _split=split, **kwargs: _ml_bench_vector_entry_point( - ml_bench=env_name, - split=_split, - vector_strategy=vector_strategy, - seed=seed, - *args, - **kwargs), - ) + vector_entry_point=lambda env_name, vector_strategy="sync", seed=None, *args, **kwargs: _ml_bench_vector_entry_point( + env_name, # positional arguments + split, + vector_strategy, + seed, + *args, + **kwargs, + ), + kwargs={}, + ) register( - id=f"Meta-World/goal_hidden", + id="Meta-World/goal_hidden", entry_point=lambda env_name, seed: ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[env_name]( # type: ignore seed=seed ), @@ -599,7 +603,7 @@ def _ml_bench_vector_entry_point( ) register( - id=f"Meta-World/goal_observable", + id="Meta-World/goal_observable", entry_point=lambda env_name, seed: ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[env_name]( # type: ignore seed=seed ), @@ -609,31 +613,32 @@ def _ml_bench_vector_entry_point( for mt_bench in ["MT10", "MT50"]: register( id=f"Meta-World/{mt_bench}", - vector_entry_point=lambda vector_strategy='sync', seed=None, use_one_hot=False, *args, _mt_bench=mt_bench, **kwargs: _mt_bench_vector_entry_point( - mt_bench=_mt_bench, - vector_strategy=vector_strategy, - seed=seed, - use_one_hot=use_one_hot, - *args, - **kwargs), + vector_entry_point=lambda vector_strategy="sync", seed=None, use_one_hot=False, *args, _mt_bench=mt_bench, **kwargs: _mt_bench_vector_entry_point( + _mt_bench, # positional arguments + vector_strategy, + seed, + use_one_hot, + *args, + **kwargs, + ), kwargs={}, ) - for ml_bench in ["ML10", "ML45"]: for split in ["train", "test"]: register( - id=f"Meta-World/{ml_bench}-{split}", - vector_entry_point=lambda vector_strategy='sync', seed=None, *args,_split=split, _ml_bench=ml_bench, **kwargs: _ml_bench_vector_entry_point( - ml_bench=_ml_bench, - split=_split, - vector_strategy=vector_strategy, - seed=seed, - *args, - **kwargs), + id=f"Meta-World/{ml_bench}-{split}", # Fixed f-string + vector_entry_point=lambda vector_strategy="sync", seed=None, *args, _ml_bench=ml_bench, _split=split, **kwargs: _ml_bench_vector_entry_point( + _ml_bench, + _split, + vector_strategy, + seed, + *args, + **kwargs, + ), + kwargs={}, ) - def _custom_mt_vector_entry_point( vector_strategy: str, envs_list: list[str], @@ -648,29 +653,30 @@ def _custom_mt_vector_entry_point( ) return ( vectorizer( # type: ignore - [ - partial( # type: ignore - make_mt_envs, - env_name, - num_tasks=len(envs_list), - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(envs_list) - ] - ), + [ + partial( # type: ignore + make_mt_envs, + env_name, + num_tasks=len(envs_list), + env_id=idx, + seed=None if not seed else seed + idx, + use_one_hot=use_one_hot, + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate(envs_list) + ] + ), ) register( - id=f"Meta-World/custom-mt-envs", - vector_entry_point=lambda vector_strategy, envs_list, seed=None, use_one_hot=False, num_envs=None: _custom_mt_vector_entry_point(vector_strategy, envs_list, seed, use_one_hot, num_envs), + id="Meta-World/custom-mt-envs", + vector_entry_point=lambda vector_strategy, envs_list, seed=None, use_one_hot=False, num_envs=None: _custom_mt_vector_entry_point( + vector_strategy, envs_list, seed, use_one_hot, num_envs + ), kwargs={}, ) - def _custom_ml_vector_entry_point( vector_strategy: str, train_envs: list[str], @@ -690,8 +696,10 @@ def _custom_ml_vector_entry_point( ) register( - id=f"Meta-World/custom-ml-envs", - vector_entry_point=lambda vector_strategy, train_envs, test_envs, meta_batch_size=20, seed=None, num_envs=None: _custom_ml_vector_entry_point(vector_strategy, train_envs, test_envs, meta_batch_size, seed, num_envs), + id="Meta-World/custom-ml-envs", + vector_entry_point=lambda vector_strategy, train_envs, test_envs, meta_batch_size=20, seed=None, num_envs=None: _custom_ml_vector_entry_point( + vector_strategy, train_envs, test_envs, meta_batch_size, seed, num_envs + ), kwargs={}, )