From 6612eafbd86b761c1296ad276f79288db1198897 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Sat, 30 Nov 2024 13:12:32 -0800 Subject: [PATCH] Added a few features to make rl_games more comparable to the brax with jax ppo (#314) * Updated * cleanup and simplenet --------- Co-authored-by: Denys Makoviichuk --- rl_games/algos_torch/model_builder.py | 2 + rl_games/algos_torch/models.py | 131 ++++++++++++++++++++++++++ rl_games/common/a2c_common.py | 26 +++-- rl_games/common/datasets.py | 101 +++++++++++--------- rl_games/common/player.py | 101 ++++++++++++++------ rl_games/envs/__init__.py | 3 +- rl_games/envs/test_network.py | 47 +++++++++ 7 files changed, 331 insertions(+), 80 deletions(-) diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index c2045c5e..359bdda9 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -46,6 +46,8 @@ def __init__(self): lambda network, **kwargs: models.ModelSACContinuous(network)) self.model_factory.register_builder('central_value', lambda network, **kwargs: models.ModelCentralValue(network)) + self.model_factory.register_builder('continuous_a2c_tanh', + lambda network, **kwargs: models.ModelA2CContinuousTanh(network)) self.network_builder = NetworkBuilder() def get_network_builder(self): diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 93d4001d..e57ff4d1 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -9,6 +9,8 @@ from rl_games.algos_torch.sac_helper import SquashedNormal from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs from rl_games.algos_torch.moving_mean_std import GeneralizedMovingStats +from torch.distributions import Normal, TransformedDistribution, TanhTransform +import math class BaseModel(): def __init__(self, model_class): @@ -251,6 +253,7 @@ def forward(self, input_dict): return result + class ModelA2CContinuousLogStd(BaseModel): def __init__(self, network): BaseModel.__init__(self, 'a2c') @@ -310,6 +313,59 @@ def neglogp(self, x, mean, std, logstd): + 0.5 * np.log(2.0 * np.pi) * x.size()[-1] \ + logstd.sum(dim=-1) +class ModelA2CContinuousTanh(BaseModel): + def __init__(self, network): + BaseModel.__init__(self, 'a2c') + self.network_builder = network + + class Network(BaseModelNetwork): + def __init__(self, a2c_network, **kwargs): + BaseModelNetwork.__init__(self, **kwargs) + self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + + def is_rnn(self): + return self.a2c_network.is_rnn() + + def get_value_layer(self): + return self.a2c_network.get_value_layer() + + def get_default_rnn_state(self): + return self.a2c_network.get_default_rnn_state() + + def forward(self, input_dict): + is_train = input_dict.get('is_train', True) + prev_actions = input_dict.get('prev_actions', None) + input_dict['obs'] = self.norm_obs(input_dict['obs']) + mu, logstd, value, states = self.a2c_network(input_dict) + sigma = torch.nn.functional.softplus(logstd + 0.001) + main_distr = NormalTanhDistribution(mu.size(-1)) + if is_train: + entropy = main_distr.entropy(mu, logstd) + prev_neglogp = -main_distr.log_prob(mu, logstd, main_distr.inverse_post_process(prev_actions)) + result = { + 'prev_neglogp' : torch.squeeze(prev_neglogp), + 'values' : value, + 'entropy' : entropy, + 'rnn_states' : states, + 'mus' : mu, + 'sigmas' : sigma + } + return result + else: + selected_action = main_distr.sample_no_postprocessing(mu, logstd) + neglogp = -main_distr.log_prob(mu, logstd, selected_action) + result = { + 'neglogpacs' : torch.squeeze(neglogp), + 'values' : self.denorm_value(value), + 'actions' : main_distr.post_process(selected_action), + 'rnn_states' : states, + 'mus' : mu, + 'sigmas' : sigma + } + return result + class ModelCentralValue(BaseModel): def __init__(self, network): @@ -385,4 +441,79 @@ def forward(self, input_dict): return dist +class TanhBijector: + """Tanh Bijector.""" + + def forward(self, x): + return torch.tanh(x) + + def inverse(self, y): + y = torch.clamp(y, -0.99999997, 0.99999997) + return 0.5 * (y.log1p() - (-y).log1p()) + + def forward_log_det_jacobian(self, x): + # Log of the absolute value of the determinant of the Jacobian + return 2. * (math.log(2.) - x - F.softplus(-2. * x)) + +class NormalTanhDistribution: + """Normal distribution followed by tanh.""" + def __init__(self, event_size, min_std=0.001, var_scale=1.0): + """Initialize the distribution. + + Args: + event_size (int): The size of events (i.e., actions). + min_std (float): Minimum standard deviation for the Gaussian. + var_scale (float): Scaling factor for the Gaussian's scale parameter. + """ + self.param_size = event_size + self._min_std = min_std + self._var_scale = var_scale + self._event_ndims = 1 # Rank of events + self._postprocessor = TanhBijector() + + def create_dist(self, loc, scale): + scale = (F.softplus(scale) + self._min_std) * self._var_scale + return torch.distributions.Normal(loc=loc, scale=scale) + + def sample_no_postprocessing(self, loc, scale): + dist = self.create_dist(loc, scale) + return dist.rsample() + + def sample(self, loc, scale): + """Returns a sample from the postprocessed distribution.""" + pre_tanh_sample = self.sample_no_postprocessing(loc, scale) + return self._postprocessor.forward(pre_tanh_sample) + + def post_process(self, pre_tanh_sample): + """Returns a postprocessed sample.""" + return self._postprocessor.forward(pre_tanh_sample) + + def inverse_post_process(self, post_tanh_sample): + """Returns a postprocessed sample.""" + return self._postprocessor.inverse(post_tanh_sample) + + def mode(self, loc, scale): + """Returns the mode of the postprocessed distribution.""" + dist = self.create_dist(loc, scale) + pre_tanh_mode = dist.mean # Mode of a normal distribution is its mean + return self._postprocessor.forward(pre_tanh_mode) + + def log_prob(self, loc, scale, actions): + """Compute the log probability of actions.""" + dist = self.create_dist(loc, scale) + log_probs = dist.log_prob(actions) + log_probs -= self._postprocessor.forward_log_det_jacobian(actions) + if self._event_ndims == 1: + log_probs = log_probs.sum(dim=-1) # Sum over action dimension + return log_probs + + def entropy(self, loc, scale): + """Return the entropy of the given distribution.""" + dist = self.create_dist(loc, scale) + entropy = dist.entropy() + sample = dist.rsample() + entropy += self._postprocessor.forward_log_det_jacobian(sample) + if self._event_ndims == 1: + entropy = entropy.sum(dim=-1) + return entropy diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 54a5cda1..f728045e 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -158,6 +158,7 @@ def __init__(self, base_name, params): self.save_freq = config.get('save_frequency', 0) self.save_best_after = config.get('save_best_after', 100) self.print_stats = config.get('print_stats', True) + self.epochs_between_resets = config.get('epochs_between_resets', 0) self.rnn_states = None self.name = base_name @@ -382,6 +383,12 @@ def set_eval(self): self.model.eval() if self.normalize_rms_advantage: self.advantage_mean_std.eval() + if self.epochs_between_resets > 0: + if self.epoch_num % self.epochs_between_resets == 0: + self.reset_envs() + self.init_current_rewards() + print(f"Forcing env reset after {self.epoch_num} epochs") + def set_train(self): self.model.train() @@ -466,10 +473,7 @@ def init_tensors(self): val_shape = (self.horizon_length, batch_size, self.value_size) current_rewards_shape = (batch_size, self.value_size) - self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) - self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) - self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device) - self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device) + self.init_current_rewards(batch_size, current_rewards_shape) if self.is_rnn: self.rnn_states = self.model.get_default_rnn_state() @@ -480,6 +484,12 @@ def init_tensors(self): assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0) self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] + def init_current_rewards(self, batch_size, current_rewards_shape): + self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) + self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) + self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device) + self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device) + def init_rnn_from_model(self, model): self.is_rnn = self.model.is_rnn() @@ -571,12 +581,12 @@ def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_ext mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * masks_t return mb_advs - def clear_stats(self): - batch_size = self.num_agents * self.num_actors + def clear_stats(self, clean_rewards= True): self.game_rewards.clear() self.game_shaped_rewards.clear() self.game_lengths.clear() - self.mean_rewards = self.last_mean_rewards = -100500 + if clean_rewards: + self.mean_rewards = self.last_mean_rewards = -100500 self.algo_observer.after_clear_stats() def update_epoch(self): @@ -772,6 +782,7 @@ def play_steps(self): self.current_rewards += rewards self.current_shaped_rewards += shaped_rewards self.current_lengths += 1 + all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = all_done_indices[::self.num_agents] @@ -947,6 +958,7 @@ def train_epoch(self): for mini_ep in range(0, self.mini_epochs_num): ep_kls = [] + self.dataset.apply_permutation() for i in range(len(self.dataset)): a_loss, c_loss, entropy, kl, last_lr, lr_mul = self.train_actor_critic(self.dataset[i]) a_losses.append(a_loss) diff --git a/rl_games/common/datasets.py b/rl_games/common/datasets.py index 3a48f3cf..e8bd62ae 100644 --- a/rl_games/common/datasets.py +++ b/rl_games/common/datasets.py @@ -3,9 +3,16 @@ from torch.utils.data import Dataset -class PPODataset(Dataset): +import torch +from torch.utils.data import Dataset +import random - def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length): +class PPODataset(Dataset): + def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length, permute=False): + if batch_size % minibatch_size != 0: + raise ValueError("Batch size must be divisible by minibatch size.") + if batch_size % seq_length != 0: + raise ValueError("Batch size must be divisible by sequence length.") self.is_rnn = is_rnn self.seq_length = seq_length @@ -15,70 +22,76 @@ def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_ self.length = self.batch_size // self.minibatch_size self.is_discrete = is_discrete self.is_continuous = not is_discrete - total_games = self.batch_size // self.seq_length self.num_games_batch = self.minibatch_size // self.seq_length - self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device) - self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length) - + self.special_names = ['rnn_states'] + self.permute = permute + if self.permute: + self.permutation_indices = torch.arange(self.batch_size, dtype=torch.long, device=self.device) def update_values_dict(self, values_dict): - self.values_dict = values_dict - - def update_mu_sigma(self, mu, sigma): - start = self.last_range[0] - end = self.last_range[1] - self.values_dict['mu'][start:end] = mu - self.values_dict['sigma'][start:end] = sigma + """Update the internal values dictionary.""" + self.values_dict = values_dict + + def update_mu_sigma(self, mu, sigma): + """Update the mu and sigma values in the dataset.""" + start, end = self.last_range + # Ensure the permutation does not break the logic for updating. + #if self.permute: + # original_indices = self.permutation_indices[start:end] + # self.values_dict['mu'][original_indices] = mu + # self.values_dict['sigma'][original_indices] = sigma + #else: + self.values_dict['mu'][start:end] = mu + self.values_dict['sigma'][start:end] = sigma + + def apply_permutation(self): + """Permute the dataset indices if the permutation flag is enabled.""" + if self.permute and not self.is_rnn: + self.permutation_indices = torch.randperm(self.batch_size, device=self.device, dtype=torch.long) + for key, value in self.values_dict.items(): + if key not in self.special_names and value is not None: + if isinstance(value, dict): + for k, v in value.items(): + self.values_dict[key][k] = v[self.permutation_indices] + else: + self.values_dict[key] = value[self.permutation_indices] - def __len__(self): - return self.length + def _slice_data(self, data, start, end): + """Slice data from start to end, handling dictionaries.""" + if isinstance(data, dict): + return {k: v[start:end] for k, v in data.items()} + return data[start:end] if data is not None else None def _get_item_rnn(self, idx): + """Retrieve a batch of data for RNN training.""" gstart = idx * self.num_games_batch gend = (idx + 1) * self.num_games_batch start = gstart * self.seq_length end = gend * self.seq_length self.last_range = (start, end) - input_dict = {} - for k,v in self.values_dict.items(): - if k not in self.special_names: - if isinstance(v, dict): - v_dict = {kd:vd[start:end] for kd, vd in v.items()} - input_dict[k] = v_dict - else: - if v is not None: - input_dict[k] = v[start:end] - else: - input_dict[k] = None - - rnn_states = self.values_dict['rnn_states'] - input_dict['rnn_states'] = [s[:, gstart:gend, :].contiguous() for s in rnn_states] - + input_dict = {k: self._slice_data(v, start, end) for k, v in self.values_dict.items() if k not in self.special_names} + input_dict['rnn_states'] = [s[:, gstart:gend, :].contiguous() for s in self.values_dict['rnn_states']] return input_dict def _get_item(self, idx): + """Retrieve a minibatch of data.""" start = idx * self.minibatch_size end = (idx + 1) * self.minibatch_size self.last_range = (start, end) - input_dict = {} - for k,v in self.values_dict.items(): - if k not in self.special_names and v is not None: - if type(v) is dict: - v_dict = { kd:vd[start:end] for kd, vd in v.items() } - input_dict[k] = v_dict - else: - input_dict[k] = v[start:end] - + + input_dict = {k: self._slice_data(v, start, end) for k, v in self.values_dict.items() if k not in self.special_names and v is not None} return input_dict def __getitem__(self, idx): - if self.is_rnn: - sample = self._get_item_rnn(idx) - else: - sample = self._get_item(idx) - return sample + """Retrieve an item based on the dataset type (RNN or not).""" + return self._get_item_rnn(idx) if self.is_rnn else self._get_item(idx) + + def __len__(self): + """Return the number of minibatches.""" + return self.length + diff --git a/rl_games/common/player.py b/rl_games/common/player.py index 98be6501..77849bab 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -25,7 +25,7 @@ def __init__(self, params): self.env_info = self.config.get('env_info') self.clip_actions = config.get('clip_actions', True) self.seed = self.env_config.pop('seed', None) - + self.balance_env_rewards = self.player_config.get('balance_env_rewards', False) if self.env_info is None: use_vecenv = self.player_config.get('use_vecenv', False) if use_vecenv: @@ -282,7 +282,6 @@ def run(self): games_played = 0 has_masks = False has_masks_func = getattr(self.env, "has_action_mask", None) is not None - op_agent = getattr(self.env, "create_agent", None) if op_agent: agent_inited = True @@ -311,6 +310,13 @@ def run(self): cr = torch.zeros(batch_size, dtype=torch.float32) steps = torch.zeros(batch_size, dtype=torch.float32) + # Initialize per-environment accumulators if balance_env_rewards is enabled + if self.balance_env_rewards: + per_env_rewards = torch.zeros(batch_size, dtype=torch.float32) + per_env_steps = torch.zeros(batch_size, dtype=torch.float32) + per_env_game_res = torch.zeros(batch_size, dtype=torch.float32) + per_env_games_played = torch.zeros(batch_size, dtype=torch.float32) + print_game_res = False for n in range(self.max_steps): @@ -336,21 +342,11 @@ def run(self): done_indices = all_done_indices[::self.num_agents] done_count = len(done_indices) games_played += done_count - if done_count > 0: if self.is_rnn: for s in self.states: s[:, all_done_indices, :] = s[:, - all_done_indices, :] * 0.0 - - cur_rewards = cr[done_indices].sum().item() - cur_steps = steps[done_indices].sum().item() - - cr = cr * (1.0 - done.float()) - steps = steps * (1.0 - done.float()) - sum_rewards += cur_rewards - sum_steps += cur_steps - + all_done_indices, :] * 0.0 game_res = 0.0 if isinstance(info, dict): if 'battle_won' in info: @@ -360,25 +356,74 @@ def run(self): print_game_res = True game_res = info.get('scores', 0.5) - if self.print_stats: - cur_rewards_done = cur_rewards/done_count - cur_steps_done = cur_steps/done_count + if self.balance_env_rewards: + # Update per-environment accumulators + per_env_rewards[done_indices] += cr[done_indices] + per_env_steps[done_indices] += steps[done_indices] + per_env_games_played[done_indices] += 1 if print_game_res: - print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f} w: {game_res}') - else: - print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f}') - - sum_game_res += game_res - if batch_size//self.num_agents == 1 or games_played >= n_games: + per_env_game_res[done_indices] += game_res + + # Reset current rewards and steps for done environments + cr[done_indices] = 0 + steps[done_indices] = 0 + else: + # Original accumulation + cur_rewards = cr[done_indices].sum().item() + cur_steps = steps[done_indices].sum().item() + + cr = cr * (1.0 - done.float()) + steps = steps * (1.0 - done.float()) + sum_rewards += cur_rewards + sum_steps += cur_steps + sum_game_res += game_res + + if self.print_stats: + cur_rewards_done = cur_rewards / done_count + cur_steps_done = cur_steps / done_count + if print_game_res: + print( + f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f} w: {game_res}') + else: + print( + f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f}') + + if batch_size // self.num_agents == 1 or games_played >= n_games: break - print(sum_rewards) - if print_game_res: - print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / - games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life) + if self.balance_env_rewards: + # Calculate per-environment average rewards + valid_envs = per_env_games_played > 0 + per_env_avg_rewards = torch.zeros(batch_size, dtype=torch.float32) + per_env_avg_steps = torch.zeros(batch_size, dtype=torch.float32) + per_env_avg_game_res = torch.zeros(batch_size, dtype=torch.float32) + + per_env_avg_rewards[valid_envs] = ( + per_env_rewards[valid_envs] / per_env_games_played[valid_envs]) + per_env_avg_steps[valid_envs] = ( + per_env_steps[valid_envs] / per_env_games_played[valid_envs]) + + overall_avg_reward = per_env_avg_rewards[valid_envs].mean().item() + overall_avg_steps = per_env_avg_steps[valid_envs].mean().item() + + if print_game_res: + per_env_avg_game_res[valid_envs] = ( + per_env_game_res[valid_envs] / per_env_games_played[valid_envs]) + overall_winrate = per_env_avg_game_res[valid_envs].mean().item() + print('av reward:', overall_avg_reward * n_game_life, 'av steps:', overall_avg_steps * + n_game_life, 'winrate:', overall_winrate * n_game_life) + else: + print('av reward:', overall_avg_reward * n_game_life, + 'av steps:', overall_avg_steps * n_game_life) else: - print('av reward:', sum_rewards / games_played * n_game_life, - 'av steps:', sum_steps / games_played * n_game_life) + print(sum_rewards) + if print_game_res: + print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / + games_played * n_game_life, 'winrate:', sum_game_res / games_played * n_game_life) + else: + print('av reward:', sum_rewards / games_played * n_game_life, + 'av steps:', sum_steps / games_played * n_game_life) + def get_batch_size(self, obses, batch_size): obs_shape = self.obs_shape diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index b906c43d..4e72221b 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -1,7 +1,8 @@ -from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder +from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder, SimpleNetBuilder from rl_games.algos_torch import model_builder model_builder.register_network('testnet', TestNetBuilder) +model_builder.register_network('simplenet', SimpleNetBuilder) model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder) \ No newline at end of file diff --git a/rl_games/envs/test_network.py b/rl_games/envs/test_network.py index 7adfae90..51e28117 100644 --- a/rl_games/envs/test_network.py +++ b/rl_games/envs/test_network.py @@ -114,5 +114,52 @@ def load(self, params): def build(self, name, **kwargs): return TestNetWithAuxLoss(self.params, **kwargs) + def __call__(self, name, **kwargs): + return self.build(name, **kwargs) + + + +class SimpleNet(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + nn.Module.__init__(self) + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + num_inputs =input_shape[0] + self.actions_num = actions_num + self.central_value = params.get('central_value', False) + self.value_size = kwargs.pop('value_size', 1) + self.linear = torch.nn.Sequential( + nn.Linear(num_inputs, 512), + nn.SiLU(), + nn.Linear(512, 256), + nn.SiLU(), + nn.Linear(256, 128), + nn.SiLU(), + nn.Linear(128, actions_num + 1), + ) + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + + def is_rnn(self): + return False + @torch.compile + def forward(self, obs): + obs = obs['obs'] + x = self.linear(obs) + mu, value = torch.split(x, [self.actions_num, 1], dim=-1) + return mu, self.sigma.unsqueeze(0).expand(mu.size()[0], self.actions_num), value, None + + + + +class SimpleNetBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + return SimpleNet(self.params, **kwargs) + def __call__(self, name, **kwargs): return self.build(name, **kwargs) \ No newline at end of file