-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CoCa model implementation #517
Comments
Hi @seungkyuK thanks for creating the issue! Sorry for the delayed reply, I missed this one over the holidays. You are right about (1): we need to change On (2): this is an interesting case. Actually we went back and forth on whether to include the CLS token in CoCa's vision encoder at all because it is not really clear from the paper that they use it. The open_clip implementation (which we compared against in #507) does use it, but they also only used global average pooling in their original implementation. However, our read of the pseudocode in Figure 2 of the paper was that they do not use CLS. As a result you'll see that most of our models default to If you are setting Then you are correct that the CLS embedding is no longer used directly (and actually I think this is true regardless of whether we use cascaded or parallel attention poolers). One thing we could do is modify For now I will at least make the change to fix (1) and set |
Summary: A couple fixes to CoCa's attention pooling as pointed out in #517. Specifically, we need to change the input dim for the contrastive pooler to match the output dim from the captioning pooler in the case of cascaded attention pooling. We should also set `n_queries=1` for the contrastive pooler so that the pooled embeddings can be directly fed into contrastive loss (after appropriate normalization). Pull Request resolved: #518 Test Plan: ``` from torchmultimodal.models.coca.coca_model import coca_vit_l_14 model = coca_vit_l_14() bs, c, h, w, seq_len, vocab_size = 2, 3, 224, 224, 77, 49408 images = torch.randn(bs, c, h, w) texts = torch.randint(0, vocab_size, (bs, seq_len)) out = model(images, texts) print(out.image_pooled_output.shape, out.multimodal_embeddings.shape) ... torch.Size([2, 1, 768]) torch.Size([2, 76, 49408]) ``` Add new unit test: ``` python -m pytest -v tests/models/coca/test_coca_model.py ... ===== 4 passed in 3.18s ====== ``` Reviewed By: pbontrager Differential Revision: D52523771 Pulled By: ebsmothers fbshipit-source-id: 7c0197605e478ae6e3204f1ec0ab2e6adbf2377e
🚀 The feature, motivation and pitch
Thank you for your awesome works!
I have some questions about CoCa model implementation.
In */multimodal/torchmultimodal/models/coca/coca_model.py, it seems like we can decide whether using CascadedAttentionPooler or just single AttentionPooler.
However, when using CascadedAttentionPooler, dimensions are not matched at the second loop.
For example, after vision feature is extracted from VisionEncoder and its feature has shape of (B, h*w, dim).
It has to pass through vision_pooler layers (pooled_outputs = self.vision_pooler(image_embeddings)) and when using CascadedAttentionPooler, 'self.vision_pooler' class has 2 sequential AttentionPooler layers.
After passed through 1st AttentionPooler layer, feature has shape of (B,256,q_dim) and it doesn't matched with the LayerNorm in the second loop which is supporting 'dim', not 'q_dim'.
Is it okay if I arbitrarily modify the input dimension of the second AttnetionPooler layer?
Similary, when using 'vision_cls_token' with CascadedAttentionPooler, shape of vision feature is (B, h*w + 1(cls), dim) (e.g., B,1025,768).
And at the vision_pooler layer, it return learnable tokens after cross-attention with vision feature and it has (B,256,q_dim) shape for each captioning_image_embeddings, and contrastive_image_embeddings, respectively.
If you intended to not using visual features directly, is it necessary to add 'cls_token' at the initial stage?
I mean, what is the purpose of adding 'cls_token' at the front of visual features even though, you're not using them directly.
Thank you again!
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: