Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try multiprocessing on dataset docs generation #254

Merged
merged 29 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/docs-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ jobs:
with:
python-version: '3.9'

- name: Install dependencies
run: pip install -r docs/requirements.txt

- name: Install Minari
run: pip install .[all,testing]

- name: Build
- name: Install docs dependencies
run: pip install -r docs/requirements.txt

- name: Build documentation
run: sphinx-build -b dirhtml -v docs _build

- name: Run markdown documentation tests
run: pytest tests/test_docs.py

- name: Install tutorial dependencies
run: pip install -r docs/tutorials/requirements.txt

- name: Run tutorial documentation tests
run: pytest --nbmake docs/tutorials/**/*.ipynb --nbmake-timeout=600
41 changes: 33 additions & 8 deletions docs/_scripts/gen_dataset_md.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import logging
import os
import pathlib
import shutil
import subprocess
import sys
import venv
import warnings
from collections import defaultdict
from multiprocessing import Pool
from typing import Dict, OrderedDict

import generate_gif
Expand All @@ -28,6 +32,8 @@ def _md_table(table_dict: Dict[str, str]) -> str:


def main():
os.environ["TQDM_DISABLE"] = "1"

remote_datasets = minari.list_remote_datasets(latest_version=True)
for i, (dataset_id, metadata) in enumerate(remote_datasets.items()):
namespace, dataset_name, version = parse_dataset_id(dataset_id)
Expand Down Expand Up @@ -58,31 +64,48 @@ def main():
"display_name": versioned_name,
}

_generate_dataset_page(dataset_id, metadata)

for namespace, content in NAMESPACE_CONTENTS.items():
_generate_namespace_page(namespace, content)

with Pool(processes=16) as pool:
pool.map(_generate_dataset_page, remote_datasets.items())

del os.environ["TQDM_DISABLE"]


def _generate_dataset_page(dataset_id, metadata):
def _generate_dataset_page(arg):
dataset_id, metadata = arg
_, dataset_name, version = parse_dataset_id(dataset_id)
versioned_name = gen_dataset_id(None, dataset_name, version)

description = metadata.get("description")
try:
requirements = metadata.get("requirements", [])
for req in requirements:
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
venv.create(dataset_id, with_pip=True)

requirements = [
"minari[gcs,hdf5] @ git+https://github.com/Farama-Foundation/Minari.git",
"imageio",
"absl-py",
]
requirements.extend(metadata.get("requirements", []))
pip_path = pathlib.Path(dataset_id) / "bin" / "pip"
req_args = [pip_path, "install", *requirements]
subprocess.check_call(req_args, stdout=subprocess.DEVNULL)
logging.info(f"Installed requirements for {dataset_id}")

minari.download_dataset(dataset_id)

python_path = pathlib.Path(dataset_id) / "bin" / "python"
subprocess.check_call(
[
sys.executable,
python_path,
generate_gif.__file__,
f"--dataset_id={dataset_id}",
f"--path={DATASET_FOLDER}",
]
)
minari.delete_dataset(dataset_id)
shutil.rmtree(dataset_id)
img_link_str = f'<img src="../{versioned_name}.gif" width="200" style="display: block; margin:0 auto"/>'
except Exception as e:
warnings.warn(f"Failed to generate gif for {dataset_id}: {e}")
Expand Down Expand Up @@ -204,6 +227,8 @@ def _generate_dataset_page(dataset_id, metadata):
file.write(content)
file.close()

logging.info(f"Generated dataset page for {dataset_id}")


def _generate_namespace_page(namespace: str, namespace_content):
namespace_path = DATASET_FOLDER.joinpath(namespace)
Expand Down
9 changes: 1 addition & 8 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
torch
torchrl
sphinx
sphinx-autobuild
sphinx_github_changelog
myst-parser
matplotlib
gymnasium-robotics>=1.2.1
minigrid>=2.2.0
rl_zoo3>=2.0.0
imageio>=2.14.1
nbmake
git+https://github.com/sphinx-gallery/sphinx-gallery.git@4006662c8c1984453a247dc6d3df6260e5b00f4b#egg=sphinx_gallery
git+https://github.com/Farama-Foundation/Celshast#egg=furo
torchrl
pyvirtualdisplay
absl-py
moviepy
7 changes: 7 additions & 0 deletions docs/tutorials/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torch
torchrl
matplotlib
gymnasium-robotics>=1.2.1
minigrid>=2.2.0
rl_zoo3>=2.0.0
imageio>=2.14.1
4 changes: 2 additions & 2 deletions docs/tutorials/using_datasets/IQL_torchrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# %%
# .. code-block:: bash
#
# ! pip install "torchrl-nightly>=2023.12.30"
# ! pip install torchrl
# ! pip install matplotlib minari gymnasium-robotics

# %%
Expand Down Expand Up @@ -329,7 +329,7 @@
"max": action_spec.space.high,
"tanh_loc": False,
},
default_interaction_type=ExplorationType.MODE,
default_interaction_type=ExplorationType.DETERMINISTIC,
)

# %%
Expand Down
2 changes: 1 addition & 1 deletion minari/storage/hosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def blob_to_metadata(blob):

cloud_storage = get_cloud_storage()
blobs = cloud_storage.list_blobs()
with ThreadPoolExecutor(max_workers=32) as executor:
with ThreadPoolExecutor(max_workers=10) as executor:
remote_metadatas = executor.map(blob_to_metadata, blobs)

remote_datasets = {}
Expand Down
6 changes: 3 additions & 3 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
from gymnasium.core import ActType, ObsType
from gymnasium.envs.registration import EnvSpec
from gymnasium.error import NameNotFound
from gymnasium.wrappers import RecordEpisodeStatistics # type: ignore

from minari.data_collector.episode_buffer import EpisodeBuffer
Expand Down Expand Up @@ -467,10 +466,11 @@ def get_normalized_score(dataset: MinariDataset, returns: np.ndarray) -> np.ndar
def get_env_spec_dict(env_spec: EnvSpec) -> Dict[str, str]:
"""Create dict of the environment specs, including observation and action space."""
try:
env = gym.make(env_spec.id)
env = gym.make(env_spec)
action_space_table = env.action_space.__repr__().replace("\n", "")
observation_space_table = env.observation_space.__repr__().replace("\n", "")
except NameNotFound:
except Exception as e:
warnings.warn(f"Failed to make env {env_spec.id}, {e}")
action_space_table, observation_space_table = None, None

md_dict = {"ID": env_spec.id}
Expand Down
Loading