Skip to content

Commit

Permalink
Merge pull request #70 from LucasAlegre/69-pql
Browse files Browse the repository at this point in the history
Implement a fix to #69
  • Loading branch information
ffelten authored Oct 12, 2023
2 parents b7f371f + d236724 commit d5bc675
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions morl_baselines/multi_policy/pareto_q_learning/pql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pareto Q-Learning."""
import numbers
from typing import Callable, List, Optional

import gymnasium as gym
Expand Down Expand Up @@ -59,10 +60,28 @@ def __init__(
# Algorithm setup
self.ref_point = ref_point

self.num_actions = self.env.action_space.n
low_bound = self.env.observation_space.low
high_bound = self.env.observation_space.high
self.env_shape = (high_bound[0] - low_bound[0] + 1, high_bound[1] - low_bound[1] + 1)
if type(self.env.action_space) == gym.spaces.Discrete:
self.num_actions = self.env.action_space.n
elif type(self.env.action_space) == gym.spaces.MultiDiscrete:
self.num_actions = np.prod(self.env.action_space.nvec)
else:
raise Exception("PQL only supports (multi)discrete action spaces.")

if type(self.env.observation_space) == gym.spaces.Discrete:
self.env_shape = (self.env.observation_space.n,)
elif type(self.env.observation_space) == gym.spaces.MultiDiscrete:
self.env_shape = self.env.observation_space.nvec
elif (
type(self.env.observation_space) == gym.spaces.Box
and self.env.observation_space.is_bounded(manner="both")
and issubclass(self.env.observation_space.dtype.type, numbers.Integral)
):
low_bound = np.array(self.env.observation_space.low)
high_bound = np.array(self.env.observation_space.high)
self.env_shape = high_bound - low_bound + 1
else:
raise Exception("PQL only supports discretizable observation spaces.")

self.num_states = np.prod(self.env_shape)
self.num_objectives = self.env.reward_space.shape[0]
self.counts = np.zeros((self.num_states, self.num_actions))
Expand Down

0 comments on commit d5bc675

Please sign in to comment.