Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates #254

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rl_games"
version = "1.6.1"
version = "1.6.2"
description = ""
readme = "README.md"
authors = [
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def calc_gradients(self, input_dict):
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef

if self.multi_gpu:
self.optimizer.zero_grad()
else:
Expand Down
6 changes: 1 addition & 5 deletions rl_games/algos_torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import torch.nn as nn
import torch.nn.functional as F

import numpy as np


class NoisyLinear(nn.Linear):
def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
Expand Down Expand Up @@ -54,15 +52,13 @@ def forward(self, input):
noise_v = torch.mul(eps_in, eps_out)
return F.linear(input, self.weight + self.sigma_weight * noise_v, bias)



def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)


def symexp(x):
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)


class SymLog(nn.Module):

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def forward(self, input_dict):
}
return result


class ModelA2CContinuous(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
Expand Down Expand Up @@ -330,7 +331,6 @@ def forward(self, input_dict):
return result



class ModelSACContinuous(BaseModel):

def __init__(self, network):
Expand Down
8 changes: 5 additions & 3 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math
import numpy as np
from rl_games.algos_torch.d2rl import D2RLNet
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.layers import symexp, symlog


def _create_initializer(func, **kwargs):
return lambda v : func(v, **kwargs)

Expand Down Expand Up @@ -309,6 +307,7 @@ def forward(self, obs_dict):
seq_length = obs_dict.get('seq_length', 1)
dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)

if self.has_cnn:
# for obs shape 4
# input expected shape (B, W, H, C)
Expand Down Expand Up @@ -769,6 +768,7 @@ def load(self, params):
self.is_multi_discrete = 'multi_discrete'in params['space']
self.value_activation = params.get('value_activation', 'None')
self.normalization = params.get('normalization', None)

if self.is_continuous:
self.space_config = params['space']['continuous']
self.fixed_sigma = self.space_config['fixed_sigma']
Expand All @@ -777,12 +777,14 @@ def load(self, params):
elif self.is_multi_discrete:
self.space_config = params['space']['multi_discrete']
self.has_rnn = 'rnn' in params

if self.has_rnn:
self.rnn_units = params['rnn']['units']
self.rnn_layers = params['rnn']['layers']
self.rnn_name = params['rnn']['name']
self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False)
self.rnn_ln = params['rnn'].get('layer_norm', False)

self.has_cnn = True
self.permute_input = params['cnn'].get('permute_input', True)
self.conv_depths = params['cnn']['conv_depths']
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_action(self, obs, is_deterministic = False):

def restore(self, fn):
checkpoint = torch_ext.load_checkpoint(fn)
self.model.load_state_dict(checkpoint['model'])
self.model.load_state_dict(checkpoint['model'], strict=False)
if self.normalize_input and 'running_mean_std' in checkpoint:
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])

Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def set_stats_weights(self, weights):
self.scaler.load_state_dict(weights['scaler'])

def set_weights(self, weights):
self.model.load_state_dict(weights['model'])
self.model.load_state_dict(weights['model'], strict=False)
self.set_stats_weights(weights)

def get_param(self, param_name):
Expand Down
28 changes: 28 additions & 0 deletions rl_games/common/algo_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,25 @@ def after_init(self, algo):
self.mean_scores = torch_ext.AverageMeter(1, self.algo.games_to_track).to(self.algo.ppo_device)
self.ep_infos = []
self.direct_info = {}

self.histo_freq = 10
self.current_ep = 0
self.mu_datapoints = None

self.writer = self.algo.writer

def after_steps(self):
self.mu_datapoints = self.algo.dataset.values_dict['mu'][0:1000]

def process_infos(self, infos, done_indices):
if not isinstance(infos, dict):
classname = self.__class__.__name__
raise ValueError(f"{classname} expected 'infos' as dict. Received: {type(infos)}")

# store episode information
if "episode" in infos:
self.ep_infos.append(infos["episode"])

# log other variables directly
if len(infos) > 0 and isinstance(infos, dict): # allow direct logging from env
self.direct_info = {}
Expand Down Expand Up @@ -126,14 +136,32 @@ def after_print_stats(self, frame, epoch_num, total_time):
value = torch.mean(info_tensor)
self.writer.add_scalar("Episode/" + key, value, epoch_num)
self.ep_infos.clear()

# log scalars from env information
for k, v in self.direct_info.items():
self.writer.add_scalar(f"{k}/frame", v, frame)
self.writer.add_scalar(f"{k}/iter", v, epoch_num)
self.writer.add_scalar(f"{k}/time", v, total_time)

# log mean reward/score from the env
if self.mean_scores.current_size > 0:
mean_scores = self.mean_scores.get_mean()
self.writer.add_scalar("scores/mean", mean_scores, frame)
self.writer.add_scalar("scores/iter", mean_scores, epoch_num)
self.writer.add_scalar("scores/time", mean_scores, total_time)

mean_std = torch.mean(self.algo.model.a2c_network.sigma).exp()
max_std = torch.max(self.algo.model.a2c_network.sigma).exp()
min_std = torch.min(self.algo.model.a2c_network.sigma).exp()

self.writer.add_scalar('info/mean_std', mean_std, epoch_num)
self.writer.add_scalar('info/max_std', max_std, epoch_num)
self.writer.add_scalar('info/min_std', min_std, epoch_num)

self.current_ep += 1
if self.current_ep % self.histo_freq == 0:
for i in range(self.algo.actions_num):
self.writer.add_histogram(
'info/mu[{0}]'.format(i),
torch.clamp(self.mu_datapoints[:,i], min=-1.0, max=1.0),
epoch_num, bins='auto')
5 changes: 5 additions & 0 deletions rl_games/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def observation(self, frame):
frame = np.expand_dims(frame, -1)
return frame


class FrameStack(gym.Wrapper):
def __init__(self, env, k, flat = False):
"""
Expand Down Expand Up @@ -302,6 +303,7 @@ def _get_ob(self):
frames = np.transpose(self.frames, (1, 0, 2))
return frames


class BatchedFrameStackWithStates(gym.Wrapper):
def __init__(self, env, k, transpose = False, flatten = False):
gym.Wrapper.__init__(self, env)
Expand Down Expand Up @@ -362,6 +364,7 @@ def process_data(self, data):
obses = np.transpose(data, (1, 0, 2))
return obses


class ProcgenStack(gym.Wrapper):
def __init__(self, env, k = 2, greyscale=True):
gym.Wrapper.__init__(self, env)
Expand Down Expand Up @@ -420,6 +423,7 @@ def observation(self, observation):
# with smaller replay buffers only.
return np.array(observation).astype(np.float32) / 255.0


class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only stored once.
Expand Down Expand Up @@ -448,6 +452,7 @@ def __len__(self):
def __getitem__(self, i):
return self._force()[i]


class ReallyDoneWrapper(gym.Wrapper):
def __init__(self, env):
"""
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/diambra/diambra.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, **kwargs):
key_to_add.append("oppPosition")
key_to_add.append("stage")
key_to_add.append("character")

self.env = make_diambra_env(diambraGym, env_prefix="Train" + str(self.random_seed), seed=self.random_seed,
diambra_kwargs=env_kwargs,
diambra_gym_kwargs=gym_kwargs,
Expand Down
7 changes: 6 additions & 1 deletion rl_games/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from rl_games.networks.tcnn_mlp import TcnnNetBuilder
from rl_games.networks.ig_networks import EncoderMLPBuilder, TransformerBuilder, TorchTransformerBuilder
from rl_games.algos_torch import model_builder

model_builder.register_network('tcnnnet', TcnnNetBuilder)

model_builder.register_network('tcnnnet', TcnnNetBuilder)
model_builder.register_network('enc_mlp', lambda **kwargs : EncoderMLPBuilder())
model_builder.register_network('transformer', lambda **kwargs : TransformerBuilder())
model_builder.register_network('torch_transformer', lambda **kwargs : TorchTransformerBuilder())
Loading