Skip to content

Commit

Permalink
Clean-up and fixed imports.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Dec 1, 2024
1 parent 7e0d74d commit aeacb7d
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 109 deletions.
20 changes: 9 additions & 11 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rl_games.common import datasets

from torch import optim
import torch
import torch


class A2CAgent(a2c_common.ContinuousA2CBase):
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(self, base_name, params):
'normalize_value' : self.normalize_value,
'normalize_input': self.normalize_input,
}

self.model = self.network.build(build_config)
self.model.to(self.ppo_device)
self.states = None
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, base_name, params):
def update_epoch(self):
self.epoch_num += 1
return self.epoch_num

def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
Expand Down Expand Up @@ -114,8 +114,8 @@ def calc_gradients(self, input_dict):

batch_dict = {
'is_train': True,
'prev_actions': actions_batch,
'obs' : obs_batch,
'prev_actions': actions_batch,
'obs': obs_batch,
}

rnn_masks = None
Expand All @@ -125,9 +125,9 @@ def calc_gradients(self, input_dict):
batch_dict['seq_length'] = self.seq_length

if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']
batch_dict['dones'] = input_dict['dones']

with torch.cuda.amp.autocast(enabled=self.mixed_precision):
with torch.amp.autocast('cuda', enabled=self.mixed_precision):
res_dict = self.model(batch_dict)
action_log_probs = res_dict['prev_neglogp']
values = res_dict['values']
Expand All @@ -138,7 +138,7 @@ def calc_gradients(self, input_dict):
a_loss = self.actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)

if self.has_value_loss:
c_loss = common_losses.critic_loss(self.model,value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
c_loss = common_losses.critic_loss(self.model, value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
else:
c_loss = torch.zeros(1, device=self.ppo_device)
if self.bound_loss_type == 'regularisation':
Expand Down Expand Up @@ -183,7 +183,7 @@ def calc_gradients(self, input_dict):
'new_neglogp' : action_log_probs,
'old_neglogp' : old_action_log_probs_batch,
'masks' : rnn_masks
}, curr_e_clip, 0)
}, curr_e_clip, 0)

self.train_result = (a_loss, c_loss, entropy, \
kl_dist, self.last_lr, lr_mul, \
Expand All @@ -209,5 +209,3 @@ def bound_loss(self, mu):
else:
b_loss = 0
return b_loss


40 changes: 20 additions & 20 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch.nn as nn

from rl_games.algos_torch.d2rl import D2RLNet
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax


Expand Down Expand Up @@ -192,7 +192,6 @@ def _build_value_layer(self, input_size, output_size, value_type='legacy'):
raise ValueError('value type is not "default", "legacy" or "two_hot_encoded"')



class A2CBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)
Expand Down Expand Up @@ -339,7 +338,7 @@ def forward(self, obs_dict):
a_out = a_out.contiguous().view(a_out.size(0), -1)

c_out = self.critic_cnn(c_out)
c_out = c_out.contiguous().view(c_out.size(0), -1)
c_out = c_out.contiguous().view(c_out.size(0), -1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)
Expand All @@ -359,23 +358,23 @@ def forward(self, obs_dict):
a_out = a_out.reshape(num_seqs, seq_length, -1)
c_out = c_out.reshape(num_seqs, seq_length, -1)

a_out = a_out.transpose(0,1)
c_out = c_out.transpose(0,1)
a_out = a_out.transpose(0, 1)
c_out = c_out.transpose(0, 1)
if dones is not None:
dones = dones.reshape(num_seqs, seq_length, -1)
dones = dones.transpose(0,1)
dones = dones.transpose(0, 1)

if len(states) == 2:
a_states = states[0]
c_states = states[1]
else:
a_states = states[:2]
c_states = states[2:]
c_states = states[2:]
a_out, a_states = self.a_rnn(a_out, a_states, dones, bptt_len)
c_out, c_states = self.c_rnn(c_out, c_states, dones, bptt_len)

a_out = a_out.transpose(0,1)
c_out = c_out.transpose(0,1)
a_out = a_out.transpose(0, 1)
c_out = c_out.transpose(0, 1)
a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1)
c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1)

Expand All @@ -398,7 +397,7 @@ def forward(self, obs_dict):
else:
a_out = self.actor_mlp(a_out)
c_out = self.critic_mlp(c_out)

value = self.value_act(self.value(c_out))

if self.is_discrete:
Expand All @@ -420,7 +419,7 @@ def forward(self, obs_dict):
else:
out = obs
out = self.actor_cnn(out)
out = out.flatten(1)
out = out.flatten(1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)
Expand Down Expand Up @@ -474,7 +473,7 @@ def forward(self, obs_dict):
else:
sigma = self.sigma_act(self.sigma(out))
return mu, mu*0 + sigma, value, states

def is_separate_critic(self):
return self.separate

Expand Down Expand Up @@ -555,6 +554,7 @@ def build(self, name, **kwargs):
net = A2CBuilder.Network(self.params, **kwargs)
return net


class Conv2dAuto(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -623,6 +623,7 @@ def forward(self, x):
x = self.res_block2(x)
return x


class A2CResnetBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)
Expand Down Expand Up @@ -842,10 +843,10 @@ def is_rnn(self):
def get_default_rnn_state(self):
num_layers = self.rnn_layers
if self.rnn_name == 'lstm':
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))

def build(self, name, **kwargs):
net = A2CResnetBuilder.Network(self.params, **kwargs)
Expand Down Expand Up @@ -952,7 +953,7 @@ def __init__(self, params, **kwargs):
self.critic = self._build_critic(1, **critic_mlp_args)
print("Building Critic Target")
self.critic_target = self._build_critic(1, **critic_mlp_args)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_target.load_state_dict(self.critic.state_dict())

mlp_init = self.init_factory.create(**self.initializer)
for m in self.modules():
Expand All @@ -976,7 +977,7 @@ def forward(self, obs_dict):
obs = obs_dict['obs']
mu, sigma = self.actor(obs)
return mu, sigma

def is_separate_critic(self):
return self.separate

Expand All @@ -997,12 +998,11 @@ def load(self, params):

if self.has_space:
self.is_discrete = 'discrete' in params['space']
self.is_continuous = 'continuous'in params['space']
self.is_continuous = 'continuous' in params['space']
if self.is_continuous:
self.space_config = params['space']['continuous']
elif self.is_discrete:
self.space_config = params['space']['discrete']
else:
self.is_discrete = False
self.is_continuous = False

Loading

0 comments on commit aeacb7d

Please sign in to comment.