Skip to content

Commit

Permalink
Added a few features to make rl_games more comparable to the brax wit…
Browse files Browse the repository at this point in the history
…h jax ppo (#314)

* Updated

* cleanup and simplenet

---------

Co-authored-by: Denys Makoviichuk <[email protected]>
  • Loading branch information
Denys88 and DenSumy authored Nov 30, 2024
1 parent 42c076e commit 6612eaf
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 80 deletions.
2 changes: 2 additions & 0 deletions rl_games/algos_torch/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(self):
lambda network, **kwargs: models.ModelSACContinuous(network))
self.model_factory.register_builder('central_value',
lambda network, **kwargs: models.ModelCentralValue(network))
self.model_factory.register_builder('continuous_a2c_tanh',
lambda network, **kwargs: models.ModelA2CContinuousTanh(network))
self.network_builder = NetworkBuilder()

def get_network_builder(self):
Expand Down
131 changes: 131 additions & 0 deletions rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs
from rl_games.algos_torch.moving_mean_std import GeneralizedMovingStats
from torch.distributions import Normal, TransformedDistribution, TanhTransform
import math

class BaseModel():
def __init__(self, model_class):
Expand Down Expand Up @@ -251,6 +253,7 @@ def forward(self, input_dict):
return result



class ModelA2CContinuousLogStd(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
Expand Down Expand Up @@ -310,6 +313,59 @@ def neglogp(self, x, mean, std, logstd):
+ 0.5 * np.log(2.0 * np.pi) * x.size()[-1] \
+ logstd.sum(dim=-1)

class ModelA2CContinuousTanh(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network

class Network(BaseModelNetwork):
def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network
def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_value_layer(self):
return self.a2c_network.get_value_layer()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

def forward(self, input_dict):
is_train = input_dict.get('is_train', True)
prev_actions = input_dict.get('prev_actions', None)
input_dict['obs'] = self.norm_obs(input_dict['obs'])
mu, logstd, value, states = self.a2c_network(input_dict)
sigma = torch.nn.functional.softplus(logstd + 0.001)
main_distr = NormalTanhDistribution(mu.size(-1))
if is_train:
entropy = main_distr.entropy(mu, logstd)
prev_neglogp = -main_distr.log_prob(mu, logstd, main_distr.inverse_post_process(prev_actions))
result = {
'prev_neglogp' : torch.squeeze(prev_neglogp),
'values' : value,
'entropy' : entropy,
'rnn_states' : states,
'mus' : mu,
'sigmas' : sigma
}
return result
else:
selected_action = main_distr.sample_no_postprocessing(mu, logstd)
neglogp = -main_distr.log_prob(mu, logstd, selected_action)
result = {
'neglogpacs' : torch.squeeze(neglogp),
'values' : self.denorm_value(value),
'actions' : main_distr.post_process(selected_action),
'rnn_states' : states,
'mus' : mu,
'sigmas' : sigma
}
return result


class ModelCentralValue(BaseModel):
def __init__(self, network):
Expand Down Expand Up @@ -385,4 +441,79 @@ def forward(self, input_dict):
return dist


class TanhBijector:
"""Tanh Bijector."""

def forward(self, x):
return torch.tanh(x)

def inverse(self, y):
y = torch.clamp(y, -0.99999997, 0.99999997)
return 0.5 * (y.log1p() - (-y).log1p())

def forward_log_det_jacobian(self, x):
# Log of the absolute value of the determinant of the Jacobian
return 2. * (math.log(2.) - x - F.softplus(-2. * x))

class NormalTanhDistribution:
"""Normal distribution followed by tanh."""

def __init__(self, event_size, min_std=0.001, var_scale=1.0):
"""Initialize the distribution.
Args:
event_size (int): The size of events (i.e., actions).
min_std (float): Minimum standard deviation for the Gaussian.
var_scale (float): Scaling factor for the Gaussian's scale parameter.
"""
self.param_size = event_size
self._min_std = min_std
self._var_scale = var_scale
self._event_ndims = 1 # Rank of events
self._postprocessor = TanhBijector()

def create_dist(self, loc, scale):
scale = (F.softplus(scale) + self._min_std) * self._var_scale
return torch.distributions.Normal(loc=loc, scale=scale)

def sample_no_postprocessing(self, loc, scale):
dist = self.create_dist(loc, scale)
return dist.rsample()

def sample(self, loc, scale):
"""Returns a sample from the postprocessed distribution."""
pre_tanh_sample = self.sample_no_postprocessing(loc, scale)
return self._postprocessor.forward(pre_tanh_sample)

def post_process(self, pre_tanh_sample):
"""Returns a postprocessed sample."""
return self._postprocessor.forward(pre_tanh_sample)

def inverse_post_process(self, post_tanh_sample):
"""Returns a postprocessed sample."""
return self._postprocessor.inverse(post_tanh_sample)

def mode(self, loc, scale):
"""Returns the mode of the postprocessed distribution."""
dist = self.create_dist(loc, scale)
pre_tanh_mode = dist.mean # Mode of a normal distribution is its mean
return self._postprocessor.forward(pre_tanh_mode)

def log_prob(self, loc, scale, actions):
"""Compute the log probability of actions."""
dist = self.create_dist(loc, scale)
log_probs = dist.log_prob(actions)
log_probs -= self._postprocessor.forward_log_det_jacobian(actions)
if self._event_ndims == 1:
log_probs = log_probs.sum(dim=-1) # Sum over action dimension
return log_probs

def entropy(self, loc, scale):
"""Return the entropy of the given distribution."""
dist = self.create_dist(loc, scale)
entropy = dist.entropy()
sample = dist.rsample()
entropy += self._postprocessor.forward_log_det_jacobian(sample)
if self._event_ndims == 1:
entropy = entropy.sum(dim=-1)
return entropy
26 changes: 19 additions & 7 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, base_name, params):
self.save_freq = config.get('save_frequency', 0)
self.save_best_after = config.get('save_best_after', 100)
self.print_stats = config.get('print_stats', True)
self.epochs_between_resets = config.get('epochs_between_resets', 0)
self.rnn_states = None
self.name = base_name

Expand Down Expand Up @@ -382,6 +383,12 @@ def set_eval(self):
self.model.eval()
if self.normalize_rms_advantage:
self.advantage_mean_std.eval()
if self.epochs_between_resets > 0:
if self.epoch_num % self.epochs_between_resets == 0:
self.reset_envs()
self.init_current_rewards()
print(f"Forcing env reset after {self.epoch_num} epochs")


def set_train(self):
self.model.train()
Expand Down Expand Up @@ -466,10 +473,7 @@ def init_tensors(self):

val_shape = (self.horizon_length, batch_size, self.value_size)
current_rewards_shape = (batch_size, self.value_size)
self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device)
self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device)
self.init_current_rewards(batch_size, current_rewards_shape)

if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
Expand All @@ -480,6 +484,12 @@ def init_tensors(self):
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

def init_current_rewards(self, batch_size, current_rewards_shape):
self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device)
self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device)

def init_rnn_from_model(self, model):
self.is_rnn = self.model.is_rnn()

Expand Down Expand Up @@ -571,12 +581,12 @@ def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_ext
mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * masks_t
return mb_advs

def clear_stats(self):
batch_size = self.num_agents * self.num_actors
def clear_stats(self, clean_rewards= True):
self.game_rewards.clear()
self.game_shaped_rewards.clear()
self.game_lengths.clear()
self.mean_rewards = self.last_mean_rewards = -100500
if clean_rewards:
self.mean_rewards = self.last_mean_rewards = -100500
self.algo_observer.after_clear_stats()

def update_epoch(self):
Expand Down Expand Up @@ -772,6 +782,7 @@ def play_steps(self):
self.current_rewards += rewards
self.current_shaped_rewards += shaped_rewards
self.current_lengths += 1

all_done_indices = self.dones.nonzero(as_tuple=False)
env_done_indices = all_done_indices[::self.num_agents]

Expand Down Expand Up @@ -947,6 +958,7 @@ def train_epoch(self):

for mini_ep in range(0, self.mini_epochs_num):
ep_kls = []
self.dataset.apply_permutation()
for i in range(len(self.dataset)):
a_loss, c_loss, entropy, kl, last_lr, lr_mul = self.train_actor_critic(self.dataset[i])
a_losses.append(a_loss)
Expand Down
101 changes: 57 additions & 44 deletions rl_games/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
from torch.utils.data import Dataset


class PPODataset(Dataset):
import torch
from torch.utils.data import Dataset
import random

def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length):
class PPODataset(Dataset):
def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length, permute=False):
if batch_size % minibatch_size != 0:
raise ValueError("Batch size must be divisible by minibatch size.")
if batch_size % seq_length != 0:
raise ValueError("Batch size must be divisible by sequence length.")

self.is_rnn = is_rnn
self.seq_length = seq_length
Expand All @@ -15,70 +22,76 @@ def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_
self.length = self.batch_size // self.minibatch_size
self.is_discrete = is_discrete
self.is_continuous = not is_discrete
total_games = self.batch_size // self.seq_length
self.num_games_batch = self.minibatch_size // self.seq_length
self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device)
self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length)


self.special_names = ['rnn_states']
self.permute = permute
if self.permute:
self.permutation_indices = torch.arange(self.batch_size, dtype=torch.long, device=self.device)

def update_values_dict(self, values_dict):
self.values_dict = values_dict

def update_mu_sigma(self, mu, sigma):
start = self.last_range[0]
end = self.last_range[1]
self.values_dict['mu'][start:end] = mu
self.values_dict['sigma'][start:end] = sigma
"""Update the internal values dictionary."""
self.values_dict = values_dict

def update_mu_sigma(self, mu, sigma):
"""Update the mu and sigma values in the dataset."""
start, end = self.last_range
# Ensure the permutation does not break the logic for updating.
#if self.permute:
# original_indices = self.permutation_indices[start:end]
# self.values_dict['mu'][original_indices] = mu
# self.values_dict['sigma'][original_indices] = sigma
#else:
self.values_dict['mu'][start:end] = mu
self.values_dict['sigma'][start:end] = sigma

def apply_permutation(self):
"""Permute the dataset indices if the permutation flag is enabled."""
if self.permute and not self.is_rnn:
self.permutation_indices = torch.randperm(self.batch_size, device=self.device, dtype=torch.long)
for key, value in self.values_dict.items():
if key not in self.special_names and value is not None:
if isinstance(value, dict):
for k, v in value.items():
self.values_dict[key][k] = v[self.permutation_indices]
else:
self.values_dict[key] = value[self.permutation_indices]

def __len__(self):
return self.length
def _slice_data(self, data, start, end):
"""Slice data from start to end, handling dictionaries."""
if isinstance(data, dict):
return {k: v[start:end] for k, v in data.items()}
return data[start:end] if data is not None else None

def _get_item_rnn(self, idx):
"""Retrieve a batch of data for RNN training."""
gstart = idx * self.num_games_batch
gend = (idx + 1) * self.num_games_batch
start = gstart * self.seq_length
end = gend * self.seq_length
self.last_range = (start, end)

input_dict = {}
for k,v in self.values_dict.items():
if k not in self.special_names:
if isinstance(v, dict):
v_dict = {kd:vd[start:end] for kd, vd in v.items()}
input_dict[k] = v_dict
else:
if v is not None:
input_dict[k] = v[start:end]
else:
input_dict[k] = None

rnn_states = self.values_dict['rnn_states']
input_dict['rnn_states'] = [s[:, gstart:gend, :].contiguous() for s in rnn_states]

input_dict = {k: self._slice_data(v, start, end) for k, v in self.values_dict.items() if k not in self.special_names}
input_dict['rnn_states'] = [s[:, gstart:gend, :].contiguous() for s in self.values_dict['rnn_states']]
return input_dict

def _get_item(self, idx):
"""Retrieve a minibatch of data."""
start = idx * self.minibatch_size
end = (idx + 1) * self.minibatch_size
self.last_range = (start, end)
input_dict = {}
for k,v in self.values_dict.items():
if k not in self.special_names and v is not None:
if type(v) is dict:
v_dict = { kd:vd[start:end] for kd, vd in v.items() }
input_dict[k] = v_dict
else:
input_dict[k] = v[start:end]


input_dict = {k: self._slice_data(v, start, end) for k, v in self.values_dict.items() if k not in self.special_names and v is not None}
return input_dict

def __getitem__(self, idx):
if self.is_rnn:
sample = self._get_item_rnn(idx)
else:
sample = self._get_item(idx)
return sample
"""Retrieve an item based on the dataset type (RNN or not)."""
return self._get_item_rnn(idx) if self.is_rnn else self._get_item(idx)

def __len__(self):
"""Return the number of minibatches."""
return self.length




Expand Down
Loading

0 comments on commit 6612eaf

Please sign in to comment.