Skip to content

Commit

Permalink
Reverted RNN changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Dec 1, 2024
1 parent 7ea2ffb commit 5e5511e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 46 deletions.
2 changes: 2 additions & 0 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng

if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]

total_agents = self.num_actors #* self.num_agents
num_seqs = self.horizon_length // self.seq_length
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
Expand Down
24 changes: 12 additions & 12 deletions rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,24 +349,24 @@ def forward(self, input_dict):
entropy = main_distr.entropy(mu, logstd)
prev_neglogp = -main_distr.log_prob(mu, logstd, main_distr.inverse_post_process(prev_actions))
result = {
'prev_neglogp' : torch.squeeze(prev_neglogp),
'values' : value,
'entropy' : entropy,
'rnn_states' : states,
'mus' : mu,
'sigmas' : sigma
'prev_neglogp': torch.squeeze(prev_neglogp),
'values': value,
'entropy': entropy,
'rnn_states': states,
'mus': mu,
'sigmas': sigma
}
return result
else:
selected_action = main_distr.sample_no_postprocessing(mu, logstd)
neglogp = -main_distr.log_prob(mu, logstd, selected_action)
result = {
'neglogpacs' : torch.squeeze(neglogp),
'values' : self.denorm_value(value),
'actions' : main_distr.post_process(selected_action),
'rnn_states' : states,
'mus' : mu,
'sigmas' : sigma
'neglogpacs': torch.squeeze(neglogp),
'values': self.denorm_value(value),
'actions': main_distr.post_process(selected_action),
'rnn_states': states,
'mus': mu,
'sigmas': sigma
}
return result

Expand Down
68 changes: 34 additions & 34 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,21 +212,21 @@ def __init__(self, params, **kwargs):
self.critic_cnn = nn.Sequential()
self.actor_mlp = nn.Sequential()
self.critic_mlp = nn.Sequential()

if self.has_cnn:
if self.permute_input:
input_shape = torch_ext.shape_whc_to_cwh(input_shape)
cnn_args = {
'ctype' : self.cnn['type'],
'input_shape' : input_shape,
'convs' :self.cnn['convs'],
'activation' : self.cnn['activation'],
'norm_func_name' : self.normalization,
'ctype': self.cnn['type'],
'input_shape': input_shape,
'convs': self.cnn['convs'],
'activation': self.cnn['activation'],
'norm_func_name': self.normalization,
}
self.actor_cnn = self._build_conv(**cnn_args)

if self.separate:
self.critic_cnn = self._build_conv( **cnn_args)
self.critic_cnn = self._build_conv(**cnn_args)

cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn)

Expand Down Expand Up @@ -264,13 +264,13 @@ def __init__(self, params, **kwargs):
self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear,
'd2rl' : self.is_d2rl,
'norm_only_first_layer' : self.norm_only_first_layer
'input_size': mlp_input_size,
'units': self.units,
'activation': self.activation,
'norm_func_name': self.normalization,
'dense_func': torch.nn.Linear,
'd2rl': self.is_d2rl,
'norm_only_first_layer': self.norm_only_first_layer
}
self.actor_mlp = self._build_mlp(**mlp_args)
if self.separate:
Expand Down Expand Up @@ -310,14 +310,14 @@ def __init__(self, params, **kwargs):
if isinstance(m, nn.Linear):
mlp_init(m.weight)
if getattr(m, "bias", None) is not None:
torch.nn.init.zeros_(m.bias)
torch.nn.init.zeros_(m.bias)

if self.is_continuous:
mu_init(self.mu.weight)
if self.fixed_sigma:
sigma_init(self.sigma)
else:
sigma_init(self.sigma.weight)
sigma_init(self.sigma.weight)

def forward(self, obs_dict):
obs = obs_dict['obs']
Expand Down Expand Up @@ -484,25 +484,26 @@ def get_default_rnn_state(self):
if not self.has_rnn:
return None
num_layers = self.rnn_layers

if self.rnn_name == 'identity':
rnn_units = 1
else:
rnn_units = self.rnn_units
if self.rnn_name == 'lstm':
if self.separate:
return (torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device),
torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device),
torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device),
torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device))
return (torch.zeros((num_layers, self.num_seqs, rnn_units)),
torch.zeros((num_layers, self.num_seqs, rnn_units)),
torch.zeros((num_layers, self.num_seqs, rnn_units)),
torch.zeros((num_layers, self.num_seqs, rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device),
torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device))
return (torch.zeros((num_layers, self.num_seqs, rnn_units)),
torch.zeros((num_layers, self.num_seqs, rnn_units)))
else:
if self.separate:
return (torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device),
torch.zeros((num_layers, self.num_seqs, rnn_units), device=self.device))
return (torch.zeros((num_layers, self.num_seqs, rnn_units)),
torch.zeros((num_layers, self.num_seqs, rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, rnn_units), ),)
return (torch.zeros((num_layers, self.num_seqs, rnn_units)),)

def load(self, params):
self.separate = params.get('separate', False)
Expand Down Expand Up @@ -671,11 +672,11 @@ def __init__(self, params, **kwargs):
#self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
'units' :self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear
'input_size': mlp_input_size,
'units':self.units,
'activation': self.activation,
'norm_func_name': self.normalization,
'dense_func': torch.nn.Linear
}

self.mlp = self._build_mlp(**mlp_args)
Expand Down Expand Up @@ -741,7 +742,6 @@ def forward(self, obs_dict):
out = self.flatten_act(out)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
Expand Down Expand Up @@ -843,10 +843,10 @@ def is_rnn(self):
def get_default_rnn_state(self):
num_layers = self.rnn_layers
if self.rnn_name == 'lstm':
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units), device=self.device),
torch.zeros((num_layers, self.num_seqs, self.rnn_units), device=self.device))
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units), device=self.device),)
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),)

def build(self, name, **kwargs):
net = A2CResnetBuilder.Network(self.params, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def device(self):
def reset_envs(self):
if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
self.obs = self.env_reset()

def init_tensors(self):
Expand All @@ -479,6 +480,7 @@ def init_tensors(self):

if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]

total_agents = self.num_agents * self.num_actors
num_seqs = self.horizon_length // self.seq_length
Expand Down

0 comments on commit 5e5511e

Please sign in to comment.