Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support heterogeneous dicts in infos #250

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions docs/content/basic_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ title: Basic Usage

# Basic Usage

Minari is a standard dataset hosting interface for Offline Reinforcement Learning applications. Minari is compatible with most of the RL environments that follow the Gymnasium API and facilitates Offline RL dataset handling by providing data collection, dataset hosting, and dataset sampling capabilities.
Minari is a standard dataset hosting interface for Offline Reinforcement Learning applications. Minari is compatible
with most of the RL environments that follow the Gymnasium API and facilitates Offline RL dataset handling by providing
data collection, dataset hosting, and dataset sampling capabilities.

## Installation

Expand All @@ -15,7 +17,9 @@ To install the most recent version of the Minari library run this command:
pip install minari
```

This will install the minimum required dependencies. Additional dependencies will be prompted for installation based on your use case. To install all dependencies at once, use:
This will install the minimum required dependencies. Additional dependencies will be prompted for installation based on
your use case. To install all dependencies at once, use:

```bash
pip install "minari[all]"
```
Expand Down Expand Up @@ -53,15 +57,15 @@ minari list remote
│ ... │ ... │ ... │ ... │ ... │
```

To use your own server with Minari, set the `MINARI_REMOTE` environment variable in the format `remote-type://remote-path`. For example, to set up a GCP bucket named `my-datasets`, run the following command:
To use your own server with Minari, set the `MINARI_REMOTE` environment variable in the format
`remote-type://remote-path`. For example, to set up a GCP bucket named `my-datasets`, run the following command:

```bash
export MINARI_REMOTE=gcp://my-datasets
```

Currently, only GCP is supported, but we plan to support other cloud providers in the future.


```{eval-rst}
To download any of the remote datasets into the local storage use the download command:
```
Expand Down Expand Up @@ -95,12 +99,14 @@ In order to use any of the dataset sampling features of Minari we first need to

```python
import minari

dataset = minari.load_dataset('D4RL/door/human-v2')
print("Observation space:", dataset.observation_space)
print("Action space:", dataset.action_space)
print("Total episodes:", dataset.total_episodes)
print("Total steps:", dataset.total_steps)
```

```
Observation space: Box(-inf, inf, (39,), float64)
Action space: Box(-1.0, 1.0, (28,), float32)
Expand Down Expand Up @@ -180,7 +186,6 @@ for episode in dataset:
print(f"EPISODE ID {episode.id}")
```


#### Filter Episodes

```{eval-rst}
Expand Down Expand Up @@ -248,7 +253,6 @@ for _ in range(100):
env.reset()
```


```{eval-rst}

.. note::
Expand All @@ -275,6 +279,7 @@ minari download D4RL/door/expert-v2
minari combine D4RL/door/human-v2 D4RL/door/expert-v2 --dataset-id=D4RL/door/all-v0
minari list local
```

```
Local Minari datasets('/Users/farama/.minari/datasets/')
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
Expand Down Expand Up @@ -308,6 +313,17 @@ env = DataCollector(env, record_infos=True)
In this example, the :class:`minari.DataCollector` wraps the `'CartPole-v1'` environment from Gymnasium. We set ``record_infos=True`` so the wrapper will also collect the returned ``info`` dictionaries to create the dataset. For the full list of arguments, read the :class:`minari.DataCollector` documentation.
```

Infos can be saved as a dictionary of np.arrays or as a list of arbitrary dictionaries by
setting the `infos_format` parameter. Default is dictionary format `infos_format = None or "dict"`:

```python
from minari import DataCollector
import gymnasium as gym

env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, infos_format="list")
```

### Save Dataset

```{eval-rst}
Expand Down Expand Up @@ -358,6 +374,7 @@ Once the dataset has been created we can check if the Minari dataset id appears
```bash
minari list local
```

```
Local Minari datasets('/Users/farama/.minari/datasets/')
┏━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━┓
Expand Down Expand Up @@ -420,7 +437,6 @@ for episode_id in range(total_episodes):
env.add_to_dataset(dataset)
```


## Using Namespaces

```{eval-rst}
Expand Down
18 changes: 9 additions & 9 deletions docs/content/dataset_standards.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ sampled_episodes = dataset.sample_episodes(10)

The `sampled_episodes` variable will be a list of 10 `EpisodeData` elements, each containing episode data. An `EpisodeData` element is a data class consisting of the following fields:

| Field | Type | Description |
| ----------------- | ------------------------------------ | ------------------------------------------------------------- |
| `id` | `int` | ID of the episode. |
| `observations` | `np.ndarray`, `list`, `tuple`, `dict` | Stacked observations for each step including initial observation. |
| `actions` | `np.ndarray`, `list`, `tuple`, `dict` | Stacked actions for each step. |
| `rewards` | `np.ndarray` | Rewards for each step. |
| `terminations` | `np.ndarray` | Terminations for each step. |
| `truncations` | `np.ndarray` | Truncations for each step. |
| `infos` | `dict` | A dictionary containing additional information returned by the environment |
| Field | Type | Description |
| ----------------- |---------------------------------------|------------------------------------------------------------------------------------------------------------------------------------|
| `id` | `int` | ID of the episode. |
| `observations` | `np.ndarray`, `list`, `tuple`, `dict` | Stacked observations for each step including initial observation. |
| `actions` | `np.ndarray`, `list`, `tuple`, `dict` | Stacked actions for each step. |
| `rewards` | `np.ndarray` | Rewards for each step. |
| `terminations` | `np.ndarray` | Terminations for each step. |
| `truncations` | `np.ndarray` | Truncations for each step. |
| `infos` | `dict`, `list` | A dictionary of iterables (e.g. list, array) or list of dictionaries containing additional information returned by the environment |

As mentioned in the `Supported Spaces` section, many different observation and action spaces are supported so the data type for these fields are dependent on the environment being used.

Expand Down
16 changes: 14 additions & 2 deletions minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
data_format: Optional[str] = None,
infos_format: Optional[str] = None,
):
"""Initialize the data collector attributes and create the temporary directory for caching.

Expand All @@ -82,10 +83,12 @@ def __init__(
observation_space (gym.Space): Observation space of the dataset. The default value is the environment observation space.
action_space (gym.Space): Action space of the dataset. The default value is the environment action space.
data_format (str, optional): Data format to store the data in the Minari dataset. If None (defaults), it will use the default format of MinariStorage.
infos_format (str, optional): Format of the infos data. Can be "dict" or "list". If None (defaults), it will use the "dict" format.
"""
super().__init__(env)
self._step_data_callback = step_data_callback()
self._episode_metadata_callback = episode_metadata_callback()
self.infos_format = infos_format or "dict"

self.datasets_path = os.environ.get("MINARI_DATASETS_PATH")
if self.datasets_path is None:
Expand Down Expand Up @@ -160,7 +163,11 @@ def step(
self._buffer = EpisodeBuffer(
id=self._episode_id,
observations=step_data["observation"],
infos=step_data["info"],
infos=(
step_data["info"]
if self.infos_format == "dict"
else [step_data["info"]]
),
)

return obs, rew, terminated, truncated, info
Expand Down Expand Up @@ -201,12 +208,17 @@ def reset(
f"Observation: {step_data['observation']}\nSpace: {self._storage.observation_space}"
)

infos = step_data["info"] if self._record_infos else None
self._buffer = EpisodeBuffer(
id=self._episode_id,
seed=seed,
options=options,
observations=step_data["observation"],
infos=step_data["info"] if self._record_infos else None,
infos=(
infos
if self.infos_format == "dict"
else [infos] if infos is not None else None
),
)
return obs, info

Expand Down
20 changes: 16 additions & 4 deletions minari/data_collector/episode_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ class EpisodeBuffer:
rewards: list = field(default_factory=list)
terminations: list = field(default_factory=list)
truncations: list = field(default_factory=list)
infos: Optional[dict] = None
infos: Optional[Union[dict, list]] = None

def add_step_data(self, step_data: StepData) -> EpisodeBuffer:
def add_step_data(self, step_data: StepData, infos_format=None) -> EpisodeBuffer:
"""Add step data dictionary to episode buffer.

Args:
step_data (StepData): dictionary with data for a single step
infos_format (str): format of the infos data. Can be "dict" or "list"

Returns:
EpisodeBuffer: episode buffer with appended data
Expand Down Expand Up @@ -54,10 +55,21 @@ def _append(data, buffer):
else:
actions = jtu.tree_map(_append, step_data["action"], self.actions)

infos_format = infos_format or "dict"
if self.infos is None:
infos = jtu.tree_map(lambda x: [x], step_data["info"])
infos = (
jtu.tree_map(lambda x: [x], step_data["info"])
if infos_format == "dict"
else [step_data["info"]]
)
else:
infos = jtu.tree_map(_append, step_data["info"], self.infos)
if isinstance(self.infos, dict):
infos = jtu.tree_map(_append, step_data["info"], self.infos)
elif isinstance(self.infos, list):
self.infos.append(step_data["info"])
infos = self.infos
else:
raise ValueError(f"Unexpected type for infos: {type(self.infos)}")

self.rewards.append(step_data["reward"])
self.terminations.append(step_data["termination"])
Expand Down
49 changes: 37 additions & 12 deletions minari/dataset/_storages/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import json
import pathlib
from collections.abc import Iterable as ABCIterable
from itertools import zip_longest
from typing import Any, Dict, Iterable, Optional, Sequence
from typing import Any, Dict, Iterable, List, Optional, Sequence

import gymnasium as gym
import numpy as np
Expand All @@ -18,6 +19,11 @@
)

from minari.data_collector.episode_buffer import EpisodeBuffer
from minari.dataset._storages.serde import (
NumpyEncoder,
deserialize_dict,
serialize_dict,
)
from minari.dataset.minari_storage import MinariStorage


Expand Down Expand Up @@ -86,6 +92,17 @@ def get_episodes(self, episode_indices: Iterable[int]) -> Iterable[dict]:
)

def _to_dict(id, episode):
if "infos" in episode.column_names:
raw_infos = episode["infos"]
if isinstance(raw_infos, pa.lib.StringArray):
infos = decode_info_list(raw_infos)
elif isinstance(raw_infos, pa.lib.StructArray):
infos = _decode_info(episode["infos"])
else:
raise ValueError(f"Unexpected type for infos: {type(raw_infos)}")
else:
infos = None

return {
"id": id,
"observations": _decode_space(
Expand All @@ -95,11 +112,7 @@ def _to_dict(id, episode):
"rewards": np.asarray(episode["rewards"])[:-1],
"terminations": np.asarray(episode["terminations"])[:-1],
"truncations": np.asarray(episode["truncations"])[:-1],
"infos": (
_decode_info(episode["infos"])
if "infos" in episode.column_names
else {}
),
"infos": infos,
}

return map(_to_dict, episode_indices, dataset.to_batches())
Expand All @@ -119,6 +132,7 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
terminations = np.asarray(episode_data.terminations).reshape(-1)
truncations = np.asarray(episode_data.truncations).reshape(-1)
pad = len(observations) - len(rewards)

actions = _encode_space(self._action_space, episode_data.actions, pad=pad)

episode_batch = {
Expand All @@ -130,7 +144,14 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
"truncations": np.pad(truncations, ((0, pad))),
}
if episode_data.infos:
episode_batch["infos"] = _encode_info(episode_data.infos)
if isinstance(episode_data.infos, dict):
episode_batch["infos"] = _encode_info(episode_data.infos)
elif isinstance(episode_data.infos, list):
info_pad = len(observations) - len(episode_data.infos)
episode_batch["infos"] = encode_info_list(
episode_data.infos + [episode_data.infos[-1]] * info_pad
)

episode_batch = pa.RecordBatch.from_pydict(episode_batch)

total_steps += len(rewards)
Expand Down Expand Up @@ -251,6 +272,8 @@ def _encode_info(info: dict):

def _decode_info(values: pa.Array):
nested_dict = {}
if not isinstance(values.type, ABCIterable):
return nested_dict
for i, field in enumerate(values.type):
if isinstance(field, pa.StructArray):
nested_dict[field.name] = _decode_info(values.field(i))
Expand All @@ -264,8 +287,10 @@ def _decode_info(values: pa.Array):
return nested_dict


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
def encode_info_list(info_list: List[Dict[str, Any]]) -> pa.Array:
serialized_list = [serialize_dict(d) for d in info_list]
return pa.array(serialized_list, type=pa.string())
jamartinh marked this conversation as resolved.
Show resolved Hide resolved


def decode_info_list(values: pa.Array) -> List[Dict[str, Any]]:
return [deserialize_dict(item.as_py()) for item in values]
Loading
Loading