From 15bd59aad57f521447f7d1d53c9b7d281875b6d9 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Sat, 30 Nov 2024 13:07:53 -0800 Subject: [PATCH] cleanup and simplenet --- rl_games/common/a2c_common.py | 25 +++++++++++++------ rl_games/common/datasets.py | 7 +++--- rl_games/common/player.py | 1 - rl_games/envs/__init__.py | 3 ++- rl_games/envs/test_network.py | 47 +++++++++++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 13 deletions(-) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 2ae2e544..f728045e 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -158,6 +158,7 @@ def __init__(self, base_name, params): self.save_freq = config.get('save_frequency', 0) self.save_best_after = config.get('save_best_after', 100) self.print_stats = config.get('print_stats', True) + self.epochs_between_resets = config.get('epochs_between_resets', 0) self.rnn_states = None self.name = base_name @@ -382,6 +383,12 @@ def set_eval(self): self.model.eval() if self.normalize_rms_advantage: self.advantage_mean_std.eval() + if self.epochs_between_resets > 0: + if self.epoch_num % self.epochs_between_resets == 0: + self.reset_envs() + self.init_current_rewards() + print(f"Forcing env reset after {self.epoch_num} epochs") + def set_train(self): self.model.train() @@ -466,10 +473,7 @@ def init_tensors(self): val_shape = (self.horizon_length, batch_size, self.value_size) current_rewards_shape = (batch_size, self.value_size) - self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) - self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) - self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device) - self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device) + self.init_current_rewards(batch_size, current_rewards_shape) if self.is_rnn: self.rnn_states = self.model.get_default_rnn_state() @@ -480,6 +484,12 @@ def init_tensors(self): assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0) self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] + def init_current_rewards(self, batch_size, current_rewards_shape): + self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) + self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) + self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device) + self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device) + def init_rnn_from_model(self, model): self.is_rnn = self.model.is_rnn() @@ -571,12 +581,12 @@ def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_ext mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * masks_t return mb_advs - def clear_stats(self): - batch_size = self.num_agents * self.num_actors + def clear_stats(self, clean_rewards= True): self.game_rewards.clear() self.game_shaped_rewards.clear() self.game_lengths.clear() - self.mean_rewards = self.last_mean_rewards = -100500 + if clean_rewards: + self.mean_rewards = self.last_mean_rewards = -100500 self.algo_observer.after_clear_stats() def update_epoch(self): @@ -772,6 +782,7 @@ def play_steps(self): self.current_rewards += rewards self.current_shaped_rewards += shaped_rewards self.current_lengths += 1 + all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = all_done_indices[::self.num_agents] diff --git a/rl_games/common/datasets.py b/rl_games/common/datasets.py index 60031ed6..e8bd62ae 100644 --- a/rl_games/common/datasets.py +++ b/rl_games/common/datasets.py @@ -22,13 +22,12 @@ def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_ self.length = self.batch_size // self.minibatch_size self.is_discrete = is_discrete self.is_continuous = not is_discrete - total_games = self.batch_size // self.seq_length self.num_games_batch = self.minibatch_size // self.seq_length - self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device) - self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length) + self.special_names = ['rnn_states'] self.permute = permute - self.permutation_indices = torch.arange(self.batch_size, dtype=torch.long, device=self.device) + if self.permute: + self.permutation_indices = torch.arange(self.batch_size, dtype=torch.long, device=self.device) def update_values_dict(self, values_dict): """Update the internal values dictionary.""" diff --git a/rl_games/common/player.py b/rl_games/common/player.py index 8ae96936..77849bab 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -342,7 +342,6 @@ def run(self): done_indices = all_done_indices[::self.num_agents] done_count = len(done_indices) games_played += done_count - print(games_played) if done_count > 0: if self.is_rnn: for s in self.states: diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index b906c43d..4e72221b 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -1,7 +1,8 @@ -from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder +from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder, SimpleNetBuilder from rl_games.algos_torch import model_builder model_builder.register_network('testnet', TestNetBuilder) +model_builder.register_network('simplenet', SimpleNetBuilder) model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder) \ No newline at end of file diff --git a/rl_games/envs/test_network.py b/rl_games/envs/test_network.py index 7adfae90..51e28117 100644 --- a/rl_games/envs/test_network.py +++ b/rl_games/envs/test_network.py @@ -114,5 +114,52 @@ def load(self, params): def build(self, name, **kwargs): return TestNetWithAuxLoss(self.params, **kwargs) + def __call__(self, name, **kwargs): + return self.build(name, **kwargs) + + + +class SimpleNet(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + nn.Module.__init__(self) + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + num_inputs =input_shape[0] + self.actions_num = actions_num + self.central_value = params.get('central_value', False) + self.value_size = kwargs.pop('value_size', 1) + self.linear = torch.nn.Sequential( + nn.Linear(num_inputs, 512), + nn.SiLU(), + nn.Linear(512, 256), + nn.SiLU(), + nn.Linear(256, 128), + nn.SiLU(), + nn.Linear(128, actions_num + 1), + ) + self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True) + + def is_rnn(self): + return False + @torch.compile + def forward(self, obs): + obs = obs['obs'] + x = self.linear(obs) + mu, value = torch.split(x, [self.actions_num, 1], dim=-1) + return mu, self.sigma.unsqueeze(0).expand(mu.size()[0], self.actions_num), value, None + + + + +class SimpleNetBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + return SimpleNet(self.params, **kwargs) + def __call__(self, name, **kwargs): return self.build(name, **kwargs) \ No newline at end of file