From 61d1bc47708d0b2c2c0807ddc9e0b03a114d709a Mon Sep 17 00:00:00 2001 From: ViktorM Date: Tue, 20 Aug 2024 12:19:17 -0700 Subject: [PATCH] Added convnext and vit backbones support. Added preprocessing. --- rl_games/algos_torch/network_builder.py | 52 ++++++++++++++----- .../atari/ppo_pong_envpool_backbone.yaml | 22 +++----- 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 3323dc84..4600f016 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -1071,7 +1071,20 @@ def build(self, name, **kwargs): return net -from torchvision import models +from torchvision import models, transforms + +def preprocess_image(image): + # Normalize the image using ImageNet's mean and standard deviation + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], # Mean of ImageNet dataset + std=[0.229, 0.224, 0.225] # Std of ImageNet dataset + ) + + # Apply the normalization + image = normalize(image) + + return image + class VisionBackboneBuilder(NetworkBuilder): def __init__(self, **kwargs): @@ -1103,6 +1116,7 @@ def __init__(self, params, **kwargs): self.cnn, self.cnn_output_size = self._build_backbone(input_shape, params['backbone']) + self.resize_transform = transforms.Resize((224, 224)) mlp_input_size = self.cnn_output_size + self.proprio_size if len(self.units) == 0: out_size = self.cnn_output_size @@ -1176,16 +1190,20 @@ def forward(self, obs_dict): if self.permute_input: obs = obs.permute((0, 3, 1, 2)) + if self.preprocess_image: + obs = preprocess_image(obs) + + # Assuming your input image is a tensor or PIL image, resize it to 224x224 + #obs = self.resize_transform(obs) + dones = obs_dict.get('dones', None) bptt_len = obs_dict.get('bptt_len', 0) states = obs_dict.get('rnn_states', None) out = obs out = self.cnn(out) - #print(out.shape) out = out.flatten(1) - #print(out.shape) - #print('AAAAAAAAAAAAAAAAAaaa') + out = self.flatten_act(out) if self.proprio_size > 0: @@ -1272,12 +1290,15 @@ def load(self, params): def _build_backbone(self, input_shape, backbone_params): backbone_type = backbone_params['type'] pretrained = backbone_params.get('pretrained', False) + self.preprocess_image = backbone_params.get('preprocess_image', False) if backbone_type == 'resnet18': backbone = models.resnet18(pretrained=pretrained, zero_init_residual=True) # norm_layer=nn.LayerNorm # Modify the first convolution layer to match input shape if needed + + # TODO: add low-res parameter backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) - backbone.maxpool = nn.Identity() + #backbone.maxpool = nn.Identity() # if input_shape[0] != 3: # model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False) # Remove the fully connected layer @@ -1289,16 +1310,21 @@ def _build_backbone(self, input_shape, backbone_params): backbone_output_size = backbone.classifier[2].in_features backbone.classifier = nn.Identity() - # Do we need it? - # backbone = nn.Sequential(*list(backbone.children())[:-1]) - elif backbone_type == 'vit_tiny_patch16_224': - backbone = models.vit_small_patch16_224(pretrained=pretrained) - backbone_output_size = backbone.heads.head.in_features - backbone.heads.head = nn.Identity() + # Modify the first convolutional layer to work with smaller resolutions + backbone.features[0][0] = nn.Conv2d( + in_channels=input_shape[0], + out_channels=backbone.features[0][0].out_channels, + kernel_size=3, # Reduce kernel size to 3x3 + stride=1, # Reduce stride to 1 to preserve spatial resolution + padding=1, # Add padding to preserve dimensions after convolution + bias=True # False + ) - # ViT outputs a single token, so no need to remove layers - # Is it true? + elif backbone_type == 'vit_b_16': + backbone = models.vision_transformer.vit_b_16(pretrained=pretrained) + backbone_output_size = backbone.heads.head.in_features + backbone.heads.head = nn.Identity() else: raise ValueError(f'Unknown backbone type: {backbone_type}') diff --git a/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml b/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml index ff63181f..7afe311f 100644 --- a/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml +++ b/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml @@ -16,21 +16,13 @@ params: value_shape: 1 space: discrete: - - # cnn: - # permute_input: False - # conv_depths: [16, 32, 32] - # activation: relu - # initializer: - # name: default - # regularizer: - # name: 'None' backbone: - type: resnet18 #convnext_tiny #vit_tiny_patch16_224 + type: resnet18 #convnext_tiny #vit_b_16 #resnet18 pretrained: True permute_input: False - freeze: True + freeze: False + preprocess_image: False args: zero_init_residual: True @@ -49,7 +41,7 @@ params: # units: 256 # layers: 1 config: - name: pong_resnet18_pretrained_2_mini_epoch_1e-4_linear_lr_norm_value_frozen + name: pong_resnet18_maxpool env_name: envpool reward_shaper: min_val: -1 @@ -71,7 +63,7 @@ params: e_clip: 0.2 clip_value: True save_best_after: 25 - save_frequency: 50 + save_frequency: 500 num_actors: 32 horizon_length: 64 minibatch_size: 512 @@ -81,14 +73,14 @@ params: kl_threshold: 0.01 use_diagnostics: True seq_length: 8 - max_epochs: 1000 + max_epochs: 2000 #weight_decay: 0.001 env_config: env_name: Pong-v5 has_lives: False use_dict_obs_space: False #True - + stack_num: 3 player: render: True games_num: 10