Skip to content

Commit

Permalink
changed dataset_size method dir, changed the integration with collect…
Browse files Browse the repository at this point in the history
…or env, added dataset
  • Loading branch information
Shreyans Jain authored and Shreyans Jain committed Sep 22, 2023
1 parent ea1de9a commit 8268a00
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 47 deletions.
3 changes: 1 addition & 2 deletions minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from gymnasium.envs.registration import EnvSpec

from minari.data_collector import DataCollectorV0
from minari.dataset.minari_storage import MinariStorage, PathLike
from minari.helpers import get_dataset_size
from minari.dataset.minari_storage import MinariStorage, PathLike, get_dataset_size


DATASET_ID_RE = re.compile(
Expand Down
34 changes: 34 additions & 0 deletions minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import gymnasium as gym
import h5py
import numpy as np
from google.cloud import storage # pyright: ignore [reportGeneralTypeIssues]
from gymnasium.envs.registration import EnvSpec
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import Version

from minari.data_collector import DataCollectorV0
from minari.serialization import deserialize_space
from minari.storage.datasets_root_dir import get_dataset_path


# Use importlib due to circular import when: "from minari import __version__"
Expand Down Expand Up @@ -354,3 +356,35 @@ def clear_episode_buffer(episode_buffer: Dict, episode_group: h5py.Group) -> h5p
episode_group.create_dataset(key, data=data, chunks=True)

return episode_group


def get_dataset_size(dataset_id: str):
"""Returns the dataset size in MB.
Args:
dataset_id (str) : name id of Minari Dataset
Returns:
datasize (float): size of the dataset in MB
"""
file_path = get_dataset_path(dataset_id)
data_path = os.path.join(file_path, "data")
datasize_list = []
if os.path.exists(data_path):

for filename in os.listdir(data_path):
if ".hdf5" in filename:
datasize = os.path.getsize(os.path.join(data_path, filename))
datasize_list.append(datasize)

else:
storage_client = storage.Client.create_anonymous_client()
bucket = storage_client.bucket(bucket_name="minari-datasets")

blobs = bucket.list_blobs(prefix=dataset_id)
for blob in blobs:
if ".hdf5" in blob.name:
datasize_list.append(bucket.get_blob(blob.name).size)
datasize = np.round(np.sum(datasize_list) / 1000000, 1)

return datasize
39 changes: 0 additions & 39 deletions minari/helpers.py

This file was deleted.

8 changes: 4 additions & 4 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

from minari import DataCollectorV0
from minari.dataset.minari_dataset import MinariDataset
from minari.dataset.minari_storage import clear_episode_buffer
from minari.helpers import get_dataset_size
from minari.dataset.minari_storage import clear_episode_buffer, get_dataset_size
from minari.serialization import serialize_space
from minari.storage.datasets_root_dir import get_dataset_path

Expand Down Expand Up @@ -628,19 +627,20 @@ def create_dataset_from_collector_env(
"num_episodes_average_score": num_episodes_average_score,
}
)
dataset_size = get_dataset_size(dataset_id)
collector_env.save_to_disk(
data_path,
dataset_metadata={
"dataset_id": str(dataset_id),
"dataset_size": str(dataset_size),
"algorithm_name": str(algorithm_name),
"author": str(author),
"author_email": str(author_email),
"code_permalink": str(code_permalink),
"minari_version": minari_version,
},
)
with h5py.File(data_path, "r+", track_order=True) as file:
file.attrs["dataset_size"] = get_dataset_size(dataset_id)

return MinariDataset(data_path)
else:
raise ValueError(
Expand Down
174 changes: 172 additions & 2 deletions tests/dataset/test_minari_storage.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import copy
import os

import gymnasium as gym
import h5py
import numpy as np
import pytest

from minari import __version__
from minari.dataset.minari_storage import MinariStorage
import minari
from minari import DataCollectorV0, __version__
from minari.dataset.minari_storage import MinariStorage, get_dataset_size
from minari.utils import get_dataset_path
from tests.common import (
check_data_integrity,
check_load_and_delete_dataset,
register_dummy_envs,
)


register_dummy_envs()

file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets")


Expand Down Expand Up @@ -39,3 +51,161 @@ def test_minari_storage_missing_env_module():
MinariStorage(os.path.join(file_path, "dummy-test-v0.hdf5"))

os.remove(os.path.join(file_path, "dummy-test-v0.hdf5"))


@pytest.mark.parametrize(
"dataset_id,env_id",
[
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-text-test-v0", "DummyTextEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"),
],
)
def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id):
"""Test get_dataset_size method for dataset made using create_dataset_from_collector_env method."""
# dataset_id = "cartpole-test-v0"
# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)

env = DataCollectorV0(env)
num_episodes = 100

# Step the environment, DataCollectorV0 wrapper will do the data collection job
env.reset(seed=42)

for episode in range(num_episodes):
terminated = False
truncated = False
while not terminated and not truncated:
action = env.action_space.sample() # User-defined policy function
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
assert not env._buffer[-1]
else:
assert env._buffer[-1]

env.reset()

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_collector_env(
dataset_id=dataset_id,
collector_env=env,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="[email protected]",
)

file_path = get_dataset_path(dataset_id)
data_path = os.path.join(file_path, "data", "main_data.hdf5")
original_dataset_size = os.path.getsize(data_path)
original_dataset_size = np.round(original_dataset_size / 1000000, 1)

assert get_dataset_size(dataset_id) == original_dataset_size

check_data_integrity(dataset._data, dataset.episode_indices)

env.close()

check_load_and_delete_dataset(dataset_id)


@pytest.mark.parametrize(
"dataset_id,env_id",
[
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-text-test-v0", "DummyTextEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"),
],
)
def test_minari_get_dataset_size_from_buffer(dataset_id, env_id):
"""Test get_dataset_size method for dataset made using create_dataset_from_buffers method."""
buffer = []
# dataset_id = "cartpole-test-v0"

# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)

observations = []
actions = []
rewards = []
terminations = []
truncations = []

num_episodes = 10

observation, info = env.reset(seed=42)

# Step the environment, DataCollectorV0 wrapper will do the data collection job
observation, _ = env.reset()
observations.append(observation)
for episode in range(num_episodes):
terminated = False
truncated = False

while not terminated and not truncated:
action = env.action_space.sample() # User-defined policy function
observation, reward, terminated, truncated, _ = env.step(action)
observations.append(observation)
actions.append(action)
rewards.append(reward)
terminations.append(terminated)
truncations.append(truncated)

episode_buffer = {
"observations": copy.deepcopy(observations),
"actions": copy.deepcopy(actions),
"rewards": np.asarray(rewards),
"terminations": np.asarray(terminations),
"truncations": np.asarray(truncations),
}
buffer.append(episode_buffer)

observations.clear()
actions.clear()
rewards.clear()
terminations.clear()
truncations.clear()

observation, _ = env.reset()
observations.append(observation)

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_buffers(
dataset_id=dataset_id,
env=env,
buffer=buffer,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="[email protected]",
)

file_path = get_dataset_path(dataset_id)
data_path = os.path.join(file_path, "data", "main_data.hdf5")
original_dataset_size = os.path.getsize(data_path)
original_dataset_size = np.round(original_dataset_size / 1000000, 1)

assert get_dataset_size(dataset_id) == original_dataset_size

check_data_integrity(dataset._data, dataset.episode_indices)

env.close()

check_load_and_delete_dataset(dataset_id)

0 comments on commit 8268a00

Please sign in to comment.