diff --git a/pyproject.toml b/pyproject.toml index b07f464..957b7a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,4 +33,5 @@ dependencies = [ "Bug Tracker" = "https://github.com/Stanford-ILIAD/PantheonRL/issues" [tool.pylint] -disable = ["protected-access", "too-many-arguments", "too-many-instance-attributes", "too-many-statements", "too-many-branches"] \ No newline at end of file +disable = ["protected-access", "too-many-arguments", "too-many-instance-attributes", "too-many-statements", "too-many-branches", "too-many-locals", "duplicate-code"] +generated-members = ["numpy.*", "torch.*"] \ No newline at end of file diff --git a/src/pantheonrl/__init__.py b/src/pantheonrl/__init__.py index 7dba3ae..6a8d6c7 100644 --- a/src/pantheonrl/__init__.py +++ b/src/pantheonrl/__init__.py @@ -1,7 +1,18 @@ """ -`PantheonRL `_ is a package for training and testing multi-agent reinforcement learning environments. The goal of PantheonRL is to provide a modular and extensible framework for training agent policies, fine-tuning agent policies, ad-hoc pairing of agents, and more. +`PantheonRL `_ is a +package for training and testing multi-agent reinforcement learning +environments. The goal of PantheonRL is to provide a modular and +extensible framework for training agent policies, fine-tuning agent +policies, ad-hoc pairing of agents, and more. -PantheonRL is built to support Stable-Baselines3 (SB3), allowing direct access to many of SB3's standard RL training algorithms such as PPO. PantheonRL currently follows a decentralized training paradigm -- each agent is equipped with its own replay buffer and update algorithm. The agents objects are designed to be easily manipulable. They can be saved, loaded and plugged into different training procedures such as self-play, ad-hoc / cross-play, round-robin training, or finetuning. +PantheonRL is built to support Stable-Baselines3 (SB3), allowing +direct access to many of SB3's standard RL training algorithms such as +PPO. PantheonRL currently follows a decentralized training paradigm -- +each agent is equipped with its own replay buffer and update +algorithm. The agents objects are designed to be easily +manipulable. They can be saved, loaded and plugged into different +training procedures such as self-play, ad-hoc / cross-play, +round-robin training, or finetuning. """ import pantheonrl.envs @@ -9,14 +20,14 @@ Agent, StaticPolicyAgent, OnPolicyAgent, - OffPolicyAgent + OffPolicyAgent, ) from pantheonrl.common.multiagentenv import ( DummyEnv, MultiAgentEnv, TurnBasedEnv, - SimultaneousEnv + SimultaneousEnv, ) from pantheonrl.common.observation import Observation diff --git a/src/pantheonrl/algos/adap/adap_learn.py b/src/pantheonrl/algos/adap/adap_learn.py index 75f2830..97cca80 100644 --- a/src/pantheonrl/algos/adap/adap_learn.py +++ b/src/pantheonrl/algos/adap/adap_learn.py @@ -1,15 +1,20 @@ +""" +Modified implementation of PPO to support ADAP +""" import warnings -from typing import Any, Dict, Optional, Type, Union, Tuple +from typing import Any, Dict, Optional, Type, Union import numpy as np -import torch as th -import gymnasium as gym +import torch from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.type_aliases import (GymEnv, MaybeCallback, - Schedule) +from stable_baselines3.common.type_aliases import ( + GymEnv, + MaybeCallback, + Schedule, +) from stable_baselines3.common.utils import explained_variance, get_schedule_fn from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback @@ -95,6 +100,7 @@ def __init__( gae_lambda: float = 0.95, clip_range: Union[float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, @@ -102,22 +108,21 @@ def __init__( sde_sample_freq: int = -1, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, - device: Union[th.device, str] = "auto", + device: Union[torch.device, str] = "auto", _init_setup_model: bool = True, context_loss_coeff: float = 0.1, context_size: int = 3, num_context_samples: int = 5, context_sampler: str = "l2", - num_state_samples: int = 32 + num_state_samples: int = 32, ): if policy_kwargs is None: policy_kwargs = {} - policy_kwargs['context_size'] = context_size - super(ADAP, self).__init__( + policy_kwargs["context_size"] = context_size + super().__init__( policy, env, learning_rate=learning_rate, @@ -133,7 +138,6 @@ def __init__( policy_kwargs=policy_kwargs, verbose=verbose, device=device, - create_eval_env=create_eval_env, seed=seed, _init_setup_model=False, supported_action_spaces=( @@ -146,9 +150,10 @@ def __init__( # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization - assert ( - batch_size > 1 - ), "`batch_size` must be greater than 1. \ + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. \ See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: @@ -156,34 +161,25 @@ def __init__( # when doing advantage normalization if self.env.action_space == spaces.Box: - self.action_dist = 'gaussian' + self.action_dist = "gaussian" else: - self.action_dist = 'categorical' + self.action_dist = "categorical" buffer_size = self.env.num_envs * self.n_steps - assert ( - buffer_size > 1 - ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps=\ - {self.n_steps} and n_envs={self.env.num_envs}" - # Check that rollout buffer size is a multiple of mini-batch size - untruncated_batches = buffer_size // batch_size + assert buffer_size > 1 or ( + not normalize_advantage + ), f"`n_steps * n_envs` must be greater than 1. Currently \ + n_steps={self.n_steps} and n_envs={self.env.num_envs}" if buffer_size % batch_size > 0: warnings.warn( f"You have specified a mini-batch size of {batch_size}," - f" but because the `RolloutBuffer` is of size \ + f" but the `RolloutBuffer` is of size \ `n_steps * n_envs = {buffer_size}`," - f" after every {untruncated_batches} untruncated \ - mini-batches," - f" there will be a truncated mini-batch of size \ - {buffer_size % batch_size}\n" - f"We recommend using a `batch_size` that is a factor of \ - `n_steps * n_envs`.\n" - f"Info: (n_steps={self.n_steps} and \ - n_envs={self.env.num_envs})" ) self.batch_size = batch_size self.n_epochs = n_epochs self.clip_range_raw = clip_range self.clip_range_vf_raw = clip_range_vf + self.normalize_advantage = normalize_advantage self.target_kl = target_kl self.context_loss_coeff = context_loss_coeff @@ -198,19 +194,21 @@ def __init__( self.full_obs_shape = None - def set_env(self, env): - super(ADAP, self).set_env(env) + def set_env(self, env, force_reset=True): + """Set the env to use""" + super().set_env(env, force_reset=force_reset) if self.env.action_space == spaces.Box: - self.action_dist = 'gaussian' + self.action_dist = "gaussian" else: - self.action_dist = 'categorical' + self.action_dist = "categorical" def _setup_model(self) -> None: - super(ADAP, self)._setup_model() + super()._setup_model() sampled_context = SAMPLERS[self.context_sampler]( - ctx_size=self.context_size, num=1, torch=True) + ctx_size=self.context_size, num=1, use_torch=True + ) self.policy.set_context(sampled_context) @@ -218,9 +216,10 @@ def _setup_model(self) -> None: self.clip_range = get_schedule_fn(self.clip_range_raw) if self.clip_range_vf_raw is not None: if isinstance(self.clip_range_vf_raw, (float, int)): - assert self.clip_range_vf_raw > 0, \ - "`clip_range_vf` must be positive, " \ + assert self.clip_range_vf_raw > 0, ( + "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + ) self.clip_range_vf = get_schedule_fn(self.clip_range_vf_raw) else: @@ -230,6 +229,8 @@ def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range @@ -237,14 +238,14 @@ def train(self) -> None: # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf( - self._current_progress_remaining) + self._current_progress_remaining + ) entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] continue_training = True - # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] @@ -257,44 +258,48 @@ def train(self) -> None: actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed - # TODO: investigate why there is no issue with the gradient - # if that line is commented (as in SAC) if self.use_sde: self.policy.reset_noise(self.batch_size) values, log_prob, entropy = self.policy.evaluate_actions( - rollout_data.observations, actions) + rollout_data.observations, actions + ) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / \ - (advantages.std() + 1e-8) + # Normalization does not make sense if mini batchsize == 1 + if self.normalize_advantage and len(advantages) > 1: + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) - # ratio between old and new policy - ratio = th.exp(log_prob - rollout_data.old_log_prob) + # ratio between old and new policy, should be one at the first + ratio = torch.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss policy_loss_1 = advantages * ratio - policy_loss_2 = advantages * \ - th.clamp(ratio, 1 - clip_range, 1 + clip_range) - policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + policy_loss_2 = advantages * torch.clamp( + ratio, 1 - clip_range, 1 + clip_range + ) + policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) - clip_fraction = th.mean( - (th.abs(ratio - 1) > clip_range).float()).item() + clip_fraction = torch.mean( + (torch.abs(ratio - 1) > clip_range).float() + ).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: - # Clip the different between old and new value + # Clip the difference between old and new value # NOTE: this depends on the reward scaling - values_pred = rollout_data.old_values + th.clamp( + values_pred = rollout_data.old_values + torch.clamp( values - rollout_data.old_values, -clip_range_vf, - clip_range_vf + clip_range_vf, ) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) @@ -303,77 +308,85 @@ def train(self) -> None: # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob) + entropy_loss = -torch.mean(-log_prob) else: - entropy_loss = -th.mean(entropy) + entropy_loss = -torch.mean(entropy) entropy_losses.append(entropy_loss.item()) # Context loss for ADAP algorithm - context_loss = get_context_kl_loss(self, - self.policy, rollout_data) + context_loss = get_context_kl_loss( + self, self.policy, rollout_data + ) context_kl_divs.append(context_loss.detach().numpy()) - loss = policy_loss + self.ent_coef * entropy_loss \ - + self.vf_coef * value_loss \ + loss = ( + policy_loss + + self.ent_coef * entropy_loss + + self.vf_coef * value_loss + self.context_loss_coeff * context_loss + ) - # Calculate approximate form of reverse KL Divergence - with th.no_grad(): + with torch.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = th.mean( - (th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_div = ( + torch.mean((torch.exp(log_ratio) - 1) - log_ratio) + .cpu() + .numpy() + ) approx_kl_divs.append(approx_kl_div) - if self.target_kl is not None and \ - approx_kl_div > 1.5 * self.target_kl: + if ( + self.target_kl is not None + and approx_kl_div > 1.5 * self.target_kl + ): continue_training = False if self.verbose >= 1: print( - f"Early stopping at step {epoch} due \ - to reaching max kl: {approx_kl_div: .2f}") + f"Early stopping at step {epoch} due to \ + reaching max kl: {approx_kl_div:.2f}" + ) break # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_( - self.policy.parameters(), self.max_grad_norm) + torch.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.max_grad_norm + ) self.policy.optimizer.step() + self._n_updates += 1 if not continue_training: break - self._n_updates += self.n_epochs explained_var = explained_variance( self.rollout_buffer.values.flatten(), - self.rollout_buffer.returns.flatten()) + self.rollout_buffer.returns.flatten(), + ) # Logs self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) - self.logger.record("train/context_kl_loss", np.mean(context_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): self.logger.record( - "train/std", th.exp(self.policy.log_std).mean().item()) + "train/std", torch.exp(self.policy.log_std).mean().item() + ) - self.logger.record("train/n_updates", - self._n_updates, exclude="tensorboard") + self.logger.record( + "train/n_updates", self._n_updates, exclude="tensorboard" + ) self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf) - _last_obs: np.ndarray - _last_episode_starts: np.ndarray - full_obs_shape: Optional[Tuple[int, ]] - def collect_rollouts( self, env: VecEnv, @@ -382,28 +395,36 @@ def collect_rollouts( n_rollout_steps: int, ) -> bool: """ - Nearly identical to OnPolicyAlgorithm's collect_rollouts, but it also - resamples the context every episode. - Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. + :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param rollout_buffer: Buffer to fill with rollouts - :param n_steps: Number of experiences to collect per environment + :param n_rollout_steps: Number of steps to collect per environment :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ - assert self._last_obs is not None, "No previous observation provided" + assert ( + self._last_obs is not None + ), "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + n_steps = 0 + + # ADAP ADDITION if self.full_obs_shape is None: self.full_obs_shape = ( - rollout_buffer.obs_shape[0] + self.context_size,) + rollout_buffer.obs_shape[0] + self.context_size, + ) rollout_buffer.obs_shape = tuple(self.full_obs_shape) + # ADAP END rollout_buffer.reset() # Sample new weights for the state dependent exploration @@ -413,25 +434,39 @@ def collect_rollouts( callback.on_rollout_start() while n_steps < n_rollout_steps: - if self.use_sde and self.sde_sample_freq > 0 and \ - n_steps % self.sde_sample_freq == 0: + if ( + self.use_sde + and self.sde_sample_freq > 0 + and n_steps % self.sde_sample_freq == 0 + ): # Sample a new noise matrix self.policy.reset_noise(env.num_envs) - with th.no_grad(): + with torch.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - actions, values, log_probs = self.policy.forward(obs_tensor) + actions, values, log_probs = self.policy(obs_tensor) actions = actions.cpu().numpy() # Rescale and perform action clipped_actions = actions - # Clip the actions to avoid out of bound error - if isinstance(self.action_space, gym.spaces.Box): - clipped_actions = np.clip( - actions, self.action_space.low, self.action_space.high) + + if isinstance(self.action_space, spaces.Box): + if self.policy.squash_output: + # Unscale the actions to match env bounds + # if they were previously squashed (scaled in [-1, 1]) + clipped_actions = self.policy.unscale_action( + clipped_actions + ) + else: + # Otherwise, clip the actions to avoid out of bound error + # as we are sampling from an unbounded Gaussian distribution + clipped_actions = np.clip( + actions, self.action_space.low, self.action_space.high + ) new_obs, rewards, dones, infos = env.step(clipped_actions) + self.num_timesteps += env.num_envs # Give access to local variables @@ -442,31 +477,59 @@ def collect_rollouts( self._update_info_buffer(infos) n_steps += 1 - if isinstance(self.action_space, gym.spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) - rollout_buffer.add(np.concatenate( - (self._last_obs, - self.policy.get_context()), - axis=None), - actions, rewards, - self._last_episode_starts, values, log_probs) - self._last_obs = new_obs + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor( + infos[idx]["terminal_observation"] + )[0] + with torch.no_grad(): + terminal_value = self.policy.predict_values( + terminal_obs + )[0] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + np.concatenate( + (self._last_obs, self.policy.get_context()), axis=None + ), + # self._last_obs, # type: ignore[arg-type] + actions, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs, + ) + self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones # ADAP CHANGE: resample context if dones[0]: sampled_context = SAMPLERS[self.context_sampler]( - ctx_size=self.context_size, num=1, torch=True) + ctx_size=self.context_size, num=1, use_torch=True + ) self.policy.set_context(sampled_context) - with th.no_grad(): + with torch.no_grad(): # Compute value for the last timestep - obs_tensor = obs_as_tensor(new_obs, self.device) - _, values, _ = self.policy.forward(obs_tensor) + _, values, _ = self.policy.forward( + obs_as_tensor(new_obs, self.device) + ) rollout_buffer.compute_returns_and_advantage( - last_values=values, dones=dones) + last_values=values, dones=dones + ) + + callback.update_locals(locals()) callback.on_rollout_end() @@ -477,21 +540,15 @@ def learn( total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, tb_log_name: str = "ADAP", - eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "ADAP": - return super(ADAP, self).learn( + progress_bar: bool = False, + ): + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, - eval_env=eval_env, - eval_freq=eval_freq, - n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, - eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) diff --git a/src/pantheonrl/algos/adap/agent.py b/src/pantheonrl/algos/adap/agent.py index d79a243..55b100e 100644 --- a/src/pantheonrl/algos/adap/agent.py +++ b/src/pantheonrl/algos/adap/agent.py @@ -1,18 +1,28 @@ +""" +Module defining the ADAP partner agent. +""" from typing import Optional -from collections import deque +import time -import numpy as np -import torch as th +import copy +import sys + +import torch -from pantheonrl.common.util import (action_from_policy, clip_actions, - resample_noise) +from gymnasium import spaces -from stable_baselines3.common.utils import configure_logger -from stable_baselines3.common.utils import safe_mean +import numpy as np from pantheonrl.common.agents import OnPolicyAgent from pantheonrl.common.observation import Observation + +from stable_baselines3.common.utils import ( + safe_mean, + obs_as_tensor, +) + + from .adap_learn import ADAP from .util import SAMPLERS from .policies import AdapPolicy @@ -26,126 +36,192 @@ class AdapAgent(OnPolicyAgent): from ``OnPolicyAlgorithm``. :param model: Model representing the agent's learning algorithm + :param log_interval: Optional log interval for policy logging + :param working_timesteps: Estimate for number of timesteps to train for. + :param callback: Optional callback fed into the OnPolicyAlgorithm + :param tb_log_name: Name for tensorboard log """ - def __init__(self, - model: ADAP, - log_interval=None, - tensorboard_log=None, - tb_log_name="AdapAgent", - latent_syncer: Optional[AdapPolicy] = None): - self.model = model - self._last_episode_starts = [True] - self.n_steps = 0 - self.values: th.Tensor = th.empty(0) - - self.model.set_logger(configure_logger( - self.model.verbose, tensorboard_log, tb_log_name)) - - self.name = tb_log_name - self.num_timesteps = 0 - self.log_interval = log_interval or (1 if model.verbose else None) - self.iteration = 0 - self.model.ep_info_buffer = deque([{"r": 0, "l": 0}], maxlen=100) + def __init__( + self, + model: ADAP, + log_interval=None, + working_timesteps=1000, + callback=None, + tb_log_name="AdapAgent", + latent_syncer: Optional[AdapPolicy] = None, + ): + super().__init__( + model, log_interval, working_timesteps, callback, tb_log_name + ) self.latent_syncer = latent_syncer - buf = self.model.rollout_buffer - self.model.full_obs_shape = ( - buf.obs_shape[0] + self.model.context_size,) - buf.obs_shape = self.model.full_obs_shape - buf.reset() - - def get_action(self, obs: Observation, record: bool = True) -> np.ndarray: + def get_action(self, obs: Observation) -> np.ndarray: """ Return an action given an observation. - When `record` is True, the agent saves the last transition into its - buffer. It also updates the model if the buffer is full. + The agent saves the last transition into its buffer. It also updates + the model if the buffer is full. :param obs: The observation to use - :param record: Whether to record the obs, action (True when training) :returns: The action to take """ obs = obs.obs - if self.latent_syncer is not None: - self.model.policy.set_context(self.latent_syncer.get_context()) + if not isinstance(obs, np.ndarray): + obs = np.array([obs]) + callback = self.callback + rollout_buffer = self.model.rollout_buffer + if self.model.full_obs_shape is None: + self.model.full_obs_shape = ( + rollout_buffer.obs_shape[0] + self.model.context_size, + ) + + rollout_buffer.obs_shape = tuple(self.model.full_obs_shape) + rollout_buffer.reset() - buf = self.model.rollout_buffer + n_rollout_steps = self.model.n_steps - # train the model if the buffer is full - if record and self.n_steps >= self.model.n_steps: - buf.compute_returns_and_advantage( - last_values=self.values, - dones=self._last_episode_starts[0] + if self.model.num_timesteps >= self.total_timesteps: + self.callback.on_training_end() + self.iteration = 0 + self.total_timesteps, self.callback = self.model._setup_learn( + self.working_timesteps, + self.original_callback, + False, + self.tb_log_name, + False, ) - if self.log_interval is not None and \ - self.iteration % self.log_interval == 0: - self.model.logger.record( - "name", self.name, exclude="tensorboard") - self.model.logger.record( - "time/iterations", self.iteration, exclude="tensorboard") + self.callback.on_training_start(locals(), globals()) + + if self.n_steps >= n_rollout_steps: + with torch.no_grad(): + values = self.model.policy.predict_values( + obs_as_tensor(obs, self.model.device).unsqueeze(0) + ) + rollout_buffer.compute_returns_and_advantage( + last_values=values, dones=self.model._last_episode_starts + ) + self.old_buffer = copy.deepcopy(rollout_buffer) + callback.update_locals(locals()) + callback.on_rollout_end() - if len(self.model.ep_info_buffer) > 0 and \ - len(self.model.ep_info_buffer[0]) > 0: - last_exclude = self.model.ep_info_buffer.pop() - rews = [ep["r"] for ep in self.model.ep_info_buffer] - lens = [ep["l"] for ep in self.model.ep_info_buffer] + self.iteration += 1 + self.model._update_current_progress_remaining( + self.model.num_timesteps, self.working_timesteps + ) + + if ( + self.log_interval is not None + and self.iteration % self.log_interval == 0 + ): + assert self.model.ep_info_buffer is not None + time_elapsed = max( + (time.time_ns() - self.model.start_time) / 1e9, + sys.float_info.epsilon, + ) + fps = int( + ( + self.model.num_timesteps + - self.model._num_timesteps_at_start + ) + / time_elapsed + ) + self.model.logger.record( + "time/iterations", self.iteration, exclude="tensorboard" + ) + if ( + len(self.model.ep_info_buffer) > 0 + and len(self.model.ep_info_buffer[0]) > 0 + ): self.model.logger.record( - "rollout/ep_rew_mean", safe_mean(rews)) + "rollout/ep_rew_mean", + safe_mean( + [ + ep_info["r"] + for ep_info in self.model.ep_info_buffer + ] + ), + ) self.model.logger.record( - "rollout/ep_len_mean", safe_mean(lens)) - self.model.ep_info_buffer.append(last_exclude) - + "rollout/ep_len_mean", + safe_mean( + [ + ep_info["l"] + for ep_info in self.model.ep_info_buffer + ] + ), + ) + self.model.logger.record("time/fps", fps) self.model.logger.record( - "time/total_timesteps", self.num_timesteps, - exclude="tensorboard") - self.model.logger.dump(step=self.num_timesteps) - + "time/time_elapsed", + int(time_elapsed), + exclude="tensorboard", + ) + self.model.logger.record( + "time/total_timesteps", + self.model.num_timesteps, + exclude="tensorboard", + ) + self.model.logger.dump(step=self.model.num_timesteps) self.model.train() - self.iteration += 1 - buf.reset() - self.n_steps = 0 - resample_noise(self.model, self.n_steps) - - actions, values, log_probs = action_from_policy(obs, self.model.policy) + # Restarting + self.model.policy.set_training_mode(False) + self.n_steps = 0 + rollout_buffer.reset() + if self.model.use_sde: + self.model.policy.reset_noise(1) + self.callback.on_rollout_start() + + if ( + self.model.use_sde + and self.model.sde_sample_freq > 0 + and self.n_steps % self.model.sde_sample_freq == 0 + ): + self.model.policy.reset_noise(1) + + with torch.no_grad(): + obs_tensor = obs_as_tensor(obs, self.model.device) + actions, values, log_probs = self.model.policy( + obs_tensor.unsqueeze(0) + ) + actions = actions.cpu().numpy() + clipped_actions = actions + + if isinstance(self.model.action_space, spaces.Box): + clipped_actions = np.clip( + actions, + self.model.action_space.low, + self.model.action_space.high, + ) - # modify the rollout buffer with newest info + self.in_progress_info["l"] += 1 + self.model.num_timesteps += 1 + self.n_steps += 1 + if isinstance(self.model.action_space, spaces.Discrete): + actions = actions.reshape(-1, 1) + print(obs.shape) obs = np.concatenate((np.reshape(obs, (1, -1)), self.model.policy.get_context()), axis=1) - if record: - obs_shape = self.model.policy.observation_space.shape - act_shape = self.model.policy.action_space.shape - buf.add( - np.reshape(obs, (1,) + obs_shape), - np.reshape(actions, (1,) + act_shape), - [0], - self._last_episode_starts, - values, - log_probs - ) - self.n_steps += 1 - self.num_timesteps += 1 - self.values = values - return clip_actions(actions, self.model)[0] + rollout_buffer.add( + obs, + actions, + [0], + self.model._last_episode_starts, + values, + log_probs, + ) + return clipped_actions[0] def update(self, reward: float, done: bool) -> None: - """ - Add new rewards and done information. - - The rewards are added to buffer entry corresponding to the most recent - recorded action. - - :param reward: The reward receieved from the previous action step - :param done: Whether the game is done - """ - super(AdapAgent, self).update(reward, done) + super().update(reward, done) if done and self.latent_syncer is None: sampled_context = SAMPLERS[self.model.context_sampler]( - ctx_size=self.model.context_size, num=1, torch=True) + ctx_size=self.model.context_size, num=1, use_torch=True + ) self.model.policy.set_context(sampled_context) diff --git a/src/pantheonrl/algos/adap/policies.py b/src/pantheonrl/algos/adap/policies.py index 399505b..210be1e 100644 --- a/src/pantheonrl/algos/adap/policies.py +++ b/src/pantheonrl/algos/adap/policies.py @@ -1,7 +1,12 @@ +""" +Module defining the Policy for ADAP +""" +# pylint: disable=locally-disabled, not-callable + from typing import Any, Dict, Optional, Type, Union, List, Tuple from itertools import zip_longest -import torch as th +import torch import gymnasium as gym from torch import nn @@ -14,11 +19,15 @@ from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, FlattenExtractor, - MlpExtractor + MlpExtractor, ) class AdapPolicy(ActorCriticPolicy): + """ + Base Policy for the ADAP Actor-critic policy + """ + def __init__( self, observation_space: gym.spaces.Space, @@ -33,15 +42,18 @@ def __init__( sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = - FlattenExtractor, + features_extractor_class: Type[ + BaseFeaturesExtractor + ] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[Optimizer] = Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, - context_size: int = 3 + context_size: int = 3, ): self.context_size = context_size + self.context = None + self.mlp_extractor = None super().__init__( observation_space=observation_space, action_space=action_space, @@ -52,7 +64,7 @@ def __init__( use_sde=use_sde, log_std_init=log_std_init, full_std=full_std, - sde_net_arch=sde_net_arch, + # sde_net_arch=sde_net_arch, use_expln=use_expln, squash_output=squash_output, features_extractor_class=features_extractor_class, @@ -60,12 +72,15 @@ def __init__( normalize_images=normalize_images, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, + share_features_extractor=True ) def set_context(self, ctxt): + """ Set the context """ self.context = ctxt def get_context(self): + """ Get the current context """ return self.context def _build_mlp_extractor(self) -> None: @@ -83,8 +98,9 @@ def _build_mlp_extractor(self) -> None: device=self.device, ) - def _get_latent(self, - obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def _get_latent( + self, obs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Get the latent code (activations of the last layer of each network) for the different networks. @@ -94,9 +110,9 @@ def _get_latent(self, """ # Preprocess the observation if needed features = self.extract_features(obs) - features = th.cat( - (features, self.context.repeat(features.size()[0], 1)), - dim=1) + features = torch.cat( + (features, self.context.repeat(features.size()[0], 1)), dim=1 + ) latent_pi, latent_vf = self.mlp_extractor(features) # Features for sde @@ -105,10 +121,22 @@ def _get_latent(self, latent_sde = self.sde_features_extractor(features) return latent_pi, latent_vf, latent_sde - def evaluate_actions(self, - obs: th.Tensor, - actions: th.Tensor - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + # def forward(self, obs: torch.Tensor, deterministic: bool = False): + # # Preprocess the observation if needed + # features = self.extract_features(obs) + # features = torch.cat((features, obs[:, -self.context_size :]), dim=1).float() + # latent_pi, latent_vf = self.mlp_extractor(features) + # # Evaluate the values for the given observations + # values = self.value_net(latent_vf) + # distribution = self._get_action_dist_from_latent(latent_pi) + # actions = distribution.get_actions(deterministic=deterministic) + # log_prob = distribution.log_prob(actions) + # actions = actions.reshape((-1, *self.action_space.shape)) + # return actions, values, log_prob + + def evaluate_actions( + self, obs: torch.Tensor, actions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Evaluate actions according to the current policy, given the observations. @@ -117,10 +145,8 @@ def evaluate_actions(self, :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ - features = self.extract_features(obs[:, :-self.context_size]) - features = th.cat( - (features, obs[:, -self.context_size:]), - dim=1) + features = self.extract_features(obs[:, : -self.context_size]) + features = torch.cat((features, obs[:, -self.context_size :]), dim=1) latent_pi, latent_vf = self.mlp_extractor(features) # Features for sde @@ -133,16 +159,13 @@ def evaluate_actions(self, return values, log_prob, distribution.entropy() -class MultModel(MlpExtractor): +class MultModel(nn.Module): + """ Neural Network representing multiplicative layers """ + def __init__( - self, - feature_dim, - net_arch, - activation_fn, - device, - context_size - ): - nn.Module.__init__(self) + self, feature_dim, net_arch, activation_fn, device, context_size + ): + super().__init__() self.obs_space_size = feature_dim + context_size self.context_size = context_size @@ -158,23 +181,25 @@ def __init__( # Iterate through shared layers and build shared parts of the network for layer in net_arch: if isinstance(layer, int): # Check that this is a shared layer - # TODO: give layer a meaningful name # add linear of size layer shared_net.append(nn.Linear(last_layer_dim_shared, layer)) shared_net.append(activation_fn()) last_layer_dim_shared = layer else: - assert isinstance(layer, dict), \ - "Error: the net_arch list can only contain ints and dicts" + assert isinstance( + layer, dict + ), "Error: the net_arch list can only contain ints and dicts" if "pi" in layer: - assert isinstance(layer["pi"], list), \ - "Error: net_arch[-1]['pi'] must \ + assert isinstance( + layer["pi"], list + ), "Error: net_arch[-1]['pi'] must \ contain a list of integers." policy_only_layers = layer["pi"] if "vf" in layer: - assert isinstance(layer["vf"], list), \ - "Error: net_arch[-1]['vf'] must \ + assert isinstance( + layer["vf"], list + ), "Error: net_arch[-1]['vf'] must \ contain a list of integers." value_only_layers = layer["vf"] break @@ -183,18 +208,21 @@ def __init__( last_layer_dim_vf = last_layer_dim_shared # Build the non-shared part of the network - for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, - value_only_layers): + for pi_layer_size, vf_layer_size in zip_longest( + policy_only_layers, value_only_layers + ): if pi_layer_size is not None: - assert isinstance(pi_layer_size, int), \ - "Error: net_arch[-1]['pi'] must only contain integers." + assert isinstance( + pi_layer_size, int + ), "Error: net_arch[-1]['pi'] must only contain integers." policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size)) policy_net.append(activation_fn()) last_layer_dim_pi = pi_layer_size if vf_layer_size is not None: - assert isinstance(vf_layer_size, int), \ - "Error: net_arch[-1]['vf'] must only contain integers." + assert isinstance( + vf_layer_size, int + ), "Error: net_arch[-1]['vf'] must only contain integers." value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size)) value_net.append(activation_fn()) last_layer_dim_vf = vf_layer_size @@ -211,7 +239,7 @@ def __init__( self.agent_branch_1 = nn.Sequential(*policy_net[0:2]).to(device) self.agent_scaling = nn.Sequential( nn.Linear(self.hidden_dim1, self.hidden_dim1 * self.context_size), - activation_fn() + activation_fn(), ).to(device) self.agent_branch_2 = nn.Sequential(*policy_net[2:]).to(device) @@ -219,52 +247,61 @@ def __init__( self.value_branch_1 = nn.Sequential(*value_net[0:2]).to(device) self.value_scaling = nn.Sequential( nn.Linear(self.hidden_dim2, self.hidden_dim2 * self.context_size), - activation_fn() + activation_fn(), ).to(device) self.value_branch_2 = nn.Sequential(*value_net[2:]).to(device) def get_input_size_excluding_ctx(self): + """ Returns input size excluding the size of context """ return self.obs_space_size - self.context_size def get_input_size_inluding_ctx(self): + """ Returns full input size """ return self.obs_space_size - def policies(self, observations: th.Tensor, - contexts: th.Tensor) -> th.Tensor: - + def policies( + self, observations: torch.Tensor, contexts: torch.Tensor + ) -> torch.Tensor: + """ Returns the logits from the policy function """ batch_size = observations.shape[0] x = self.agent_branch_1(observations) x_a = self.agent_scaling(x) # reshape to do context multiplication x_a = x_a.view((batch_size, self.hidden_dim1, self.context_size)) - x_a_out = th.matmul(x_a, contexts.unsqueeze(-1)).squeeze(-1) + x_a_out = torch.matmul(x_a, contexts.unsqueeze(-1)).squeeze(-1) logits = self.agent_branch_2(x + x_a_out) return logits - def values(self, observations: th.Tensor, - contexts: th.Tensor) -> th.Tensor: - + def values( + self, observations: torch.Tensor, contexts: torch.Tensor + ) -> torch.Tensor: + """ Returns the response from the value function """ batch_size = observations.shape[0] x = self.value_branch_1(observations) x_a = self.value_scaling(x) # reshape to do context multiplication x_a = x_a.view((batch_size, self.hidden_dim2, self.context_size)) - x_a_out = th.matmul(x_a, contexts.unsqueeze(-1)).squeeze(-1) + x_a_out = torch.matmul(x_a, contexts.unsqueeze(-1)).squeeze(-1) values = self.value_branch_2(x + x_a_out) # values = self.value_branch_2(x_a_out) return values - def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Returns the action logits and values """ features = self.shared_net(features) - observations = features[:, :-self.context_size] - contexts = features[:, -self.context_size:] - return self.policies(observations, contexts), \ - self.values(observations, contexts) + observations = features[:, : -self.context_size] + contexts = features[:, -self.context_size :] + return self.policies(observations, contexts), self.values( + observations, contexts + ) class AdapPolicyMult(AdapPolicy): + """ + Multiplicative Policy for the ADAP Actor-critic policy + """ def _build_mlp_extractor(self) -> None: """ @@ -279,5 +316,5 @@ def _build_mlp_extractor(self) -> None: net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device, - context_size=self.context_size + context_size=self.context_size, ) diff --git a/src/pantheonrl/algos/adap/util.py b/src/pantheonrl/algos/adap/util.py index a90115f..4dd2639 100644 --- a/src/pantheonrl/algos/adap/util.py +++ b/src/pantheonrl/algos/adap/util.py @@ -1,20 +1,28 @@ +""" +Collection of helper functions for ADAP +""" + import copy from itertools import combinations -import torch as th +from typing import TYPE_CHECKING + +import torch import numpy as np from torch.distributions import kl from stable_baselines3.common import distributions from stable_baselines3.common.buffers import RolloutBufferSamples -from typing import TYPE_CHECKING + if TYPE_CHECKING: from .adap_learn import ADAP from .policies import AdapPolicy -def kl_divergence(dist_true: distributions.Distribution, - dist_pred: distributions.Distribution) -> th.Tensor: +def kl_divergence( + dist_true: distributions.Distribution, + dist_pred: distributions.Distribution, +) -> torch.Tensor: """ Wrapper for the PyTorch implementation of the full form KL Divergence :param dist_true: the p distribution @@ -22,97 +30,108 @@ def kl_divergence(dist_true: distributions.Distribution, :return: KL(dist_true||dist_pred) """ # KL Divergence for different distribution types is out of scope - assert dist_true.__class__ == dist_pred.__class__, \ - "Error: input distributions should be the same type" + assert ( + dist_true.__class__ == dist_pred.__class__ + ), "Error: input distributions should be the same type" # MultiCategoricalDistribution is not a PyTorch Distribution subclass # so we need to implement it ourselves! if isinstance(dist_pred, distributions.MultiCategoricalDistribution): - return th.stack( - [kl.kl_divergence(p, q) for p, q in zip( - dist_true.distribution, dist_pred.distribution)], + return torch.stack( + [ + kl.kl_divergence(p, q) + for p, q in zip(dist_true.distribution, dist_pred.distribution) + ], dim=1, ).sum(dim=1) # Use the PyTorch kl_divergence implementation - else: - return kl.kl_divergence(dist_true.distribution, dist_pred.distribution) + return kl.kl_divergence(dist_true.distribution, dist_pred.distribution) -def get_L2_sphere(ctx_size, num, torch=False): - if torch: - ctxs = th.rand(num, ctx_size, device='cpu') * 2 - 1 - ctxs = ctxs / (th.sum((ctxs)**2, dim=-1).reshape(num, 1))**(1/2) - ctxs = ctxs.to('cpu') +def get_l2_sphere(ctx_size, num, use_torch=False): + """ Samples from l2 sphere """ + if use_torch: + ctxs = torch.rand(num, ctx_size, device="cpu") * 2 - 1 + ctxs = ctxs / (((ctxs) ** 2).sum(dim=-1).reshape(num, 1)) ** (1 / 2) + ctxs = ctxs.to("cpu") else: ctxs = np.random.rand(num, ctx_size) * 2 - 1 - ctxs = ctxs / (np.sum((ctxs)**2, axis=-1).reshape(num, 1))**(1/2) + ctxs = ctxs / (np.sum((ctxs) ** 2, axis=-1).reshape(num, 1)) ** (1 / 2) return ctxs -def get_unit_square(ctx_size, num, torch=False): - if torch: - ctxs = th.rand(num, ctx_size) * 2 - 1 +def get_unit_square(ctx_size, num, use_torch=False): + """ Samples from unit square centered at 0 """ + if use_torch: + ctxs = torch.rand(num, ctx_size) * 2 - 1 else: ctxs = np.random.rand(num, ctx_size) * 2 - 1 return ctxs -def get_positive_square(ctx_size, num, torch=False): - if torch: - ctxs = th.rand(num, ctx_size) +def get_positive_square(ctx_size, num, use_torch=False): + """ Samples from the square with axes between 0 and 1 """ + if use_torch: + ctxs = torch.rand(num, ctx_size) else: ctxs = np.random.rand(num, ctx_size) return ctxs -def get_categorical(ctx_size, num, torch=False): - if torch: - ctxs = th.zeros(num, ctx_size) - ctxs[th.arange(num), th.randint(0, ctx_size, size=(num,))] = 1 +def get_categorical(ctx_size, num, use_torch=False): + """ Samples from categorical distribution """ + if use_torch: + ctxs = torch.zeros(num, ctx_size) + ctxs[torch.arange(num), torch.randint(0, ctx_size, size=(num,))] = 1 else: ctxs = np.zeros((num, ctx_size)) ctxs[np.arange(num), np.random.randint(0, ctx_size, size=(num,))] = 1 return ctxs -def get_natural_number(ctx_size, num, torch=False): - ''' +def get_natural_number(ctx_size, num, use_torch=False): + """ Returns context vector of shape (num,1) with numbers in range [0, ctx_size] - ''' - if torch: - ctxs = th.randint(0, ctx_size, size=(num, 1)) + """ + if use_torch: + ctxs = torch.randint(0, ctx_size, size=(num, 1)) else: ctxs = np.random.randint(0, ctx_size, size=(num, 1)) return ctxs -SAMPLERS = {"l2": get_L2_sphere, - "unit_square": get_unit_square, - "positive_square": get_positive_square, - "categorical": get_categorical, - "natural_numbers": get_natural_number} +SAMPLERS = { + "l2": get_l2_sphere, + "unit_square": get_unit_square, + "positive_square": get_positive_square, + "categorical": get_categorical, + "natural_numbers": get_natural_number, +} -def get_context_kl_loss(policy: 'ADAP', model: 'AdapPolicy', - train_batch: RolloutBufferSamples): +def get_context_kl_loss( + policy: "ADAP", model: "AdapPolicy", train_batch: RolloutBufferSamples +): + """ Gets the KL loss for ADAP """ - original_obs = train_batch.observations[:, :-policy.context_size] + original_obs = train_batch.observations[:, : -policy.context_size] context_size = policy.context_size num_context_samples = policy.num_context_samples num_state_samples = policy.num_state_samples - indices = th.randperm(original_obs.shape[0])[:num_state_samples] + indices = torch.randperm(original_obs.shape[0])[:num_state_samples] sampled_states = original_obs[indices] num_state_samples = min(num_state_samples, sampled_states.shape[0]) all_contexts = set() all_action_dists = [] old_context = model.get_context() - for i in range(0, num_context_samples): # 10 sampled contexts + for _ in range(0, num_context_samples): # 10 sampled contexts sampled_context = SAMPLERS[policy.context_sampler]( - ctx_size=context_size, num=1, torch=True) + ctx_size=context_size, num=1, use_torch=True + ) if sampled_context in all_contexts: continue @@ -121,11 +140,14 @@ def get_context_kl_loss(policy: 'ADAP', model: 'AdapPolicy', model.set_context(sampled_context) latent_pi, _, latent_sde = model._get_latent(sampled_states) context_action_dist = model._get_action_dist_from_latent( - latent_pi, latent_sde) + latent_pi, latent_sde + ) all_action_dists.append(copy.copy(context_action_dist)) model.set_context(old_context) - all_CLs = [th.mean(th.exp(-kl_divergence(a, b))) - for a, b in combinations(all_action_dists, 2)] - rawans = sum(all_CLs)/len(all_CLs) + all_cls = [ + torch.mean(torch.exp(-kl_divergence(a, b))) + for a, b in combinations(all_action_dists, 2) + ] + rawans = sum(all_cls) / len(all_cls) return rawans diff --git a/src/pantheonrl/algos/bc.py b/src/pantheonrl/algos/bc.py index f32ff6f..674a1da 100644 --- a/src/pantheonrl/algos/bc.py +++ b/src/pantheonrl/algos/bc.py @@ -1,4 +1,5 @@ -"""Behavioural Cloning (BC). +""" +Behavioural Cloning (BC). Trains policy by applying supervised learning to a fixed dataset of (observation, action) pairs generated by some expert demonstrator. @@ -7,33 +8,54 @@ """ import contextlib -from typing import (Any, Callable, Dict, Iterable, Mapping, - Optional, Tuple, Type, Union) +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Tuple, + Type, + Union, +) import gymnasium as gym import numpy as np -import torch as th +import torch import torch.utils.data as th_data from torch.optim.optimizer import Optimizer from torch.optim.adam import Adam import tqdm.autonotebook as tqdm from stable_baselines3.common import policies, utils -from pantheonrl.common.trajsaver import (TransitionsMinimal, - transitions_collate_fn) +from pantheonrl.common.trajsaver import ( + TransitionsMinimal, + transitions_collate_fn, +) from pantheonrl.common.util import FeedForward32Policy log = utils.configure_logger(verbose=0) # change to 1 for debugging +@dataclass class BCShell: - def __init__(self, policy): - self.policy = policy + """ Shell class for BC policy """ + policy: FeedForward32Policy + + def get_policy(self): + """ Get the current policy """ + return self.policy + + def set_policy(self, new_policy): + """ Set the BC policy """ + self.policy = new_policy def reconstruct_policy( policy_path: str, - device: Union[th.device, str] = "auto", + device: Union[torch.device, str] = "auto", ) -> policies.BasePolicy: """Reconstruct a saved policy. Args: @@ -42,7 +64,7 @@ def reconstruct_policy( Returns: policy: policy with reloaded weights. """ - policy = th.load(policy_path, map_location=utils.get_device(device)) + policy = torch.load(policy_path, map_location=utils.get_device(device)) assert isinstance(policy, policies.BasePolicy) return policy @@ -63,8 +85,29 @@ def __call__(self, _): """ return self.lr + def set_lr(self, new_lr): + """ Sets a new learning rate """ + self.lr = new_lr + class EpochOrBatchIteratorWithProgress: + """ + Wraps DataLoader so that all BC batches can be processed in a one + for-loop. Also uses `tqdm` to show progress in stdout. + Args: + data_loader: An iterable over data dicts, as used in `BC`. + n_epochs: The number of epochs to iterate through in one call to + __iter__. Exactly one of `n_epochs` and `n_batches` should be + provided. + n_batches: The number of batches to iterate through in one call to + __iter__. Exactly one of `n_epochs` and `n_batches` should be + provided. + on_epoch_end: A callback function without parameters to be called + at the end of every epoch. + on_batch_end: A callback function without parameters to be called + at the end of every batch. + """ + def __init__( self, data_loader: Iterable[dict], @@ -73,22 +116,6 @@ def __init__( on_epoch_end: Optional[Callable[[], None]] = None, on_batch_end: Optional[Callable[[], None]] = None, ): - """ - Wraps DataLoader so that all BC batches can be processed in a one - for-loop. Also uses `tqdm` to show progress in stdout. - Args: - data_loader: An iterable over data dicts, as used in `BC`. - n_epochs: The number of epochs to iterate through in one call to - __iter__. Exactly one of `n_epochs` and `n_batches` should be - provided. - n_batches: The number of batches to iterate through in one call to - __iter__. Exactly one of `n_epochs` and `n_batches` should be - provided. - on_epoch_end: A callback function without parameters to be called - at the end of every epoch. - on_batch_end: A callback function without parameters to be called - at the end of every batch. - """ if n_epochs is not None and n_batches is None: self.use_epochs = True elif n_epochs is None and n_batches is not None: @@ -135,11 +162,11 @@ def update_desc(): batch_size = len(batch["obs"]) assert batch_size > 0 samples_so_far += batch_size - stats = dict( - epoch_num=epoch_num, - batch_num=batch_num, - samples_so_far=samples_so_far, - ) + stats = { + "epoch_num": epoch_num, + "batch_num": batch_num, + "samples_so_far": samples_so_far, + } yield batch, stats if self.on_batch_end is not None: self.on_batch_end() @@ -168,8 +195,33 @@ def update_desc(): if epoch_num >= self.n_epochs: return + def set_data_loader(self, new_data_loader): + """ Set the data loader to new value """ + self.data_loader = new_data_loader + class BC: + """ + Behavioral cloning (BC). + + Recovers a policy via supervised learning on observation-action Tensor + pairs, sampled from a Torch DataLoader or any Iterator that ducktypes + `torch.utils.data.DataLoader`. + Args: + observation_space: the observation space of the environment. + action_space: the action space of the environment. + policy_class: used to instantiate imitation policy. + policy_kwargs: keyword arguments passed to policy's constructor. + expert_data: If not None, then immediately call + `self.set_expert_data_loader(expert_data)` during + initialization. + optimizer_cls: optimiser to use for supervised training. + optimizer_kwargs: keyword arguments, excluding learning rate and + weight decay, for optimiser construction. + ent_weight: scaling applied to the policy's entropy regularization. + l2_weight: scaling applied to the policy's L2 regularization. + device: name/identity of device to place policy on. + """ DEFAULT_BATCH_SIZE: int = 32 """ @@ -189,41 +241,23 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, ent_weight: float = 1e-3, l2_weight: float = 0.0, - device: Union[str, th.device] = "auto", + device: Union[str, torch.device] = "auto", ): - """Behavioral cloning (BC). - Recovers a policy via supervised learning on observation-action Tensor - pairs, sampled from a Torch DataLoader or any Iterator that ducktypes - `torch.utils.data.DataLoader`. - Args: - observation_space: the observation space of the environment. - action_space: the action space of the environment. - policy_class: used to instantiate imitation policy. - policy_kwargs: keyword arguments passed to policy's constructor. - expert_data: If not None, then immediately call - `self.set_expert_data_loader(expert_data)` during - initialization. - optimizer_cls: optimiser to use for supervised training. - optimizer_kwargs: keyword arguments, excluding learning rate and - weight decay, for optimiser construction. - ent_weight: scaling applied to the policy's entropy regularization. - l2_weight: scaling applied to the policy's L2 regularization. - device: name/identity of device to place policy on. - """ if optimizer_kwargs: if "weight_decay" in optimizer_kwargs: raise ValueError( - "Use the parameter l2_weight instead of weight_decay.") + "Use the parameter l2_weight instead of weight_decay." + ) self.action_space = action_space self.observation_space = observation_space self.policy_class = policy_class self.device = device = utils.get_device(device) - self.policy_kwargs = dict( - observation_space=self.observation_space, - action_space=self.action_space, - lr_schedule=ConstantLRSchedule(), - ) + self.policy_kwargs = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "lr_schedule": ConstantLRSchedule(), + } self.policy_kwargs.update(policy_kwargs or {}) self.device = utils.get_device(device) @@ -233,7 +267,8 @@ def __init__( optimizer_kwargs = optimizer_kwargs or {} self.optimizer = optimizer_cls( - self.policy.parameters(), **optimizer_kwargs) + self.policy.parameters(), **optimizer_kwargs + ) self.expert_data_loader: Optional[Iterable[Mapping]] = None self.ent_weight = ent_weight @@ -269,9 +304,9 @@ def set_expert_data_loader( def _calculate_loss( self, - obs: Union[th.Tensor, np.ndarray], - acts: Union[th.Tensor, np.ndarray], - ) -> Tuple[th.Tensor, Dict[str, float]]: + obs: Union[torch.Tensor, np.ndarray], + acts: Union[torch.Tensor, np.ndarray], + ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Calculate the supervised learning loss used to train the behavioral clone. @@ -285,15 +320,15 @@ def _calculate_loss( optimize. stats_dict: Statistics about the learning process to be logged. """ - obs = th.as_tensor(obs, device=self.device).detach() - acts = th.as_tensor(acts, device=self.device).detach() + obs = torch.as_tensor(obs, device=self.device).detach() + acts = torch.as_tensor(acts, device=self.device).detach() _, log_prob, entropy = self.policy.evaluate_actions(obs, acts) - prob_true_act = th.exp(log_prob).mean() + prob_true_act = log_prob.exp().mean() log_prob = log_prob.mean() entropy = entropy.mean() - l2_norms = [th.sum(th.square(w)) for w in self.policy.parameters()] + l2_norms = [w.square().sum() for w in self.policy.parameters()] # divide by 2 to cancel with gradient of square l2_norm = sum(l2_norms) / 2 @@ -302,15 +337,15 @@ def _calculate_loss( l2_loss = self.l2_weight * l2_norm loss = neglogp + ent_loss + l2_loss - stats_dict = dict( - neglogp=neglogp.item(), - loss=loss.item(), - entropy=entropy.item(), - ent_loss=ent_loss.item(), - prob_true_act=prob_true_act.item(), - l2_norm=l2_norm.item(), - l2_loss=l2_loss.item(), - ) + stats_dict = { + "neglogp": neglogp.item(), + "loss": loss.item(), + "entropy": entropy.item(), + "ent_loss": ent_loss.item(), + "prob_true_act": prob_true_act.item(), + "l2_norm": l2_norm.item(), + "l2_loss": l2_loss.item(), + } return loss, stats_dict @@ -349,7 +384,8 @@ def train( batch_num = 0 for batch, stats_dict_it in it: loss, stats_dict_loss = self._calculate_loss( - batch["obs"], batch["acts"]) + batch["obs"], batch["acts"] + ) self.optimizer.zero_grad() loss.backward() @@ -363,8 +399,8 @@ def train( batch_num += 1 def save_policy(self, policy_path: str) -> None: - """Save policy to a path. Can be reloaded by `.reconstruct_policy()`. + """Save policy to a patorch. Can be reloaded by `.reconstruct_policy()`. Args: policy_path: path to save policy to. """ - th.save(self.policy, policy_path) + torch.save(self.policy, policy_path) diff --git a/src/pantheonrl/algos/modular/learn.py b/src/pantheonrl/algos/modular/learn.py index d734bc4..be989b3 100644 --- a/src/pantheonrl/algos/modular/learn.py +++ b/src/pantheonrl/algos/modular/learn.py @@ -1,23 +1,33 @@ +""" +Implementation of the Modular Algorithm. +""" + import time -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Type, Union + +import warnings import gymnasium as gym import numpy as np -import torch as th +import torch from gymnasium import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common import logger -from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import ( + GymEnv, + MaybeCallback, + Schedule, +) from stable_baselines3.common.utils import safe_mean from stable_baselines3.common.vec_env import VecEnv -from stable_baselines3.common.utils import explained_variance, get_schedule_fn +from stable_baselines3.common.utils import get_schedule_fn + + class ModularAlgorithm(OnPolicyAlgorithm): """ @@ -43,18 +53,16 @@ def __init__( sde_sample_freq: int = -1, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, - device: Union[th.device, str] = "auto", + device: Union[torch.device, str] = "auto", _init_setup_model: bool = True, - # my additional arguments - marginal_reg_coef : float = 0.0, + marginal_reg_coef: float = 0.0, ): - super(ModularAlgorithm, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -70,7 +78,6 @@ def __init__( policy_kwargs=policy_kwargs, verbose=verbose, device=device, - create_eval_env=create_eval_env, seed=seed, _init_setup_model=False, supported_action_spaces=( @@ -80,14 +87,15 @@ def __init__( spaces.MultiBinary, ), ) - + self.marginal_reg_coef = marginal_reg_coef # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization assert ( batch_size > 1 - ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + ), "`batch_size` must be greater than 1. \ + See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: # Check that `n_steps * n_envs > 1` to avoid NaN @@ -95,17 +103,15 @@ def __init__( buffer_size = self.env.num_envs * self.n_steps assert ( buffer_size > 1 - ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" - # Check that the rollout buffer size is a multiple of the mini-batch size - untruncated_batches = buffer_size // batch_size + ), f"`n_steps * n_envs` must be greater than 1. \ + Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is + # a multiple of the mini-batch size if buffer_size % batch_size > 0: warnings.warn( f"You have specified a mini-batch size of {batch_size}," - f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," - f" after every {untruncated_batches} untruncated mini-batches," - f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" - f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" - f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + f" but the `RolloutBuffer` is of size \ + `n_steps * n_envs = {buffer_size}`." ) self.batch_size = batch_size self.n_epochs = n_epochs @@ -113,11 +119,13 @@ def __init__( self.clip_range_vf = clip_range_vf self.target_kl = target_kl + self._last_dones = None + if _init_setup_model: self._setup_model() def _setup_model(self) -> None: - + # OnPolicyAlgorithm's _setup_model self._setup_lr_schedule() self.set_random_seed(self.seed) @@ -127,46 +135,57 @@ def _setup_model(self) -> None: self.action_space, self.lr_schedule, use_sde=self.use_sde, - **self.policy_kwargs # pytype:disable=not-instantiable + **self.policy_kwargs, # pytype:disable=not-instantiable ) self.policy = self.policy.to(self.device) - - buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer + self.rollout_buffer = [ + RolloutBuffer( + self.n_steps, + self.observation_space, + self.action_space, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + for _ in range(self.policy.num_partners) + ] - self.rollout_buffer = [buffer_cls( - self.n_steps, - self.observation_space, - self.action_space, - self.device, - gamma=self.gamma, - gae_lambda=self.gae_lambda, - n_envs=self.n_envs, - ) for _ in range(self.policy.num_partners)] - # PPO's _setup_model # Initialize schedules for policy/value clipping self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): - assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + assert self.clip_range_vf > 0, ( + "`clip_range_vf` must be positive, " + "pass `None` to deactivate vf clipping" + ) self.clip_range_vf = get_schedule_fn(self.clip_range_vf) - + def collect_rollouts( - self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int, partner_idx: int + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + partner_idx: int = 0, ) -> bool: """ Collect rollouts using the current policy and fill a `RolloutBuffer`. :param env: (VecEnv) The training environment - :param callback: (BaseCallback) Callback that will be called at each step - (and at the beginning and end of the rollout) + :param callback: (BaseCallback) Callback that will be called at each + step (and at the beginning and end of the rollout) :param rollout_buffer: (RolloutBuffer) Buffer to fill with rollouts :param n_steps: (int) Number of experiences to collect per environment - :return: (bool) True if function returned with at least `n_rollout_steps` - collected, False if callback terminated rollout prematurely. + :return: (bool) True if function returned with at least + `n_rollout_steps` collected, False if callback terminated rollout + prematurely. """ - assert self._last_obs is not None, "No previous observation was provided" + assert ( + self._last_obs is not None + ), "No previous observation was provided" n_steps = 0 rollout_buffer.reset() # Sample new weights for the state dependent exploration @@ -177,22 +196,31 @@ def collect_rollouts( self._last_dones = None while n_steps < n_rollout_steps: - if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + self.env.envs[0].set_partnerid(partner_idx) + if ( + self.use_sde + and self.sde_sample_freq > 0 + and n_steps % self.sde_sample_freq == 0 + ): # Sample a new noise matrix self.policy.reset_noise(env.num_envs) - with th.no_grad(): + with torch.no_grad(): # Convert to pytorch tensor - obs_tensor = th.as_tensor(self._last_obs).to(self.device) - #actions, values, log_probs = self.policy.forward(obs_tensor) - actions, values, log_probs = self.policy.forward(obs_tensor, partner_idx=partner_idx) + obs_tensor = torch.as_tensor(self._last_obs).to(self.device) + # actions, values, log_probs = self.policy.forward(obs_tensor) + actions, values, log_probs = self.policy.forward( + obs_tensor, partner_idx=partner_idx + ) actions = actions.cpu().numpy() # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error if isinstance(self.action_space, gym.spaces.Box): - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + clipped_actions = np.clip( + actions, self.action_space.low, self.action_space.high + ) env.envs[0].set_partnerid(partner_idx) new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -207,7 +235,14 @@ def collect_rollouts( if isinstance(self.action_space, gym.spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) - rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs) + rollout_buffer.add( + self._last_obs, + actions, + rewards, + self._last_dones, + values, + log_probs, + ) self._last_obs = new_obs self._last_dones = dones @@ -216,8 +251,7 @@ def collect_rollouts( callback.on_rollout_end() return True - - + def train(self) -> None: """ Update policy using the currently gathered @@ -229,7 +263,9 @@ def train(self) -> None: clip_range = self.clip_range(self._current_progress_remaining) # Optional: clip range for the value function if self.clip_range_vf is not None: - clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + clip_range_vf = self.clip_range_vf( + self._current_progress_remaining + ) entropy_losses, all_kl_divs = [], [] pg_losses, value_losses = [], [] @@ -241,36 +277,48 @@ def train(self) -> None: approx_kl_divs = [] # Do a complete pass on the rollout buffer # for rollout_data in self.rollout_buffer.get(self.batch_size): - for rollout_data in self.rollout_buffer[partner_idx].get(self.batch_size): + for rollout_data in self.rollout_buffer[partner_idx].get( + self.batch_size + ): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() - # Re-sample the noise matrix because the log_std has changed - # TODO: investigate why there is no issue with the gradient + # Re-sample the noise matrix because the log_std changed + # investigate why there is no issue with the gradient # if that line is commented (as in SAC) if self.use_sde: self.policy.reset_noise(self.batch_size) - #values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) - values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions, partner_idx=partner_idx) + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + partner_idx=partner_idx, + ) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) - # ratio between old and new policy, should be one at the first iteration - ratio = th.exp(log_prob - rollout_data.old_log_prob) + # ratio between old and new policy, should be + # one at the first iteration + ratio = torch.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss policy_loss_1 = advantages * ratio - policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + policy_loss_2 = advantages * torch.clamp( + ratio, 1 - clip_range, 1 + clip_range + ) + policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fraction = torch.mean( + (torch.abs(ratio - 1) > clip_range).float() + ).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: @@ -279,8 +327,10 @@ def train(self) -> None: else: # Clip the different between old and new value # NOTE: this depends on the reward scaling - values_pred = rollout_data.old_values + th.clamp( - values - rollout_data.old_values, -clip_range_vf, clip_range_vf + values_pred = rollout_data.old_values + torch.clamp( + values - rollout_data.old_values, + -clip_range_vf, + clip_range_vf, ) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) @@ -291,81 +341,118 @@ def train(self) -> None: # Approximate entropy when no analytical form entropy_loss = log_prob.mean() else: - entropy_loss = -th.mean(entropy) + entropy_loss = -torch.mean(entropy) entropy_losses.append(entropy_loss.item()) ########### # Marginal Regularization ########### - # each action_dist is a Distribution object containing self.batch_size observations - # dist.distribution.probs returns a tensor of shape (self.batch_size, self.action_space) + # each action_dist is a Distribution object containing + # self.batch_size observations dist.distribution.probs + # returns shape (self.batch_size, self.action_space) # dist.sample() returns a tensor of shape (self.batch_size) - # careful: must extract torch distribution object from stable_baseline Distribution object, otherwise old references get overwritten - main_logits, partner_logits = zip( *[self.policy.get_action_logits_from_obs(rollout_data.observations, partner_idx=idx) for idx in range(self.policy.num_partners)] ) - main_logits = th.stack([logits for logits in main_logits]) # (num_partners, self.batch_size, self.action_space) - partner_logits = th.stack([logits for logits in partner_logits]) # (num_partners, self.batch_size, self.action_space) + # careful: must extract torch distribution object from + # stable_baseline Distribution object, otherwise old + # references get overwritten + main_logits, partner_logits = zip( + *[ + self.policy.get_action_logits_from_obs( + rollout_data.observations, partner_idx=idx + ) + for idx in range(self.policy.num_partners) + ] + ) + main_logits = torch.stack( + list(main_logits) + ) # (num_partners, self.batch_size, self.action_space) + partner_logits = torch.stack( + list(partner_logits) + ) # (num_partners, self.batch_size, self.action_space) composed_logits = main_logits + partner_logits # Regularize main prob to be the marginals - # Wasserstein metric with unitary distances (for categorical actions) - main_probs = th.mean( th.exp(main_logits - main_logits.logsumexp(dim=-1, keepdim=True)), dim=0 ) - composed_probs = th.mean( th.exp(composed_logits - composed_logits.logsumexp(dim=-1, keepdim=True)), dim=0 ) - marginal_regularization_loss = th.mean(th.sum( th.abs(main_probs - composed_probs), dim=1)) + # Wasserstein metric with unitary distances + # (for categorical actions) + main_probs = torch.mean( + torch.exp( + main_logits + - main_logits.logsumexp(dim=-1, keepdim=True) + ), + dim=0, + ) + composed_probs = torch.mean( + torch.exp( + composed_logits + - composed_logits.logsumexp(dim=-1, keepdim=True) + ), + dim=0, + ) + marginal_regularization_loss = torch.mean( + torch.sum(torch.abs(main_probs - composed_probs), dim=1) + ) ########### - loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + self.marginal_reg_coef * marginal_regularization_loss + loss = ( + policy_loss + + self.ent_coef * entropy_loss + + self.vf_coef * value_loss + + self.marginal_reg_coef * marginal_regularization_loss + ) # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + torch.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.max_grad_norm + ) self.policy.optimizer.step() - approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy()) + approx_kl_divs.append( + torch.mean(rollout_data.old_log_prob - log_prob) + .detach() + .cpu() + .numpy() + ) all_kl_divs.append(np.mean(approx_kl_divs)) - if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl: - print(f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}") + if ( + self.target_kl is not None + and np.mean(approx_kl_divs) > 1.5 * self.target_kl + ): + print( + f"Early stopping at step {epoch} due to reaching \ + max kl: {np.mean(approx_kl_divs):.2f}" + ) break self._n_updates += self.n_epochs - # explained_var = explained_variance(self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten()) - + # Logs self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) - # self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) - # self.logger.record("train/clip_fraction", np.mean(clip_fractions)) - # self.logger.record("train/loss", loss.item()) - # self.logger.record("train/explained_variance", explained_var) - # if hasattr(self.policy, "log_std"): - # self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) - - # self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - # self.logger.record("train/clip_range", clip_range) - # if self.clip_range_vf is not None: - # self.logger.record("train/clip_range_vf", clip_range_vf) def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, tb_log_name: str = "OnPolicyAlgorithm", - eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False ) -> "OnPolicyAlgorithm": iteration = 0 + self.env.envs[0].set_resample_policy("null") total_timesteps, callback = self._setup_learn( - total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar ) callback.on_training_start(locals(), globals()) @@ -373,28 +460,56 @@ def learn( while self.num_timesteps < total_timesteps: for partner_idx in range(self.policy.num_partners): - try: self.env.envs[0].set_partnerid(partner_idx) - except: - print("unable to switch") - pass - continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer[partner_idx], n_rollout_steps=self.n_steps, partner_idx=partner_idx) + self.env.envs[0].set_partnerid(partner_idx) + continue_training = self.collect_rollouts( + self.env, + callback, + self.rollout_buffer[partner_idx], + n_rollout_steps=self.n_steps, + partner_idx=partner_idx, + ) if continue_training is False: break iteration += 1 - self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + self._update_current_progress_remaining( + self.num_timesteps, total_timesteps + ) # Display training infos if log_interval is not None and iteration % log_interval == 0: fps = int(self.num_timesteps / (time.time() - self.start_time)) - self.logger.record("time/iterations", iteration, exclude="tensorboard") - if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: - self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) - self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record( + "time/iterations", iteration, exclude="tensorboard" + ) + if ( + len(self.ep_info_buffer) > 0 + and len(self.ep_info_buffer[0]) > 0 + ): + self.logger.record( + "rollout/ep_rew_mean", + safe_mean( + [ep_info["r"] for ep_info in self.ep_info_buffer] + ), + ) + self.logger.record( + "rollout/ep_len_mean", + safe_mean( + [ep_info["l"] for ep_info in self.ep_info_buffer] + ), + ) self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") - self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.record( + "time/time_elapsed", + int(time.time() - self.start_time), + exclude="tensorboard", + ) + self.logger.record( + "time/total_timesteps", + self.num_timesteps, + exclude="tensorboard", + ) self.logger.dump(step=self.num_timesteps) self.train() diff --git a/src/pantheonrl/algos/modular/policies.py b/src/pantheonrl/algos/modular/policies.py index 7a85a37..8242de4 100644 --- a/src/pantheonrl/algos/modular/policies.py +++ b/src/pantheonrl/algos/modular/policies.py @@ -1,25 +1,32 @@ -from abc import ABC, abstractmethod -import collections +""" +Implementation of the policy for the ModularAlgorithm +""" +# pylint: disable=locally-disabled, no-value-for-parameter, unexpected-keyword-arg from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable from functools import partial import gymnasium as gym -import torch as th -import torch.nn as nn +import torch +from torch import nn import numpy as np -from stable_baselines3 import PPO -from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim -from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp, - NatureCNN, MlpExtractor) -from stable_baselines3.common.utils import get_device, is_vectorized_observation -from stable_baselines3.common.vec_env import VecTransposeImage -from stable_baselines3.common.distributions import (make_proba_distribution, Distribution, - DiagGaussianDistribution, CategoricalDistribution, - MultiCategoricalDistribution, BernoulliDistribution, - StateDependentNoiseDistribution) +from stable_baselines3.common.torch_layers import ( + FlattenExtractor, + BaseFeaturesExtractor, + MlpExtractor, +) +from stable_baselines3.common.distributions import ( + make_proba_distribution, + Distribution, + DiagGaussianDistribution, + CategoricalDistribution, + MultiCategoricalDistribution, + BernoulliDistribution, + StateDependentNoiseDistribution, +) from stable_baselines3.common.policies import BasePolicy + class ModularPolicy(BasePolicy): """ Policy class for actor-critic algorithms (has both policy and value prediction). @@ -28,7 +35,7 @@ class ModularPolicy(BasePolicy): :param action_space: (gym.spaces.Space) Action space :param lr_schedule: (Callable) Learning rate schedule (could be constant) :param net_arch: ([int or dict]) The specification of the policy and value networks. - :param device: (str or th.device) Device on which the code should run. + :param device: (str or torch.device) Device on which the code should run. :param activation_fn: (Type[nn.Module]) Activation function :param ortho_init: (bool) Whether to use or not orthogonal initialization :param use_sde: (bool) Whether to use State Dependent Exploration or not @@ -48,78 +55,88 @@ class ModularPolicy(BasePolicy): to pass to the feature extractor. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, - ``th.optim.Adam`` by default + :param optimizer_class: (Type[torch.optim.Optimizer]) The optimizer to use, + ``torch.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ - def __init__(self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable[[float], float], - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.Tanh, - ortho_init: bool = True, - use_sde: bool = False, - log_std_init: float = 0.0, - full_std: bool = True, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - - # my additional arguments - num_partners: int = 1, - partner_net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, # net arch for each partner-specific module - baseline: bool = False, - nomain: bool = False, - ): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable[[float], float], + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[ + BaseFeaturesExtractor + ] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + # my additional arguments + num_partners: int = 1, + partner_net_arch: Optional[ + List[Union[int, Dict[str, List[int]]]] + ] = None, # net arch for each partner-specific module + baseline: bool = False, + nomain: bool = False, + ): if optimizer_kwargs is None: optimizer_kwargs = {} # Small values to avoid NaN in Adam optimizer - if optimizer_class == th.optim.Adam: - optimizer_kwargs['eps'] = 1e-5 - - super(ModularPolicy, self).__init__(observation_space, - action_space, - features_extractor_class, - features_extractor_kwargs, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - squash_output=squash_output) + if optimizer_class == torch.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output, + ) self.num_partners = num_partners - print("CUDA: ", th.cuda.is_available()) + print("CUDA: ", torch.cuda.is_available()) if partner_net_arch is None: if features_extractor_class == FlattenExtractor: - partner_net_arch = [dict(pi=[64, 64], vf=[64, 64])] + partner_net_arch = {"pi":[64, 64], "vf":[64, 64]} else: partner_net_arch = [] self.partner_net_arch = partner_net_arch self.baseline = baseline self.nomain = nomain - # Default network architecture, from stable-baselines if net_arch is None: if features_extractor_class == FlattenExtractor: - net_arch = [dict(pi=[64, 64], vf=[64, 64])] + net_arch = {"pi":[64, 64], "vf":[64, 64]} else: net_arch = [] + + self.value_net = None + self.log_std = None + self.action_net = None + self.mlp_extractor = None self.net_arch = net_arch self.activation_fn = activation_fn self.ortho_init = ortho_init - self.features_extractor = features_extractor_class(self.observation_space, - **self.features_extractor_kwargs) + self.features_extractor = features_extractor_class( + self.observation_space, **self.features_extractor_kwargs + ) self.features_dim = self.features_extractor.features_dim self.normalize_images = normalize_images @@ -128,10 +145,10 @@ def __init__(self, # Keyword arguments for gSDE distribution if use_sde: dist_kwargs = { - 'full_std': full_std, - 'squash_output': squash_output, - 'use_expln': use_expln, - 'learn_features': sde_net_arch is not None + "full_std": full_std, + "squash_output": squash_output, + "use_expln": use_expln, + "learn_features": sde_net_arch is not None, } self.sde_features_extractor = None @@ -140,85 +157,86 @@ def __init__(self, self.dist_kwargs = dist_kwargs # Action distribution - self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) + self.action_dist = make_proba_distribution( + action_space, use_sde=use_sde, dist_kwargs=dist_kwargs + ) self.lr_schedule = lr_schedule self._build(self.lr_schedule) - # freeze / unfreeze the module networks def set_freeze_module(self, module, freeze): + """ freeze / unfreeze the module networks """ for param in module.parameters(): param.requires_grad = not freeze + def set_freeze_main(self, freeze): + """ freeze / unfreeze main modules """ self.set_freeze_module(self.mlp_extractor, freeze) self.set_freeze_module(self.action_net, freeze) self.set_freeze_module(self.value_net, freeze) + def set_freeze_partner(self, freeze): + """ freeze / unfreeze partner modules """ for partner_idx in range(self.num_partners): - self.set_freeze_module(self.partner_mlp_extractor[partner_idx], freeze) - self.set_freeze_module(self.partner_action_net[partner_idx], freeze) + self.set_freeze_module( + self.partner_mlp_extractor[partner_idx], freeze + ) + self.set_freeze_module( + self.partner_action_net[partner_idx], freeze + ) self.set_freeze_module(self.partner_value_net[partner_idx], freeze) - def _get_data(self) -> Dict[str, Any]: - data = super()._get_data() - - default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) - - data.update(dict( - net_arch=self.net_arch, - activation_fn=self.activation_fn, - use_sde=self.use_sde, - log_std_init=self.log_std_init, - squash_output=default_none_kwargs['squash_output'], - full_std=default_none_kwargs['full_std'], - sde_net_arch=default_none_kwargs['sde_net_arch'], - use_expln=default_none_kwargs['use_expln'], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - ortho_init=self.ortho_init, - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs - )) - return data - - def reset_noise(self, n_envs: int = 1) -> None: - """ - Sample new weights for the exploration matrix. - :param n_envs: (int) - """ - assert isinstance(self.action_dist, - StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE' - self.action_dist.sample_weights(self.log_std, batch_size=n_envs) - - def make_action_dist_net(self, latent_dim_pi: int, latent_sde_dim: int = 0): + def make_action_dist_net( + self, latent_dim_pi: int, latent_sde_dim: int = 0 + ): + """ Make the action distribution network """ action_net, log_std = None, None if isinstance(self.action_dist, DiagGaussianDistribution): - action_net, log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi, - log_std_init=self.log_std_init) + action_net, log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, log_std_init=self.log_std_init + ) elif isinstance(self.action_dist, StateDependentNoiseDistribution): - latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim - action_net, log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi, - latent_sde_dim=latent_sde_dim, - log_std_init=self.log_std_init) + latent_sde_dim = ( + latent_dim_pi if self.sde_net_arch is None else latent_sde_dim + ) + action_net, log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, + log_std_init=self.log_std_init, + ) elif isinstance(self.action_dist, CategoricalDistribution): - action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + action_net = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi + ) elif isinstance(self.action_dist, MultiCategoricalDistribution): - action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + action_net = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi + ) elif isinstance(self.action_dist, BernoulliDistribution): - action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + action_net = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi + ) else: - raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") + raise NotImplementedError( + f"Unsupported distribution '{self.action_dist}'." + ) return action_net, log_std def build_mlp_action_value_net(self, input_dim, net_arch): - mlp_extractor = MlpExtractor(input_dim, net_arch=net_arch, - activation_fn=self.activation_fn, device=self.device) - action_net, log_std = self.make_action_dist_net(mlp_extractor.latent_dim_pi) + """ Build the action and value networks """ + mlp_extractor = MlpExtractor( + input_dim, + net_arch=net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + action_net, log_std = self.make_action_dist_net( + mlp_extractor.latent_dim_pi + ) value_net = nn.Linear(mlp_extractor.latent_dim_vf, 1) return mlp_extractor, action_net, log_std, value_net def do_init_weights(self, init_main=False, init_partner=False): + """ Initialize the weights """ # Values from stable-baselines. # feature_extractor/mlp values are # originally from openai/baselines (default gains/init_scales). @@ -250,55 +268,92 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None: # net_arch here is an empty list and mlp_extractor does not # really contain any layers (acts like an identity module). - self.mlp_extractor, self.action_net, self.log_std, self.value_net = self.build_mlp_action_value_net(input_dim=self.features_dim, net_arch=self.net_arch) - - partner_builds = [self.build_mlp_action_value_net(input_dim=self.mlp_extractor.latent_dim_pi, net_arch=self.partner_net_arch) for _ in range(self.num_partners)] - if self.baseline: # use the same partner module for all partners + ( + self.mlp_extractor, + self.action_net, + self.log_std, + self.value_net, + ) = self.build_mlp_action_value_net( + input_dim=self.features_dim, net_arch=self.net_arch + ) + + partner_builds = [ + self.build_mlp_action_value_net( + input_dim=self.mlp_extractor.latent_dim_pi, + net_arch=self.partner_net_arch, + ) + for _ in range(self.num_partners) + ] + if self.baseline: # use the same partner module for all partners print("Baseline architecture: using the same partner module.") partner_builds = [partner_builds[0]] * self.num_partners - self.partner_mlp_extractor, self.partner_action_net, self.partner_log_std, self.partner_value_net = zip(*partner_builds) + ( + self.partner_mlp_extractor, + self.partner_action_net, + self.partner_log_std, + self.partner_value_net, + ) = zip(*partner_builds) self.partner_mlp_extractor = nn.ModuleList(self.partner_mlp_extractor) self.partner_action_net = nn.ModuleList(self.partner_action_net) self.partner_value_net = nn.ModuleList(self.partner_value_net) # Setup optimizer with initial learning rate - self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.optimizer = self.optimizer_class( + self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs + ) self.do_init_weights(init_main=True, init_partner=True) def overwrite_main(self, other): - self.mlp_extractor, self.action_net, self.log_std, self.value_net = other.mlp_extractor, other.action_net, other.log_std, other.value_net - self.optimizer = self.optimizer_class(self.parameters(), lr=self.lr_schedule(1), **self.optimizer_kwargs) - - def forward(self, obs: th.Tensor, - partner_idx: int, - deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ Overwrite the main weights """ + self.mlp_extractor, self.action_net, self.log_std, self.value_net = ( + other.mlp_extractor, + other.action_net, + other.log_std, + other.value_net, + ) + self.optimizer = self.optimizer_class( + self.parameters(), lr=self.lr_schedule(1), **self.optimizer_kwargs + ) + + def forward( + self, obs: torch.Tensor, partner_idx: int, deterministic: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass in all the networks (actor and critic) - :param obs: (th.Tensor) Observation + :param obs: (torch.Tensor) Observation :param deterministic: (bool) Whether to sample or use deterministic actions - :return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) action, value and log probability of the action + :return: (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) action, value + and log probability of the action """ latent_pi, latent_vf, _ = self._get_latent(obs=obs) - partner_latent_pi, partner_latent_vf = self.partner_mlp_extractor[partner_idx](latent_pi) + partner_latent_pi, partner_latent_vf = self.partner_mlp_extractor[ + partner_idx + ](latent_pi) - distribution = self._get_action_dist_from_latent(latent_pi, partner_latent_pi, partner_idx=partner_idx) + distribution = self._get_action_dist_from_latent( + latent_pi, partner_latent_pi, partner_idx=partner_idx + ) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) - values = self.value_net(latent_vf) + self.partner_value_net[partner_idx](partner_latent_vf) + values = self.value_net(latent_vf) + self.partner_value_net[ + partner_idx + ](partner_latent_vf) return actions, values, log_prob - def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def _get_latent( + self, obs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Get the latent code (i.e., activations of the last layer of each network) for the different networks. - :param obs: (th.Tensor) Observation - :return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) Latent codes + :param obs: (torch.Tensor) Observation + :return: (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) Latent codes for the actor, the value function and for gSDE function """ # Preprocess the observation if needed - features = self.extract_features(obs) + features = self.extract_features(obs, self.features_extractor) latent_pi, latent_vf = self.mlp_extractor(features) # Features for sde @@ -308,89 +363,136 @@ def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: return latent_pi, latent_vf, latent_sde - def _get_action_dist_from_latent(self, latent_pi: th.Tensor, - partner_latent_pi: th.Tensor, - partner_idx: int, - latent_sde: Optional[th.Tensor] = None, - action_mask: Optional[th.Tensor] = None) -> Distribution: + def _get_action_dist_from_latent( + self, + latent_pi: torch.Tensor, + partner_latent_pi: torch.Tensor, + partner_idx: int, + latent_sde: Optional[torch.Tensor] = None, + action_mask: Optional[torch.Tensor] = None, + ) -> Distribution: """ Retrieve action distribution given the latent codes. - :param latent_pi: (th.Tensor) Latent code for the actor - :param latent_sde: (Optional[th.Tensor]) Latent code for the gSDE exploration function + :param latent_pi: (torch.Tensor) Latent code for the actor + :param latent_sde: (Optional[torch.Tensor]) Latent code for the gSDE exploration function :return: (Distribution) Action distribution """ main_logits = self.action_net(latent_pi) - partner_logits = self.partner_action_net[partner_idx](partner_latent_pi) + partner_logits = self.partner_action_net[partner_idx]( + partner_latent_pi + ) if self.nomain: mean_actions = partner_logits else: mean_actions = main_logits + partner_logits - + large_exponent = 30 if action_mask is not None: action_mask = action_mask.to(mean_actions.device) - mean_actions = mean_actions - large_exponent*(~action_mask) - th.clamp(mean_actions, min=-1*large_exponent) + mean_actions = mean_actions - large_exponent * (~action_mask) + torch.clamp(mean_actions, min=-1 * large_exponent) if isinstance(self.action_dist, DiagGaussianDistribution): log_std = self.log_std + self.partner_log_std[partner_idx] return self.action_dist.proba_distribution(mean_actions, log_std) - elif isinstance(self.action_dist, CategoricalDistribution): + if isinstance(self.action_dist, CategoricalDistribution): # Here mean_actions are the logits before the softmax - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, MultiCategoricalDistribution): + return self.action_dist.proba_distribution( + mean_actions + ) + if isinstance(self.action_dist, MultiCategoricalDistribution): # Here mean_actions are the flattened logits - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, BernoulliDistribution): + return self.action_dist.proba_distribution( + mean_actions + ) + if isinstance(self.action_dist, BernoulliDistribution): # Here mean_actions are the logits (before rounding to get the binary actions) - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, StateDependentNoiseDistribution): + return self.action_dist.proba_distribution( + mean_actions + ) + if isinstance(self.action_dist, StateDependentNoiseDistribution): log_std = self.log_std + self.partner_log_std[partner_idx] - return self.action_dist.proba_distribution(mean_actions, log_std, latent_sde) - else: - raise ValueError('Invalid action distribution') - - def _predict(self, observation: th.Tensor, partner_idx: int, deterministic: bool = False) -> th.Tensor: + return self.action_dist.proba_distribution( + mean_actions=mean_actions, + log_std=log_std, + latent_sde=latent_sde + ) + raise ValueError("Invalid action distribution") + + def _predict( + self, + observation: torch.Tensor, + deterministic: bool = False, + partner_idx: int = 0, + ) -> torch.Tensor: """ Get the action according to the policy for a given observation. - :param observation: (th.Tensor) + :param observation: (torch.Tensor) :param deterministic: (bool) Whether to use stochastic or deterministic actions - :return: (th.Tensor) Taken action according to the policy + :return: (torch.Tensor) Taken action according to the policy """ - actions, _, _ = self.forward(obs=observation, partner_idx=partner_idx, deterministic=deterministic) + actions, _, _ = self.forward( + obs=observation, + partner_idx=partner_idx, + deterministic=deterministic, + ) return actions - def evaluate_actions(self, obs: th.Tensor, - actions: th.Tensor, - partner_idx: int, - action_mask: Optional[th.Tensor] = None) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def evaluate_actions( + self, + obs: torch.Tensor, + actions: torch.Tensor, + partner_idx: int, + action_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Evaluate actions according to the current policy, given the observations. - :param obs: (th.Tensor) - :param actions: (th.Tensor) - :return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions - and entropy of the action distribution. + :param obs: (torch.Tensor) + :param actions: (torch.Tensor) + :return: (torch.Tensor, torch.Tensor, torch.Tensor) estimated value, log likelihood of + taking those actions and entropy of the action distribution. """ latent_pi, latent_vf, _ = self._get_latent(obs=obs) - partner_latent_pi, partner_latent_vf = self.partner_mlp_extractor[partner_idx](latent_pi) - - distribution = self._get_action_dist_from_latent(latent_pi, partner_latent_pi, partner_idx=partner_idx, action_mask=action_mask) + partner_latent_pi, partner_latent_vf = self.partner_mlp_extractor[ + partner_idx + ](latent_pi) + + distribution = self._get_action_dist_from_latent( + latent_pi, + partner_latent_pi, + partner_idx=partner_idx, + action_mask=action_mask, + ) log_prob = distribution.log_prob(actions) - values = self.value_net(latent_vf) + self.partner_value_net[partner_idx](partner_latent_vf) + values = self.value_net(latent_vf) + self.partner_value_net[ + partner_idx + ](partner_latent_vf) return values, log_prob, distribution.entropy() - def get_action_logits_from_obs(self, obs: th.Tensor, partner_idx: int, action_mask: Optional[th.Tensor] = None) -> th.Tensor: + def get_action_logits_from_obs( + self, + obs: torch.Tensor, + partner_idx: int, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ Get the action logits from the observation """ latent_pi, _, _ = self._get_latent(obs=obs) - partner_latent_pi, _ = self.partner_mlp_extractor[partner_idx](latent_pi) + partner_latent_pi, _ = self.partner_mlp_extractor[partner_idx]( + latent_pi + ) main_logits = self.action_net(latent_pi) - partner_logits = self.partner_action_net[partner_idx](partner_latent_pi) - - if action_mask: - main_logits = main_logits * action_mask # set masked out options to 0 + partner_logits = self.partner_action_net[partner_idx]( + partner_latent_pi + ) + + if action_mask: + main_logits = ( + main_logits * action_mask + ) # set masked out options to 0 partner_logits = partner_logits * action_mask return main_logits, partner_logits diff --git a/src/pantheonrl/common/multiagentenv.py b/src/pantheonrl/common/multiagentenv.py index 924be08..fd40f4f 100644 --- a/src/pantheonrl/common/multiagentenv.py +++ b/src/pantheonrl/common/multiagentenv.py @@ -341,6 +341,9 @@ def resample_random(self) -> None: self.np_random.integers(0, len(plist)) for plist in self.partners ] + def resample_null(self) -> None: + """Do not resample each partner policy""" + def resample_round_robin(self) -> None: """ Sets the partner policy to the next option on the list for round-robin @@ -355,7 +358,7 @@ def set_resample_policy(self, resample_policy: str) -> None: Set the resample_partner method to round "robin" or "random" :param resample_policy: The new resampling policy to use. - Valid values are: "default", "robin", "random" + Valid values are: "default", "robin", "random", or "null" """ if resample_policy == "default": resample_policy = "robin" if self.n_players == 2 else "random" @@ -369,6 +372,8 @@ def set_resample_policy(self, resample_policy: str) -> None: self.resample_partner = self.resample_round_robin elif resample_policy == "random": self.resample_partner = self.resample_random + elif resample_policy == "null": + self.resample_partner = self.resample_null else: raise PlayerException( f"Invalid resampling policy: {resample_policy}" diff --git a/src/pantheonrl/common/trajsaver.py b/src/pantheonrl/common/trajsaver.py index edc3377..8cf6951 100644 --- a/src/pantheonrl/common/trajsaver.py +++ b/src/pantheonrl/common/trajsaver.py @@ -28,7 +28,7 @@ def transitions_collate_fn( Use this as the `collate_fn` argument to `DataLoader` if using an instance of `TransitionsMinimal` as the `dataset` argument. """ - batch_no_infos = [sample.items() for sample in batch] + batch_no_infos = list(batch) result = default_collate(batch_no_infos) assert isinstance(result, dict) diff --git a/tests/test_adap.py b/tests/test_adap.py new file mode 100644 index 0000000..beadec6 --- /dev/null +++ b/tests/test_adap.py @@ -0,0 +1,43 @@ +import pytest + +from stable_baselines3 import PPO + +import gymnasium as gym + +import overcookedgym + +from pantheonrl.algos.adap.adap_learn import ADAP +from pantheonrl.algos.adap.policies import AdapPolicy, AdapPolicyMult +from pantheonrl.algos.adap.agent import AdapAgent + + +def make_env(option): + if option == 0: + env = gym.make('OvercookedMultiEnv-v0', layout_name='simple') + elif option == 1: + env = gym.make('RPS-v0') + elif option == 2: + env = gym.make('LiarsDice-v0') + env.np_random, _ = gym.utils.seeding.np_random(0) + return env + + +def run_standard(ALGO, timesteps, option, n_steps): + env = make_env(option) + ego = ALGO(AdapPolicy, env, n_steps=n_steps, verbose=0) + env.unwrapped.ego_ind = 0 + partner = AdapAgent(ALGO(AdapPolicy, env, n_steps=n_steps, verbose=0), latent_syncer=ego) + env.unwrapped.add_partner_agent(partner) + + ego.learn(total_timesteps=timesteps) + +@pytest.mark.timeout(60) +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::UserWarning") +@pytest.mark.parametrize("ALGO", [ADAP]) +@pytest.mark.parametrize("epochs", [20]) +@pytest.mark.parametrize("option", [0]) +@pytest.mark.parametrize("n_steps", [40]) +def test_onpolicy(ALGO, epochs, option, n_steps): + run_standard(ALGO, n_steps * epochs, option, n_steps) + diff --git a/tests/test_bc.py b/tests/test_bc.py new file mode 100644 index 0000000..d6c6b6d --- /dev/null +++ b/tests/test_bc.py @@ -0,0 +1,71 @@ +import pytest + +import gymnasium as gym +from stable_baselines3 import PPO + +import overcookedgym + +from pantheonrl import OnPolicyAgent, StaticPolicyAgent +from pantheonrl.common.agents import RecordingAgentWrapper +from pantheonrl.algos.bc import BC + + +def make_env(option): + if option == 0: + env = gym.make('OvercookedMultiEnv-v0', layout_name='simple') + elif option == 1: + env = gym.make('RPS-v0') + elif option == 2: + env = gym.make('LiarsDice-v0') + elif option == 3: + env = gym.make('BlockEnv-v0') + elif option == 4: + env = gym.make('BlockEnv-v1') + env.np_random, _ = gym.utils.seeding.np_random(0) + return env + + +def run_standard(ALGO, timesteps, option, n_steps): + env = make_env(option) + ego = ALGO('MlpPolicy', env, n_steps=n_steps, verbose=0) + partner = RecordingAgentWrapper(OnPolicyAgent(PPO('MlpPolicy', env.unwrapped.get_dummy_env(1), verbose=0, n_steps=64))) + env.unwrapped.ego_ind = 0 + env.unwrapped.add_partner_agent(partner) + + ego.learn(total_timesteps=timesteps) + return ego.policy, partner.get_transitions() + + +def do_bc(option, data): + full_env = make_env(option) + env = full_env.unwrapped.get_dummy_env(1) + clone = BC(observation_space=env.observation_space, + action_space=env.action_space, + expert_data=data, + l2_weight=0.2) + + clone.train(n_epochs=10) + return clone + + +def do_test_standard(ALGO, timesteps, option, n_steps, clone): + env = make_env(option) + ego = ALGO('MlpPolicy', env, n_steps=n_steps, verbose=0) + partner = StaticPolicyAgent(clone.policy) + env.unwrapped.ego_ind = 0 + env.unwrapped.add_partner_agent(partner) + + ego.learn(total_timesteps=timesteps) + + +@pytest.mark.timeout(60) +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::UserWarning") +@pytest.mark.parametrize("ALGO", [PPO]) +@pytest.mark.parametrize("epochs", [1]) +@pytest.mark.parametrize("option", [0, 1, 2, 3, 4]) +@pytest.mark.parametrize("n_steps", [400]) +def test_onpolicy(ALGO, epochs, option, n_steps): + model1, rb1 = run_standard(ALGO, n_steps * epochs, option, n_steps) + clone = do_bc(option, rb1) + do_test_standard(ALGO, n_steps * epochs, option, n_steps, clone) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index fbc0120..9cce43a 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -51,3 +51,4 @@ def test_DQN(env_name): ego.learn(total_timesteps=1000) except Exception as e: assert False, f"Exception raised on {env_name}: {e}" + diff --git a/tests/test_modular.py b/tests/test_modular.py new file mode 100644 index 0000000..2e7bb83 --- /dev/null +++ b/tests/test_modular.py @@ -0,0 +1,72 @@ +import pytest + +from stable_baselines3 import PPO + +import gymnasium as gym + +import overcookedgym + +from pantheonrl.algos.modular.learn import ModularAlgorithm +from pantheonrl.algos.modular.policies import ModularPolicy +from pantheonrl import Agent, OnPolicyAgent +from pantheonrl.algos.bc import ConstantLRSchedule + +class CounterAgent(Agent): + def __init__(self, agent, idx): + self.agent = agent + self.steps = 0 + self.idx = idx + + def get_action(self, obs): + self.steps += 1 + toreturn = self.agent.get_action(obs) + # print("ACTION is", toreturn) + return toreturn + + def update(self, reward, done): + self.agent.update(reward, done) + +def make_env(option): + if option == 0: + env = gym.make('OvercookedMultiEnv-v0', layout_name='simple') + elif option == 1: + env = gym.make('RPS-v0') + elif option == 2: + env = gym.make('LiarsDice-v0') + elif option == 3: + env = gym.make('BlockEnv-v0') + elif option == 4: + env = gym.make('BlockEnv-v1') + env.np_random, _ = gym.utils.seeding.np_random(0) + env.unwrapped.set_resample_policy("null") + return env + + +def run_standard(ALGO, timesteps, option, n_steps): + env = make_env(option) + pkwargs = {"num_partners":8} + ego = ModularAlgorithm(ModularPolicy, env, n_steps=n_steps, verbose=0, policy_kwargs=pkwargs) + env.unwrapped.ego_ind = 0 + for i in range(12): + partner = CounterAgent(OnPolicyAgent(ALGO('MlpPolicy', env.unwrapped.get_dummy_env(1), verbose=0, n_steps=64)), i) + env.unwrapped.add_partner_agent(partner) + + ego.learn(total_timesteps=timesteps) + + print([env.unwrapped.partners[0][i].steps for i in range(12)]) + for i in range(12): + if i < 8: + assert env.unwrapped.partners[0][i].steps > 0 + else: + assert env.unwrapped.partners[0][i].steps == 0 + +@pytest.mark.timeout(60) +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.filterwarnings("ignore::UserWarning") +@pytest.mark.parametrize("ALGO", [PPO]) +@pytest.mark.parametrize("epochs", [20]) +@pytest.mark.parametrize("option", [0, 1, 2, 3, 4]) +@pytest.mark.parametrize("n_steps", [40]) +def test_onpolicy(ALGO, epochs, option, n_steps): + run_standard(ALGO, n_steps * epochs, option, n_steps) + diff --git a/tests/test_overcooked.py b/tests/test_overcooked.py index 1e9e197..08b3c1d 100644 --- a/tests/test_overcooked.py +++ b/tests/test_overcooked.py @@ -56,3 +56,4 @@ def test_DQN(env_name): ego.learn(total_timesteps=1000) except Exception as e: assert False, f"Exception raised on {env_name}: {e}" + diff --git a/tests/test_pettingzoo.py b/tests/test_pettingzoo.py index 5a73d38..13023a4 100644 --- a/tests/test_pettingzoo.py +++ b/tests/test_pettingzoo.py @@ -60,3 +60,4 @@ def test_PPO(option): ego.learn(total_timesteps=128) except Exception as e: assert False, f"Exception raised on {option}: {e}" + diff --git a/tests/test_sarl_reproducibility.py b/tests/test_sarl_reproducibility.py index d39ae6a..1a29e41 100644 --- a/tests/test_sarl_reproducibility.py +++ b/tests/test_sarl_reproducibility.py @@ -233,11 +233,12 @@ def learn_thread(): @pytest.mark.parametrize("option", [0, 1]) @pytest.mark.parametrize("n_steps", [10, 100, 1000]) def test_dqn(ALGO, epochs, option, n_steps): + init_count = threading.active_count() model1 = run_standard_dqn(ALGO, n_steps * epochs, option, n_steps) model2 = run_reversed_dqn(ALGO, n_steps * epochs, option, n_steps) assert check_equivalent_models(model1, model2), "NOT IDENTICAL MODELS" - assert threading.active_count() == 1, "DID NOT KILL THREADS" + assert threading.active_count() == init_count, "DID NOT KILL THREADS" @pytest.mark.timeout(60) @@ -248,11 +249,12 @@ def test_dqn(ALGO, epochs, option, n_steps): @pytest.mark.parametrize("option", [0, 1, 2, 3, 4]) @pytest.mark.parametrize("n_steps", [10, 100, 1000]) def test_sarl(ALGO, epochs, option, n_steps): + init_count = threading.active_count() model1, rb1 = run_standard(ALGO, n_steps * epochs, option, n_steps) model2, rb2 = run_reversed(ALGO, n_steps * epochs, option, n_steps) assert check_equivalent_models(model1, model2), "NOT IDENTICAL MODELS" - assert threading.active_count() == 1, "DID NOT KILL THREADS" + assert threading.active_count() == init_count, "DID NOT KILL THREADS" # def printifdiff(r1, r2, val):