Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration to gymnasium 1.0 #109

Merged
merged 13 commits into from
Oct 16, 2024
36 changes: 10 additions & 26 deletions .github/workflows/build-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
#
# derived from https://github.com/Farama-Foundation/PettingZoo/blob/e230f4d80a5df3baf9bd905149f6d4e8ce22be31/.github/workflows/build-publish.yml
name: build-publish
name: Build artifact for PyPI

on:
push:
Expand All @@ -16,35 +16,18 @@ on:

jobs:
build-wheels:
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- os: ubuntu-latest
python: 38
platform: manylinux_x86_64
- os: ubuntu-latest
python: 39
platform: manylinux_x86_64
- os: ubuntu-latest
python: 310
platform: manylinux_x86_64
- os: ubuntu-latest
python: 311
platform: manylinux_x86_64
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- uses: actions/checkout@v4
- uses: actions/setup-python@v5

- name: Install dependencies
run: python -m pip install --upgrade pip setuptools build
run: pipx install build
- name: Build sdist and wheels
run: python -m build
run: pyproject-build
- name: Store wheels
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
path: dist

Expand All @@ -55,10 +38,11 @@ jobs:
if: github.event_name == 'release' && github.event.action == 'published'
steps:
- name: Download dists
uses: actions/download-artifact@v4.1.7
uses: actions/download-artifact@v4
with:
name: artifact
path: dist

- name: Publish
uses: pypa/gh-action-pypi-publish@release/v1
with:
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- run: python -m pip install pre-commit
- run: python -m pre_commit --version
- run: python -m pre_commit install
- run: python -m pre_commit run --all-files
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- run: pipx install pre-commit
- run: pre-commit run --all-files
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -18,13 +18,13 @@ repos:
- id: detect-private-key
- id: debug-statements
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
rev: v2.3.0
hooks:
- id: codespell
args:
- --ignore-words-list=reacher,ure,referenc,wile,mor,ser,esr,nowe
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.1.1
hooks:
- id: flake8
args:
Expand All @@ -35,16 +35,16 @@ repos:
- --show-source
- --statistics
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.18.0
hooks:
- id: pyupgrade
args: ["--py37-plus"]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/python/black
rev: 23.1.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/pydocstyle
Expand Down
2 changes: 1 addition & 1 deletion examples/envelope_minecart.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.multi_policy.envelope.envelope import Envelope

Expand Down
2 changes: 1 addition & 1 deletion examples/eupg_fishwood.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import mo_gymnasium as mo_gym
import numpy as np
import torch as th
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import eval_mo_reward_conditioned
from morl_baselines.single_policy.esr.eupg import EUPG
Expand Down
2 changes: 1 addition & 1 deletion examples/mo_q_learning_DST.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import eval_mo
from morl_baselines.common.scalarization import tchebicheff
Expand Down
2 changes: 1 addition & 1 deletion examples/mp_mo_q_learning_DST.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.scalarization import tchebicheff
from morl_baselines.multi_policy.multi_policy_moqlearning.mp_mo_q_learning import (
Expand Down
2 changes: 1 addition & 1 deletion examples/pcn_minecart.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.multi_policy.pcn.pcn import PCN

Expand Down
2 changes: 1 addition & 1 deletion examples/pgmorl_halfcheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
algo.train(
total_timesteps=int(5e6),
eval_env=make_env(env_id, 42, 0, "PGMORL_eval_env", gamma=0.995)(),
ref_point=np.array([0.0, -5.0]),
ref_point=np.array([-100.0, -100.0]),
known_pareto_front=None,
)
env = make_env(env_id, 422, 1, "PGMORL_test", gamma=0.995)() # idx != 0 to avoid taking videos
Expand Down
13 changes: 7 additions & 6 deletions experiments/benchmark/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import numpy as np
import requests
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gymnasium.wrappers import FlattenObservation
from gymnasium.wrappers.record_video import RecordVideo
from mo_gymnasium.utils import MORecordEpisodeStatistics
from gymnasium.wrappers import FlattenObservation, RecordVideo
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import seed_everything
from morl_baselines.common.experiments import (
Expand Down Expand Up @@ -90,13 +89,15 @@ def autotag() -> str:
git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
try:
# try finding the pull request number on github
prs = requests.get(f"https://api.github.com/search/issues?q=repo:LucasAlegre/morl-baselines+is:pr+{git_commit}")
prs = requests.get(
f"https://api.github.com/search/issues?q=repo:LucasAlegre/morl-baselines+is:pr+{git_commit}" # noqa
)
if prs.status_code == 200:
prs = prs.json()
if len(prs["items"]) > 0:
pr = prs["items"][0]
pr_number = pr["number"]
wandb_tag += f",pr-{pr_number}"
wandb_tag += f",pr-{pr_number}" # noqa
print(f"identified github pull request: {pr_number}")
except Exception as e:
print(e)
Expand Down Expand Up @@ -165,7 +166,7 @@ def wrap_mario(env):
TimeLimit,
)
from mo_gymnasium.envs.mario.joypad_space import JoypadSpace
from mo_gymnasium.utils import MOMaxAndSkipObservation
from mo_gymnasium.wrappers import MOMaxAndSkipObservation

env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = MOMaxAndSkipObservation(env, skip=4)
Expand Down
2 changes: 1 addition & 1 deletion experiments/hyperparameter_search/launch_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import wandb
import yaml
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import seed_everything
from morl_baselines.common.experiments import (
Expand Down
3 changes: 1 addition & 2 deletions morl_baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""MORL-Baselines contains various MORL algorithms and utility functions."""


__version__ = "1.0.0"
__version__ = "1.1.0"
1 change: 1 addition & 0 deletions morl_baselines/common/buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Replay buffer for multi-objective reinforcement learning."""

import numpy as np
import torch as th

Expand Down
8 changes: 6 additions & 2 deletions morl_baselines/common/diverse_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Diverse Experience Replay Buffer. Code extracted from https://github.com/axelabels/DynMORL."""

from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -154,7 +155,7 @@ def update(self, idx: int, p, tree_id=None):
Keyword Arguments:
tree_id {object} -- Tree to be updated (default: {None})
"""
if type(p) == dict:
if isinstance(p, dict):
for k in p:
self.update(idx, p[k], k)
return
Expand Down Expand Up @@ -476,7 +477,10 @@ def get_data(self, include_indices: bool = False):
Returns:
The data
"""
all_data = list(np.arange(self.capacity) + self.capacity - 1), list(self.tree.data)
all_data = (
list(np.arange(self.capacity) + self.capacity - 1),
list(self.tree.data),
)
indices = []
data = []
for i, d in zip(all_data[0], all_data[1]):
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities related to evaluation."""

import os
import random
from typing import List, Optional, Tuple
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/experiments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common experiment utilities."""

import argparse

from morl_baselines.multi_policy.capql.capql import CAPQL
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Probabilistic ensemble of neural networks."""

import os

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/model_based/tabular_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tabular dynamics model S_{t+1}, R_t ~ m(.,.|s,a) ."""

import random

import numpy as np
Expand Down
53 changes: 44 additions & 9 deletions morl_baselines/common/model_based/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility functions for the model."""

from typing import Tuple

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -34,7 +35,7 @@ def termination_fn_dst(obs, act, next_obs):


def termination_fn_mountaincar(obs, act, next_obs):
"""Termination function of mountin car."""
"""Termination function of mountain car."""
assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
position = next_obs[:, 0]
velocity = next_obs[:, 1]
Expand Down Expand Up @@ -147,16 +148,29 @@ def step(
var_obs = var_obs[0]
var_rewards = var_rewards[0]

info = {"uncertainty": uncertainties, "var_obs": var_obs, "var_rewards": var_rewards}
info = {
"uncertainty": uncertainties,
"var_obs": var_obs,
"var_rewards": var_rewards,
}

# info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev}
return next_obs, rewards, terminals, info


def visualize_eval(
agent, env, model=None, w=None, horizon=10, init_obs=None, compound=True, deterministic=False, show=False, filename=None
agent,
env,
model=None,
w=None,
horizon=10,
init_obs=None,
compound=True,
deterministic=False,
show=False,
filename=None,
):
"""Generates a plot of the evolution of the state, reward and model predicitions ove time.
"""Generates a plot of the evolution of the state, reward and model predictions over time.

Args:
agent: agent to be evaluated
Expand Down Expand Up @@ -213,10 +227,16 @@ def visualize_eval(
acts = F.one_hot(acts, num_classes=env.action_space.n).squeeze(1)
for step in range(len(real_obs)):
if compound or step == 0:
obs, r, done, info = model_env.step(th.tensor(obs).to(agent.device), acts[step], deterministic=deterministic)
obs, r, done, info = model_env.step(
th.tensor(obs).to(agent.device),
acts[step],
deterministic=deterministic,
)
else:
obs, r, done, info = model_env.step(
th.tensor(real_obs[step - 1]).to(agent.device), acts[step], deterministic=deterministic
th.tensor(real_obs[step - 1]).to(agent.device),
acts[step],
deterministic=deterministic,
)
model_obs.append(obs.copy())
model_obs_stds.append(np.sqrt(info["var_obs"].copy()))
Expand All @@ -240,11 +260,26 @@ def visualize_eval(
axs[i].set_ylabel(f"Reward {i - obs_dim}")
axs[i].grid(alpha=0.25)
if w is not None:
axs[i].plot(x, [real_vec_rewards[step][i - obs_dim] for step in x], label="Environment", color="black")
axs[i].plot(
x,
[real_vec_rewards[step][i - obs_dim] for step in x],
label="Environment",
color="black",
)
else:
axs[i].plot(x, [real_rewards[step] for step in x], label="Environment", color="black")
axs[i].plot(
x,
[real_rewards[step] for step in x],
label="Environment",
color="black",
)
if model is not None:
axs[i].plot(x, [model_rewards[step][i - obs_dim] for step in x], label="Model", color="blue")
axs[i].plot(
x,
[model_rewards[step][i - obs_dim] for step in x],
label="Model",
color="blue",
)
axs[i].fill_between(
x,
[model_rewards[step][i - obs_dim] + model_rewards_stds[step][i - obs_dim] for step in x],
Expand Down
3 changes: 2 additions & 1 deletion morl_baselines/common/morl_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MORL algorithm base classes."""

import os
import time
from abc import ABC, abstractmethod
Expand All @@ -11,7 +12,7 @@
import torch.nn
import wandb
from gymnasium import spaces
from mo_gymnasium.utils import MOSyncVectorEnv
from mo_gymnasium.wrappers.vector import MOSyncVectorEnv

from morl_baselines.common.evaluation import (
eval_mo_reward_conditioned,
Expand Down
Loading
Loading