Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is SlimSAM compatible with SAM? #4

Open
NielsRogge opened this issue Dec 25, 2023 · 5 comments
Open

Is SlimSAM compatible with SAM? #4

NielsRogge opened this issue Dec 25, 2023 · 5 comments

Comments

@NielsRogge
Copy link

Hi,

Can we plug in the SlimSAM weights into SAM (by recombining the q, k, v weights into a single matrix per layer)?

If yes, then SlimSAM could be ported easily to the 🤗 hub. Currently I'm getting errors like:

size mismatch for vision_encoder.layers.7.mlp.lin2.bias: copying a param with shape torch.Size([168]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for vision_encoder.layers.8.layer_norm1.weight: copying a param with shape torch.Size([168]) from checkpoint, the shape in current model is torch.Size([768]).

It seems that the dimensions are different per layer based on the pruning. Any way to load such a state dict in PyTorch?

@czg1225
Copy link
Owner

czg1225 commented Dec 25, 2023

Hello @NielsRogge,
Thanks for the issue. Currently, there are difficulties in loading a state dict for SlimSAM, primarily due to the use of global pruning in the model compression process. This technique results in altered intermediate dimensions for each vit within the encoder. To facilitate easier state dict loading, we anticipate the release of SlimSAM versions utilizing local uniform pruning by next week. Additionally, the q, k, v weights will be recombined into a single matrix in the new SlimSAMs. We will notify you once released. Thanks!

@NielsRogge
Copy link
Author

Awesome, looking forward!

@czg1225
Copy link
Owner

czg1225 commented Jan 7, 2024

Hello @NielsRogge,
Here are local pruning SlimSAM models:

Above models can be instantiated by running

import torch
from segment_anything import sam_model_registry

model_type = 'vit_p50'
checkpoint = 'checkpoints/SlimSAM-50-uniform.pth'
SlimSAM_model = sam_model_registry[model_type](checkpoint=checkpoint)
SlimSAM_model.to(device)
SlimSAM_model.eval()

@NielsRogge
Copy link
Author

NielsRogge commented Jan 7, 2024

Very cool, just converted and pushed the checkpoints:

One can use them as explained in the Hugging Face docs.

Would you be interested in transferring these checkpoints to your account/the University of Singapore organization on the Hugging Face hub? Cause currently they are part of my account (nielsr)

Also we can add some nice model cards (READMEs)

@czg1225
Copy link
Owner

czg1225 commented Jan 9, 2024

Sure, thanks for your help!
https://huggingface.co/Zigeng/SlimSAM-uniform-50
https://huggingface.co/Zigeng/SlimSAM-uniform-77

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants