Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnecessary conv3 Layer Weights in DAE with MONAI v0.8.0 Leading to Compatibility Issues with v0.9.0+ #367

Open
simojens opened this issue Mar 31, 2024 · 0 comments

Comments

@simojens
Copy link

Description

In MONAI version 0.8.0, UnetResBlock creates a conv3 layer regardless of its use, leading to unnecessary weight inclusion in the SwinUNETR architecture within the DAE ssl weights. Specifically, conv3 layers in encoder2, encoder3, encoder4, and encoder10 seem redundant as these blocks have identical input and output feature dimensions. This redundancy becomes problematic when upgrading to MONAI v0.9.0+, where conv3 is only instantiated if needed, causing errors due to missing layers.

Monai v0.8.0 snippet shows unconditional creation:

  def __init__(self, ...):
     # Rest of code    
     self.conv3 = get_conv_layer(
            spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True
        )
      self.downsample = in_channels != out_channels
      stride_np = np.atleast_1d(stride)
      if not np.all(stride_np == 1):
          self.downsample = True

  def forward(self, inp):
      residual = inp
      # Rest of code
      if self.downsample:
          residual = self.conv3(residual)
          residual = self.norm3(residual)

In Monai v0.9.0+ the conv3 layer is created conditionally based on input and output channels. The DAE ssl weights includes conv3 in all UnetResBlocks, also in layers where they are not used. When loading ssl weights when finetuning using the DAE/BTCV_Finetune repositiory using Monai v0.9.0+ the code throws error since the expected conv3 layers are never created.

load_from in DAE/BTCV_Finetune/swin_unetr_og.py includes the following lines:

    self.encoder1.layer.conv1.conv.weight.copy_(weights["model"]["encoder1.layer.conv1.conv.weight"])
    self.encoder1.layer.conv2.conv.weight.copy_(weights["model"]["encoder1.layer.conv2.conv.weight"])
    self.encoder1.layer.conv3.conv.weight.copy_(weights["model"]["encoder1.layer.conv3.conv.weight"])
  
    self.encoder2.layer.conv1.conv.weight.copy_(weights["model"]["encoder2.layer.conv1.conv.weight"])
    self.encoder2.layer.conv2.conv.weight.copy_(weights["model"]["encoder2.layer.conv2.conv.weight"])
    self.encoder2.layer.conv3.conv.weight.copy_(weights["model"]["encoder2.layer.conv3.conv.weight"])
  
    self.encoder3.layer.conv1.conv.weight.copy_(weights["model"]["encoder3.layer.conv1.conv.weight"])
    self.encoder3.layer.conv2.conv.weight.copy_(weights["model"]["encoder3.layer.conv2.conv.weight"])
    self.encoder3.layer.conv3.conv.weight.copy_(weights["model"]["encoder3.layer.conv3.conv.weight"])
  
    self.encoder4.layer.conv1.conv.weight.copy_(weights["model"]["encoder4.layer.conv1.conv.weight"])
    self.encoder4.layer.conv2.conv.weight.copy_(weights["model"]["encoder4.layer.conv2.conv.weight"])
    self.encoder4.layer.conv3.conv.weight.copy_(weights["model"]["encoder4.layer.conv3.conv.weight"])
  
    self.encoder10.layer.conv1.conv.weight.copy_(weights["model"]["encoder10.layer.conv1.conv.weight"])
    self.encoder10.layer.conv2.conv.weight.copy_(weights["model"]["encoder10.layer.conv2.conv.weight"])
    self.encoder10.layer.conv3.conv.weight.copy_(weights["model"]["encoder10.layer.conv3.conv.weight"])

Proposed solution

Avoid loading weights for conv3 in blocks encoder2, encoder3, encoder4 and encoder10, where the input and output feature sizes are the same, to ensure compatibility with newer monai versions.

Could also be beneficial to remove the mentioned weights from DAE_SSL_WEIGHTS to avoid further confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant