From 2c9c0abd092e8d1047cfb2856e696390e4af2c93 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Mon, 31 Jul 2023 20:18:03 -0400 Subject: [PATCH] some polish, removed tester file --- minari/dataset/minari_dataset.py | 10 +- minari/dataset/minari_storage.py | 9 +- tester.py | 262 ------------------------------- 3 files changed, 6 insertions(+), 275 deletions(-) delete mode 100644 tester.py diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 7f69d72e..8705a71b 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -144,18 +144,16 @@ def __init__( total_steps = self._data.total_steps else: total_steps = sum( - self._data.apply( - lambda episode: episode["total_timesteps"], - episode_indices=episode_indices, + self._data.apply( + lambda episode: episode["total_timesteps"], + episode_indices=episode_indices, + ) ) - ) self._episode_indices = episode_indices assert self._episode_indices is not None - - self.spec = MinariDatasetSpec( env_spec=self._data.env_spec, total_episodes=self._episode_indices.size, diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index a508b07d..3d3cae39 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -211,12 +211,9 @@ def update_from_collector_env( file.attrs.modify( "total_episodes", last_episode_id + new_data_total_episodes ) - file.attrs.modify( - "total_steps", self._total_steps - ) + file.attrs.modify("total_steps", self._total_steps) self._total_episodes = int(file.attrs["total_episodes"].item()) - def update_from_buffer(self, buffer: List[dict], data_path: str): additional_steps = 0 with h5py.File(data_path, "a", track_order=True) as file: @@ -254,9 +251,7 @@ def update_from_buffer(self, buffer: List[dict], data_path: str): self._total_episodes = last_episode_id + len(buffer) file.attrs.modify("total_episodes", self._total_episodes) - file.attrs.modify( - "total_steps", self._total_steps - ) + file.attrs.modify("total_steps", self._total_steps) self._total_episodes = int(file.attrs["total_episodes"].item()) diff --git a/tester.py b/tester.py deleted file mode 100644 index 99d8a6bb..00000000 --- a/tester.py +++ /dev/null @@ -1,262 +0,0 @@ -import copy -from collections import OrderedDict -from typing import Dict -import datetime -import random -from operator import itemgetter - -import gymnasium as gym -import numpy as np -import pytest -from gymnasium import spaces -import pickle - -import minari -from minari import DataCollectorV0, MinariDataset -from tests.common import ( - register_dummy_envs, -) - - -NUM_EPISODES = 10000 -EPISODE_SAMPLE_COUNT = 10 - -register_dummy_envs() - - -def test_generate_dataset_with_collector_env(dataset_id, env_id): - """Test DataCollectorV0 wrapper and Minari dataset creation.""" - # dataset_id = "cartpole-test-v0" - # delete the test dataset if it already exists - local_datasets = minari.list_local_datasets() - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - - env = gym.make(env_id) - - env = DataCollectorV0(env) - - # Step the environment, DataCollectorV0 wrapper will do the data collection job - env.reset(seed=42) - - for episode in range(NUM_EPISODES): - terminated = False - truncated = False - while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function - _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - assert not env._buffer[-1] - else: - assert env._buffer[-1] - - env.reset() - - # Create Minari dataset and store locally - dataset = minari.create_dataset_from_collector_env( - dataset_id=dataset_id, - collector_env=env, - algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", - author="WillDudley", - author_email="wdudley@farama.org", - ) - - - -def test_generate_dataset_with_external_buffer(dataset_id, env_id): - """Test create dataset from external buffers without using DataCollectorV0.""" - buffer = [] - # dataset_id = "cartpole-test-v0" - - - env = gym.make(env_id) - - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - - - observation, info = env.reset(seed=42) - - # Step the environment, DataCollectorV0 wrapper will do the data collection job - observation, _ = env.reset() - observations.append(observation) - for episode in range(NUM_EPISODES): - terminated = False - truncated = False - - while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function - observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) - - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() - - observation, _ = env.reset() - observations.append(observation) - - # Create Minari dataset and store locally - dataset = minari.create_dataset_from_buffers( - dataset_id=dataset_id, - env=env, - buffer=buffer, - algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", - author="WillDudley", - author_email="wdudley@farama.org", - ) - - - -def test_generate_dataset_pickle(dataset_id, env_id): - """Test create dataset from external buffers without using DataCollectorV0.""" - buffer = [] - # dataset_id = "cartpole-test-v0" - - - env = gym.make(env_id) - - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - - - observation, info = env.reset(seed=42) - - # Step the environment, DataCollectorV0 wrapper will do the data collection job - observation, _ = env.reset() - observations.append(observation) - for episode in range(NUM_EPISODES): - terminated = False - truncated = False - - while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function - observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) - - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() - - observation, _ = env.reset() - observations.append(observation) - - # Create Minari dataset and store locally with pickle - with open("test.pkl", "wb") as test_file: - pickle.dump(buffer,test_file) - - #with open("test.pkl", "rb") as test_file: - # test = pickle.load(test_file) - - -def test_sample_n_random_episodes_from_minari_dataset(dataset_id): - dataset = minari.load_dataset(dataset_id) - episodes = dataset.sample_episodes(EPISODE_SAMPLE_COUNT) - # print(episodes) - -def test_sample_n_random_episodes_from_pickle_dataset(): - with open("test.pkl", "rb") as test_file: - test = pickle.load(test_file) - - indices = random.sample(range(0,len(test)),EPISODE_SAMPLE_COUNT ) - - result = itemgetter(*indices)(test) - - - -def measure(function, args): - before = datetime.datetime.now() - function(*args) - after = datetime.datetime.now() - return (after-before).total_seconds() - - -if __name__ == "__main__": - - - environment_list = [ - ("cartpole-test-v0", "CartPole-v1"), - ("dummy-dict-test-v0", "DummyDictEnv-v0"), - ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), - ("dummy-text-test-v0", "DummyTextEnv-v0"), - ("dummy-combo-test-v0", "DummyComboEnv-v0"), - ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), - ] - - - measurements = {} - - - - for dataset_id, env_id in environment_list: - - #dataset_id, env_id = ("cartpole-test-v0", "CartPole-v1") - - - # delete the test dataset if it already exists - local_datasets = minari.list_local_datasets() - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - - result = measure(test_generate_dataset_with_collector_env, (dataset_id, env_id)) - print(f"Time to generate {NUM_EPISODES} episodes with {env_id} using test_generate_dataset_with_collector_env: {str(result)}") - - # delete the test dataset if it already exists - local_datasets = minari.list_local_datasets() - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - - - result = measure(test_generate_dataset_with_external_buffer, (dataset_id, env_id)) - print(f"Time to generate {NUM_EPISODES} episodes with {env_id} using test_generate_dataset_with_external_buffer: {str(result)}") - - - - result = measure(test_generate_dataset_pickle, (dataset_id, env_id)) - print(f"Time to generate {NUM_EPISODES} episodes with {env_id} using test_generate_dataset_pickle: {str(result)}") - - result = measure(test_sample_n_random_episodes_from_minari_dataset, (dataset_id,)) - print(f"Time to sample {EPISODE_SAMPLE_COUNT} episodes from {env_id} using test_sample_n_random_episodes_from_minari_dataset: {str(result)}") - - - result = measure(test_sample_n_random_episodes_from_pickle_dataset, ()) - print(f"Time to sample {EPISODE_SAMPLE_COUNT} episodes from {env_id} test_sample_n_random_episodes_from_pickle_dataset: {str(result)}") - \ No newline at end of file