Skip to content

Commit

Permalink
Add pretrained MAE weights, option to load checkpoints in ViT builder
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers committed Oct 4, 2023
1 parent 0793eb4 commit ff7d657
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torchmultimodal/models/masked_auto_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
)


MAE_MODEL_MAPPING = {
"vit_b16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_base.pth",
"vit_l16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_large.pth",
"vit_b16_audio": "https://download.pytorch.org/models/multimodal/audio_mae/audio_mae_pretrained_vit_base.pth",
}


class MAEOutput(NamedTuple):
encoder_output: Union[TransformerOutput, Tensor]
decoder_pred: Optional[Tensor] = None
Expand Down
4 changes: 4 additions & 0 deletions torchmultimodal/modules/encoders/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TransformerEncoder,
TransformerOutput,
)
from torchmultimodal.utils.common import load_module_from_url


class VisionTransformer(nn.Module):
Expand Down Expand Up @@ -148,6 +149,7 @@ def vision_transformer(
drop_path_rate: Optional[float] = None,
patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None,
pooler: Optional[nn.Module] = None,
ckpt_path: str = None,
) -> VisionTransformer:
"""
Args:
Expand Down Expand Up @@ -198,6 +200,8 @@ def vision_transformer(
vit = VisionTransformer(
embeddings=image_embedding, encoder=transformer_encoder, pooler=pooler
)
if ckpt_path:
load_module_from_url(vit, ckpt_path)
return vit


Expand Down

0 comments on commit ff7d657

Please sign in to comment.