From 39e0ff805f1ae3932a4416d41d767924ec48e694 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Tue, 10 Sep 2024 11:54:39 -0700 Subject: [PATCH] Aux loss is now optional. Confif fix. --- .../maniskill/maniskill_pickcube_vision.yaml | 4 +- rl_games/envs/maniskill.py | 7 ++-- rl_games/networks/vision_networks.py | 41 +++++++++++-------- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml index 18f3039c..6b687393 100644 --- a/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml +++ b/rl_games/configs/maniskill/maniskill_pickcube_vision.yaml @@ -55,7 +55,7 @@ params: concat_output: True config: - name: PickCube_RGB_resnet18_LSTM_norm_embedding_64envs_auxloss + name: PickCube_RGB_resnet18_LSTM_norm_embedding_128envs_2e-4_linear_lr_first_layer_retrain env_name: maniskill reward_shaper: scale_value: 1.0 @@ -72,7 +72,7 @@ params: scale_value: 1.0 gamma: 0.99 tau : 0.95 - learning_rate: 1e-4 + learning_rate: 2e-4 lr_schedule: linear kl_threshold: 0.008 max_epochs: 20000 diff --git a/rl_games/envs/maniskill.py b/rl_games/envs/maniskill.py index 40e86b5c..c744907a 100644 --- a/rl_games/envs/maniskill.py +++ b/rl_games/envs/maniskill.py @@ -66,8 +66,9 @@ def observation(self, observation: Dict): # print("Observation:", observation.keys()) # for key, value in observation.items(): # print(key, value.keys()) - aux_target = observation['extra']['aux_target'] - del observation['extra']['aux_target'] + if self.aux_loss: + aux_target = observation['extra']['aux_target'] + del observation['extra']['aux_target'] # print("Input Obs:", observation.keys()) # print("Input Obs Agent:", observation['agent'].keys()) # print("Input Obs Extra:", observation['extra'].keys()) @@ -109,7 +110,7 @@ def __init__(self, config_name, num_envs, **kwargs): # an observation type and space, see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/observation.html for details self.obs_mode = kwargs.pop('obs_mode', 'state') # can be one of ['pointcloud', 'rgbd', 'state_dict', 'state'] - self.aux_loss = kwargs.pop('aux_loss', True) + self.aux_loss = kwargs.pop('aux_loss', False) # a controller type / action space, see https://maniskill.readthedocs.io/en/latest/user_guide/concepts/controllers.html for a full list # can be one of ['pd_ee_delta_pose', 'pd_ee_delta_pos', 'pd_joint_delta_pos', 'arm_pd_joint_pos_vel'] diff --git a/rl_games/networks/vision_networks.py b/rl_games/networks/vision_networks.py index ad503076..d4d136b3 100644 --- a/rl_games/networks/vision_networks.py +++ b/rl_games/networks/vision_networks.py @@ -70,12 +70,14 @@ def __init__(self, params, **kwargs): } self.mlp = self._build_mlp(**mlp_args) - - self.aux_loss_linear = nn.Linear(out_size, self.target_shape) - self.aux_loss_map = { - 'aux_dist_loss': None - } + # TODO: implement for Impala + self.aux_loss_map = None + if self.use_aux_loss: + self.aux_loss_linear = nn.Linear(out_size, self.target_shape) + self.aux_loss_map = { + 'aux_dist_loss': None + } self.value = self._build_value_layer(out_size, self.value_size) self.value_act = self.activations_factory.create(self.value_activation) @@ -283,9 +285,13 @@ def __init__(self, params, **kwargs): print('full_input_shape: ', full_input_shape) - self.target_key = 'aux_target' - self.target_shape = full_input_shape[self.target_key] - print("Target shape: ", self.target_shape) + self.use_aux_loss = kwargs.pop('use_aux_loss', False) + + if self.use_aux_loss: + self.target_key = 'aux_target' + if 'aux_target' in full_input_shape: + self.target_shape = full_input_shape[self.target_key] + print("Target shape: ", self.target_shape) print("Observations shape: ", full_input_shape) @@ -341,11 +347,12 @@ def __init__(self, params, **kwargs): self.mlp = self._build_mlp(**mlp_args) - self.aux_loss_linear = nn.Linear(out_size, self.target_shape[0]) - - self.aux_loss_map = { - 'aux_dist_loss': None - } + self.aux_loss_map = None + if self.use_aux_loss: + self.aux_loss_linear = nn.Linear(out_size, self.target_shape) + self.aux_loss_map = { + 'aux_dist_loss': None + } self.value = self._build_value_layer(out_size, self.value_size) self.value_act = self.activations_factory.create(self.value_activation) @@ -392,7 +399,8 @@ def forward(self, obs_dict): else: obs = obs_dict['obs'] - target_obs = obs_dict['obs'][self.target_key] + if self.use_aux_loss: + target_obs = obs_dict['obs'][self.target_key] # print('obs.min(): ', obs.min()) # print('obs.max(): ', obs.max()) @@ -452,8 +460,9 @@ def forward(self, obs_dict): value = self.value_act(self.value(out)) - y = self.aux_loss_linear(out) - self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) + if self.use_aux_loss: + y = self.aux_loss_linear(out) + self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) if self.is_discrete: logits = self.logits(out)