-
-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom space serialization tutorial #151
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good, thank you a lot!
I left some minor comments and a major one about the choice of the space.
Also, we may want to link the future page on the documentation in this error:
Minari/minari/serialization.py
Line 14 in c43a612
raise NotImplementedError(f"No serialization method available for {space}") |
========================================= | ||
""" | ||
# %%% | ||
# In this tutorial you'll learn how to serialize a custom Gym observation space and use that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gym -> Gymnasium
print(f"Observation space: {env.observation_space}") | ||
print(f"Observation space: {env.action_space}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it is good to add the outputs of this on the doc page?
) | ||
|
||
# %% [markdown] | ||
# Now that we have a custom observation space we need to define functions that properly serialize it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
functions that properly serialize it
-> functions that serialize and deserialize it
or a function that serialize it
# | ||
# When creating a Minari dataset, the space data gets `serialized <https://minari.farama.org/content/dataset_standards/#space-serialization>`_ | ||
# to a JSON format when saving to disk. The `serialize_space <https://github.com/Farama-Foundation/Minari/blob/main/minari/serialization.py#L13C5-L13C20>`_ | ||
# function takes care of this conversion for various supported Gym spaces. To enable serialization for a custom space we can register 2 new functions that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gym -> Gymnasium
# delete the test dataset if it already exists | ||
local_datasets = minari.list_local_datasets() | ||
if dataset_id in local_datasets: | ||
minari.delete_dataset(dataset_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would place it at the end of the tutorial (just remove the dataset without if, with a comment saying you remove it as you created it only for the purpose of the tutorial) to avoid interrupting the flow with something not really related
custom_observation_space = CartPoleObservationSpace( | ||
low=np.array([0, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38], dtype=np.float32), | ||
high=np.array([4.8, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38], dtype=np.float32), | ||
shape=(4,), | ||
dtype=np.float32 | ||
) | ||
|
||
|
||
class CustomSpaceStepDataCallback(StepDataCallback): | ||
def __call__(self, env, **kwargs): | ||
step_data = super().__call__(env, **kwargs) | ||
step_data["observations"][0] = max(step_data["observations"][0], 0) | ||
return step_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if doing it with a StepDataCallback is the best way for the user to understand (it also introduces the StepDataCallback that may not be known by the user). What about creating a very simple environment (subclassing CartPole), defining step, and reset functions?
class CartPoleObservationSpace(Space): | ||
def __init__( | ||
self, | ||
low: NDArray[Any], | ||
high: NDArray[Any], | ||
shape: Sequence[int], | ||
dtype: type[np.floating[Any]] = np.float32, | ||
): | ||
self.low = np.full(shape, low, dtype=dtype) | ||
self.high = np.full(shape, high, dtype=dtype) | ||
super().__init__(shape, dtype) | ||
|
||
def sample(self) -> NDArray[Any]: | ||
"""Sample a random observation according to low/high boundaries""" | ||
sample = np.empty(self.shape) | ||
sample = self.np_random.uniform(low=self.low, high=self.high, size=self.shape) | ||
return sample.astype(self.dtype) | ||
|
||
def contains(self, x: Any) -> bool: | ||
"""Return boolean specifying if x is a valid member of this space""" | ||
if not isinstance(x, np.ndarray): | ||
try: | ||
x = np.asarray(x, dtype=self.dtype) | ||
except (ValueError, TypeError): | ||
return False | ||
|
||
return bool( | ||
np.can_cast(x.dtype, self.dtype) | ||
and x.shape == self.shape | ||
and np.all(x >= self.low) | ||
and np.all(x <= self.high) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is basically the Box space; it can be confusing for the user. I am wondering if we can have a custom space that makes more sense and it is still simple.
For example, minigrid
has the mission space, it is basically a TextSpace where the text is sampled from a set. Do you have other ideas? I will think a bit more about it and let you know if I have something better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I thought about making a custom action space instead since that's a little simpler to implement (ex: maybe an action space that only allows moving left) but thought that might not get the idea across of a good example for when someone should create a custom space. I'll take a look at the minigrid
environment and see if there's a modification that makes sense and is simple to implement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify: what I meant is that the MissionSpace of minigrid is not serializable, so this can be a use case for the tutorial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just two minor changes, then it is ready to be merged. Thanks!
# supported and if try to serialize it with: | ||
|
||
# %% | ||
serialize_space(env.observation_space['mission']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to not create code that throws error, as we aim to add automatic testing for docs (so the file should run). You can catch the error and print it or just use text to explain it.
# %% [markdown] | ||
# To get an idea of what the serialization is doing under the hood we can directly call | ||
# the ``serialize_custom_space`` function we defined earlier and see the JSON string it returns. | ||
|
||
# %% | ||
serialize_custom_space(env.observation_space['mission']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
show instead of this block that actually you can reload the dataset
del dataset
dataset = minari.load_dataset(dataset_id)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And the dataset.observation_space is correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
Description
This PR adds a new tutorial that goes over serializing custom Gym spaces and shows an example of how to apply a custom observation space to the CartPole environment.
Type of change
Screenshots
Screenshot of tutorial page:
Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up)pytest -v
and no errors are present.pytest -v
has generated that are related to my code to the best of my knowledge.