Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training SHAP on ViT for custom dataset #9

Open
mdabedr opened this issue Apr 25, 2024 · 3 comments
Open

Training SHAP on ViT for custom dataset #9

mdabedr opened this issue Apr 25, 2024 · 3 comments

Comments

@mdabedr
Copy link

mdabedr commented Apr 25, 2024

Hello, could you please provide some guidelines on how to obtain SHAP values for a finetuned vision transformer for custom dataset?

I am finetuning a google/vit-base-patch16-224-in21k with a classifier head on my own dataset. How can I get Shapley values with it?

@chanwkimlab
Copy link
Collaborator

Hi, thanks for your interest in our work. Once you get your fine-tuned ViT classifier, the next step is to train surrogate model where your ViT model is finetuned with random masking so that it can acommodate held-out image patches. The final step is to train explainer model using our custom loss function. The scripts for each step are available here.

@MirekJara
Copy link

Hi @chanwkimlab,

I'm currently trying to use scripts for training surrogate model. Based on this lines of code in main.py (lines 63-70):

    if datasets == "MURA":
        datamodule = MURADataModule(**dataset_parameters)
    elif datasets == "ImageNette":
        datamodule = ImageNetteDataModule(**dataset_parameters)
    elif datasets == "Pet":
        datamodule = PetDataModule(**dataset_parameters)
    else:
        ValueError("Invalid 'datasets' configuration")

I asssume that i need to implement Dataset and Datamodule classes for my own datasets. Is that right or is there some more straightforward way to do this? If you'd known about some repository that uses that in such a way, that would also be a huge help.

Anyway thanks in advance!

@chanwkimlab
Copy link
Collaborator

You may need to slightly modify the dataset implementation to fit your data, as the current ViT Shapley implementation expects a specific output format for the __getitem__ function :{"images": img, "labels": label, "path": img_path}. https://github.com/suinleelab/vit-shapley/blob/master/vit_shapley/datamodules/datasets/base_dataset.py#L233

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants