From 1fd96dcbc3eac99860b0de4b63dcb08a616fa57e Mon Sep 17 00:00:00 2001 From: Peng Chen Date: Thu, 12 Oct 2023 15:45:45 -0700 Subject: [PATCH] Port BLIP-2 entire model to OSS folder (#487) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/487 as title Reviewed By: ebsmothers Differential Revision: D50236403 fbshipit-source-id: 06fab6bb9324d0425679394f85935c6cd4ca373f --- tests/models/blip2/__init__.py | 5 + tests/models/blip2/test_blip2.py | 137 ++++++ tests/models/blip2/test_qformer_layers.py | 452 ++++++++++++++++++ tests/models/blip2/test_qformer_model.py | 414 ++++++++++++++++ tests/models/blip2/test_qformer_utils.py | 77 +++ tests/modules/losses/test_blip2_loss.py | 331 +++++++++++++ torchmultimodal/models/blip2/__init__.py | 5 + torchmultimodal/models/blip2/blip2.py | 157 ++++++ .../models/blip2/qformer_layers.py | 387 +++++++++++++++ torchmultimodal/models/blip2/qformer_model.py | 294 ++++++++++++ torchmultimodal/models/blip2/qformer_utils.py | 71 +++ .../modules/losses/blip2_losses.py | 360 ++++++++++++++ 12 files changed, 2690 insertions(+) create mode 100644 tests/models/blip2/__init__.py create mode 100644 tests/models/blip2/test_blip2.py create mode 100644 tests/models/blip2/test_qformer_layers.py create mode 100644 tests/models/blip2/test_qformer_model.py create mode 100644 tests/models/blip2/test_qformer_utils.py create mode 100644 tests/modules/losses/test_blip2_loss.py create mode 100644 torchmultimodal/models/blip2/__init__.py create mode 100644 torchmultimodal/models/blip2/blip2.py create mode 100644 torchmultimodal/models/blip2/qformer_layers.py create mode 100644 torchmultimodal/models/blip2/qformer_model.py create mode 100644 torchmultimodal/models/blip2/qformer_utils.py create mode 100644 torchmultimodal/modules/losses/blip2_losses.py diff --git a/tests/models/blip2/__init__.py b/tests/models/blip2/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/models/blip2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/models/blip2/test_blip2.py b/tests/models/blip2/test_blip2.py new file mode 100644 index 000000000..e9a294c44 --- /dev/null +++ b/tests/models/blip2/test_blip2.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch +import torch.nn as nn +from tests.test_utils import assert_expected, init_weights_with_constant +from torchmultimodal.models.blip2.blip2 import BLIP2 +from torchmultimodal.models.blip2.qformer_model import QformerForCLM +from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer +from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings +from torchmultimodal.modules.layers.transformer import TransformerEncoder + + +@pytest.fixture +def dim_q(): + return 4 + + +@pytest.fixture +def dim_kv(): + return 2 + + +@pytest.fixture +def dim_feedforward(): + return 6 + + +@pytest.fixture +def num_hidden_layers(): + return 2 + + +@pytest.fixture +def num_heads(): + return 2 + + +@pytest.fixture +def vocab_size(): + return 20 + + +@pytest.fixture +def qformer_model_for_clm( + dim_q, + dim_kv, + dim_feedforward, + num_hidden_layers, + num_heads, + vocab_size, +): + qformer_for_clm = QformerForCLM( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + num_hidden_layers=num_hidden_layers, + max_position_embeddings=512, + vocab_size=vocab_size, + ) + return qformer_for_clm + + +@pytest.fixture +def vit(): + embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2) + encoder = TransformerEncoder( + n_layer=1, + d_model=2, + n_head=1, + dim_feedforward=1, + activation=nn.GELU, + norm_first=True, + final_layer_norm_eps=1e-5, + ) + image_encoder = VisionTransformer( + embeddings=embedding, + encoder=encoder, + ) + init_weights_with_constant(image_encoder) + image_encoder.eval() + return image_encoder + + +@pytest.fixture +def blip2(dim_q, dim_kv, qformer_model_for_clm, vit): + blip2 = BLIP2( + dim_q=dim_q, + image_encoder_embedding_dim=dim_kv, + qformer=qformer_model_for_clm, + vision_encoder=vit, + embedding_dim=4, + decoder_bos_token_id=19, + ) + init_weights_with_constant(blip2) + blip2.eval() + return blip2 + + +@pytest.fixture +def attn_mask(): + return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]]) + + +class TestBLIP2: + def test_blip2(self, blip2, attn_mask): + image = torch.ones(2, 3, 2, 2) + input_ids = torch.ones(2, 4).long() + output = blip2(image, input_ids, attn_mask) + assert_expected( + output.image_features, torch.ones([2, 32, 4]) * 0.5, rtol=0, atol=1e-4 + ) + assert_expected( + output.text_features, torch.ones([2, 4]) * 0.5, rtol=0, atol=1e-4 + ) + assert_expected( + output.image_embeddings, torch.ones([2, 5, 2]), rtol=0, atol=1e-4 + ) + assert_expected( + output.prediction_scores, torch.ones([2, 4, 20]) * 5, rtol=0, atol=1e-4 + ) + + def test_blip2_scripting(self, blip2, attn_mask): + image = torch.ones(2, 3, 2, 2) + input_ids = torch.ones(2, 4).long() + scripted_model = torch.jit.script(blip2) + actual = scripted_model(image, input_ids, attn_mask) + expected = blip2(image, input_ids, attn_mask) + assert_expected(actual, expected) diff --git a/tests/models/blip2/test_qformer_layers.py b/tests/models/blip2/test_qformer_layers.py new file mode 100644 index 000000000..adbb0e017 --- /dev/null +++ b/tests/models/blip2/test_qformer_layers.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torch import nn +from torchmultimodal.models.blip2.qformer_layers import ( + QformerEmbedding, + QformerEncoder, + QformerLayer, +) + + +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(0) + + +class TestQformerWithMHA: + @pytest.fixture + def dim_q(self): + return 4 + + @pytest.fixture + def dim_kv(self): + return 2 + + @pytest.fixture + def dim_feedforward(self): + return 6 + + @pytest.fixture + def cross_attention_freq(self): + return 2 + + @pytest.fixture + def num_hidden_layers(self): + return 2 + + @pytest.fixture + def num_heads(self): + return 2 + + @pytest.fixture() + def input_ids(self): + return torch.LongTensor([[0, 1], [2, 3]]) + + @pytest.fixture() + def query_embeddings(self): + return torch.Tensor( + [ + [ + [0.6424, 0.6182, 0.5110, 0.7867], + [0.3907, 0.2057, 0.6909, 0.6334], + ], + [ + [0.6904, 0.4445, 0.4336, 0.4603], + [0.6318, 0.1163, 0.0340, 0.6871], + ], + ] + ) + + @pytest.fixture + def q(self): + return torch.Tensor([[[1, 2, 3, 1], [4, 3, 2, 1], [1, 1, 1, 1]]]) + + @pytest.fixture + def kv(self): + return torch.Tensor([[[3, 2], [1, 1]]]) + + @pytest.fixture + def current_key_value(self): + return torch.Tensor( + [ + [ + [[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]], + [[8.0, 8.0], [11.0, 11.0], [5.0, 5.0]], + ] + ] + ) + + @pytest.fixture + def past_key_value(self): + return torch.Tensor( + [ + [ + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + ] + ] + ) + + @pytest.fixture + def past_key_values(self, past_key_value, num_hidden_layers): + past_key_values = [] + for i in range(num_hidden_layers): + past_key_values.append((past_key_value, past_key_value)) + return past_key_values + + @pytest.fixture + def qformer_layer_self_attention_only(self, dim_q, dim_feedforward, num_heads): + qformer_layer = QformerLayer( + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + has_cross_attention=False, + ) + init_weights_with_constant(qformer_layer) + qformer_layer.eval() + return qformer_layer + + @pytest.fixture + def qformer_layer_with_cross_attention( + self, + dim_q, + dim_kv, + dim_feedforward, + num_heads, + ): + qformer_layer = QformerLayer( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + activation=nn.ReLU, + has_cross_attention=True, + ) + init_weights_with_constant(qformer_layer) + # modify query feedforward params to test cross attention case with different query lengths + init_weights_with_constant(qformer_layer.feedforward_query, 2.0) + init_weights_with_constant(qformer_layer.feedforward_layernorm_query, 2.0) + qformer_layer.eval() + return qformer_layer + + @pytest.fixture + def qformer_encoder( + self, + dim_q, + dim_kv, + dim_feedforward, + cross_attention_freq, + num_hidden_layers, + num_heads, + ): + qformer_encoder = QformerEncoder( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + cross_attention_freq=cross_attention_freq, + num_hidden_layers=num_hidden_layers, + ) + init_weights_with_constant(qformer_encoder) + qformer_encoder.eval() + return qformer_encoder + + def test_qformer_layer_self_attention_only( + self, qformer_layer_self_attention_only, current_key_value, past_key_value, q + ): + actual = qformer_layer_self_attention_only( + q, past_key_value=(past_key_value, past_key_value), use_cache=True + ) + expected = torch.Tensor( + [ + [ + [0.0955, 1.3015, 2.5076, 0.0955], + [2.3416, 1.4472, 0.5528, -0.3416], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_layer_with_cross_attention_only_query( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + # test with query length < attn_residual.shape[1] + actual = qformer_layer_with_cross_attention( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=2, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_layer_with_cross_attention_query_and_text( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + # test with query length >= attn_residual.shape[1] + actual = qformer_layer_with_cross_attention( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=3, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [2.0000, 2.0000, 2.0000, 2.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_qformer_encoder( + self, + qformer_encoder, + past_key_values, + current_key_value, + past_key_value, + q, + kv, + ): + actual = qformer_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + expected_hidden_state = torch.Tensor( + [ + [ + [0.0955, 1.3015, 2.5076, 0.0955], + [2.3416, 1.4472, 0.5528, -0.3416], + [1.0000, 1.0000, 1.0000, 1.0000], + ] + ] + ) + expected_key_value = torch.Tensor( + [ + [ + [[5.0, 5.0], [5.0, 5.0], [5.0, 5.0]], + [[5.0, 5.0], [5.0, 5.0], [5.0, 5.0]], + ] + ] + ) + assert_expected(actual[0], expected_hidden_state, rtol=0, atol=1e-4) + assert_expected( + actual[1][0][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][0][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1][0], + torch.cat([past_key_value, expected_key_value], dim=2), + ) + assert_expected( + actual[1][1][1], + torch.cat([past_key_value, expected_key_value], dim=2), + ) + + def test_layer_scripting( + self, + qformer_layer_with_cross_attention, + current_key_value, + past_key_value, + q, + kv, + ): + scripted_model = torch.jit.script(qformer_layer_with_cross_attention) + actual = scripted_model( + q, + kv, + past_key_value=(past_key_value, past_key_value), + query_length=3, + use_cache=True, + ) + expected = torch.Tensor( + [ + [ + [0.1909, 2.6030, 5.0151, 0.1909], + [4.6833, 2.8944, 1.1056, -0.6833], + [2.0000, 2.0000, 2.0000, 2.0000], + ] + ] + ) + assert_expected(actual[0], expected, rtol=0, atol=1e-4) + assert_expected( + actual[1][0], + torch.cat([past_key_value, current_key_value], dim=2), + ) + assert_expected( + actual[1][1], + torch.cat([past_key_value, current_key_value], dim=2), + ) + + def test_encoder_scripting( + self, + qformer_encoder, + past_key_values, + current_key_value, + past_key_value, + q, + kv, + ): + scripted_encoder = torch.jit.script(qformer_encoder) + actual = scripted_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + expected = qformer_encoder( + q, kv, past_key_values=past_key_values, query_length=2, use_cache=True + ) + assert_expected(actual[0], expected[0]) + assert_expected(actual[1], expected[1]) + assert len(actual) == len(expected) + + @pytest.fixture + def qformer_emb(self, dim_q): + return QformerEmbedding( + embedding_dim=dim_q, + max_position_embeddings=512, + vocab_size=20, + ) + + def test_qformer_embedding(self, input_ids, query_embeddings, qformer_emb): + actual = qformer_emb( + input_ids=input_ids, + query_embeddings=query_embeddings, + ) + expected_value = torch.Tensor( + [ + [ + [0.0287, -0.2175, -1.3081, 1.4969], + [-0.4602, -1.4116, 1.0838, 0.7880], + [-0.0600, 1.3838, -1.4382, 0.1144], + [1.1554, 0.0435, 0.3865, -1.5855], + ], + [ + [1.7251, -0.5904, -0.6931, -0.4416], + [0.8989, -0.8530, -1.1327, 1.0868], + [0.8951, -1.1037, -0.8854, 1.0940], + [-0.0748, -0.2439, 1.5529, -1.2342], + ], + ] + ) + # expected dim [bsz, num_token, embed_dim] + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_empty_input_ids( + self, + query_embeddings, + qformer_emb, + ): + actual = qformer_emb( + query_embeddings=query_embeddings, + ) + expected_value = torch.Tensor( + [ + [ + [0.0287, -0.2175, -1.3081, 1.4969], + [-0.4602, -1.4116, 1.0838, 0.7880], + ], + [ + [1.7251, -0.5904, -0.6931, -0.4416], + [0.8989, -0.8530, -1.1327, 1.0868], + ], + ] + ) + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_empty_query_embs( + self, + input_ids, + qformer_emb, + ): + actual = qformer_emb( + input_ids=input_ids, + ) + expected_value = torch.Tensor( + [ + [ + [-0.0600, 1.3838, -1.4382, 0.1144], + [1.1554, 0.0435, 0.3865, -1.5855], + ], + [ + [0.8951, -1.1037, -0.8854, 1.0940], + [-0.0748, -0.2439, 1.5529, -1.2342], + ], + ] + ) + assert_expected(actual, expected_value, atol=1e-4, rtol=1e-4) + + def test_qformer_embedding_invalid_input( + self, + qformer_emb, + ): + with pytest.raises(ValueError): + qformer_emb() + + def test_embedding_scripting(self, input_ids, qformer_emb, query_embeddings): + scripted_emb = torch.jit.script(qformer_emb) + actual = scripted_emb(input_ids=input_ids, query_embeddings=query_embeddings) + assert_expected( + actual, qformer_emb(input_ids=input_ids, query_embeddings=query_embeddings) + ) diff --git a/tests/models/blip2/test_qformer_model.py b/tests/models/blip2/test_qformer_model.py new file mode 100644 index 000000000..ef5479d2a --- /dev/null +++ b/tests/models/blip2/test_qformer_model.py @@ -0,0 +1,414 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torch.nn import CrossEntropyLoss +from torchmultimodal.models.blip2.qformer_model import QformerForCLM, QformerModel + + +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(0) + + +class TestQformerModel: + @pytest.fixture + def attn_mask(self): + return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]]) + + @pytest.fixture + def dim_q(self): + return 4 + + @pytest.fixture + def dim_kv(self): + return 2 + + @pytest.fixture + def dim_feedforward(self): + return 6 + + @pytest.fixture + def cross_attention_freq(self): + return 2 + + @pytest.fixture + def num_hidden_layers(self): + return 2 + + @pytest.fixture + def num_heads(self): + return 2 + + @pytest.fixture() + def input_ids(self): + return torch.LongTensor([[0, 1], [2, 3]]) + + @pytest.fixture + def vocab_size(self): + return 20 + + @pytest.fixture() + def query_embeddings(self): + return torch.Tensor( + [ + [ + [0.6424, 0.6182, 0.5110, 0.7867], + [0.3907, 0.2057, 0.6909, 0.6334], + ], + [ + [0.6904, 0.4445, 0.4336, 0.4603], + [0.6318, 0.1163, 0.0340, 0.6871], + ], + ] + ) + + @pytest.fixture + def past_key_value(self): + return torch.Tensor( + [ + [ + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + ], + [ + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + [[7.0, 7.0], [9.0, 9.0], [4.0, 4.0]], + ], + ] + ) + + @pytest.fixture + def past_key_values(self, past_key_value, num_hidden_layers): + past_key_values = [] + for i in range(num_hidden_layers): + past_key_values.append((past_key_value, past_key_value)) + return past_key_values + + @pytest.fixture + def kv(self): + return torch.Tensor([[[3, 2], [1, 1]], [[3, 2], [1, 1]]]) + + @pytest.fixture + def labels(self): + labels = torch.ones([2, 2]).long() + return labels[:, 1:].contiguous() + + @pytest.fixture + def loss_fct(self): + return CrossEntropyLoss(reduction="mean", label_smoothing=0.1) + + @pytest.fixture + def qformer_model( + self, + dim_q, + dim_kv, + dim_feedforward, + cross_attention_freq, + num_hidden_layers, + num_heads, + vocab_size, + ): + qformer_model = QformerModel( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + num_hidden_layers=num_hidden_layers, + max_position_embeddings=512, + vocab_size=vocab_size, + query_length=2, + ) + init_weights_with_constant(qformer_model) + qformer_model.eval() + return qformer_model + + @pytest.fixture + def qformer_model_for_clm( + self, + dim_q, + dim_kv, + dim_feedforward, + cross_attention_freq, + num_hidden_layers, + num_heads, + vocab_size, + ): + qformer_for_clm = QformerForCLM( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + num_hidden_layers=num_hidden_layers, + max_position_embeddings=512, + vocab_size=vocab_size, + ) + init_weights_with_constant(qformer_for_clm) + qformer_for_clm.eval() + return qformer_for_clm + + def test_qformer_model_with_attn_mask( + self, + input_ids, + attn_mask, + qformer_model, + query_embeddings, + num_hidden_layers, + kv, + ): + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + ) + expected_hidden_states = torch.Tensor( + [ + [ + [1.0287, 0.7825, -0.3081, 2.4969], + [0.5398, -0.4116, 2.0838, 1.7880], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + [ + [2.7251, 0.4096, 0.3069, 0.5584], + [1.8989, 0.1470, -0.1327, 2.0868], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + ] + ) + assert_expected(actual[0], expected_hidden_states, atol=1e-4, rtol=1e-4) + + assert_expected(len(actual[1]), num_hidden_layers) + assert_expected(len(actual[1][0]), 2) # 2-element tuple includes key and value + assert_expected( + actual[1][0][0].shape, torch.Size([2, 2, 4, 2]) + ) # bsz x num_heads x seq_len x head_dim + expected_cached_values = torch.Tensor( + [ + [ + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + ], + [ + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + [ + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + [5.0000, 5.0000], + ], + ], + ] + ) + assert_expected(actual[1][0][0], expected_cached_values, atol=1e-4, rtol=1e-4) + + def test_qformer_model_with_past_key_values( + self, + input_ids, + qformer_model, + query_embeddings, + num_hidden_layers, + kv, + past_key_values, + ): + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + query_embeds=query_embeddings, + past_key_values=past_key_values, + use_cache=True, + ) + expected_hidden_states = torch.Tensor( + [ + [ + [1.0287, 0.7825, -0.3081, 2.4969], + [0.5398, -0.4116, 2.0838, 1.7880], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + [ + [2.7251, 0.4096, 0.3069, 0.5584], + [1.8989, 0.1470, -0.1327, 2.0868], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + ] + ) + assert_expected(actual[0], expected_hidden_states, atol=1e-4, rtol=1e-4) + + assert_expected(len(actual[1]), num_hidden_layers) + assert_expected(len(actual[1][0]), 2) # 2-element tuple includes key and value + assert_expected( + actual[1][0][0].shape, torch.Size([2, 2, 7, 2]) + ) # bsz x num_heads x seq_len x head_dim + expected_cached_values = torch.Tensor( + [ + [ + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + ], + [ + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + [ + [7.0, 7.0], + [9.0, 9.0], + [4.0, 4.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + [5.0, 5.0], + ], + ], + ] + ) + assert_expected(actual[1][0][0], expected_cached_values, atol=1e-4, rtol=1e-4) + + def test_qformer_model_with_causal_mask( + self, + input_ids, + attn_mask, + kv, + qformer_model, + query_embeddings, + num_hidden_layers, + ): + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + use_causal_mask=True, + ) + expected_hidden_states = torch.Tensor( + [ + [ + [1.0287, 0.7825, -0.3081, 2.4969], + [0.5398, -0.4116, 2.0838, 1.7880], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + [ + [2.7251, 0.4096, 0.3069, 0.5584], + [1.8989, 0.1470, -0.1327, 2.0868], + [1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000], + ], + ] + ) + assert_expected(actual[0], expected_hidden_states, atol=1e-4, rtol=1e-4) + + def test_qformer_model_scripting( + self, qformer_model, input_ids, attn_mask, query_embeddings, kv + ): + scripted_model = torch.jit.script(qformer_model) + scripted_output = scripted_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + ) + actual = qformer_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=True, + ) + assert_expected(scripted_output[0], actual[0], atol=1e-4, rtol=1e-4) + assert_expected(scripted_output[1], actual[1], atol=1e-4, rtol=1e-4) + + def test_qformer_for_clm( + self, + qformer_model_for_clm, + query_embeddings, + input_ids, + kv, + attn_mask, + labels, + loss_fct, + vocab_size, + ): + actual_pred = qformer_model_for_clm( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=False, + ) + expected = torch.ones([2, 2, 20]) * 5 + assert_expected(actual_pred, expected, atol=1e-4, rtol=1e-4) + + def test_qformer_for_clm_scripting( + self, + qformer_model_for_clm, + query_embeddings, + input_ids, + kv, + attn_mask, + labels, + loss_fct, + vocab_size, + ): + scripted_model = torch.jit.script(qformer_model_for_clm) + actual_pred = scripted_model( + input_ids=input_ids, + encoder_hidden_states=kv, + attention_mask=attn_mask, + query_embeds=query_embeddings, + use_cache=False, + ) + expected = torch.ones([2, 2, 20]) * 5 + assert_expected(actual_pred, expected, atol=1e-4, rtol=1e-4) diff --git a/tests/models/blip2/test_qformer_utils.py b/tests/models/blip2/test_qformer_utils.py new file mode 100644 index 000000000..0fb074455 --- /dev/null +++ b/tests/models/blip2/test_qformer_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from tests.test_utils import assert_expected +from torch import Tensor +from torchmultimodal.models.blip2.qformer_utils import get_causal_mask + + +class TestExtendedAttnMaskForDecoder: + @pytest.fixture + def attention_mask(self): + return Tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]) + + @pytest.fixture + def input_shape(self): + return (2, 2) + + def test_extended_attention_mask(self, attention_mask): + actual_mask = get_causal_mask(attention_mask, attention_mask.shape) + expected = Tensor( + [ + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + ] + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) + + def test_extended_attention_mask_diff_input_size(self, attention_mask, input_shape): + actual_mask = get_causal_mask( + attention_mask, + input_shape, + ) + expected = Tensor( + Tensor( + [ + [[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0]], + ] + ) + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) + + def test_extended_attention_mask_with_query_embs(self, attention_mask, input_shape): + actual_mask = get_causal_mask(attention_mask, input_shape, has_query=True) + expected = Tensor( + Tensor( + [ + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ], + ] + ) + ) + assert_expected(actual_mask, expected, rtol=0, atol=1e-4) diff --git a/tests/modules/losses/test_blip2_loss.py b/tests/modules/losses/test_blip2_loss.py new file mode 100644 index 000000000..3c8dd9daf --- /dev/null +++ b/tests/modules/losses/test_blip2_loss.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import chain + +import pytest +import torch +from tests.test_utils import ( + assert_expected, + gpu_test, + init_distributed_on_file, + init_weights_with_constant, + with_temp_files, +) +from torch import distributed as dist, multiprocessing as mp, nn, optim +from torchmultimodal.models.blip2.blip2 import BLIP2, Blip2Output +from torchmultimodal.models.blip2.qformer_model import QformerForCLM +from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer +from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings +from torchmultimodal.modules.layers.transformer import TransformerEncoder +from torchmultimodal.modules.losses.blip2_losses import Blip2Phase1Loss + + +@pytest.fixture +def dim_q(): + return 4 + + +@pytest.fixture +def dim_kv(): + return 2 + + +@pytest.fixture +def dim_feedforward(): + return 6 + + +@pytest.fixture +def num_hidden_layers(): + return 2 + + +@pytest.fixture +def num_heads(): + return 2 + + +@pytest.fixture +def vocab_size(): + return 20 + + +@pytest.fixture +def vit(): + embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2) + encoder = TransformerEncoder( + n_layer=1, + d_model=2, + n_head=1, + dim_feedforward=1, + activation=nn.GELU, + norm_first=True, + final_layer_norm_eps=1e-5, + ) + image_encoder = VisionTransformer( + embeddings=embedding, + encoder=encoder, + ) + init_weights_with_constant(image_encoder) + image_encoder.eval() + return image_encoder + + +class TestBLIP2Stage1Loss: + @pytest.fixture + def images(self): + return torch.ones(4, 3, 2, 2) + + @pytest.fixture + def input_ids(self): + return torch.ones(4, 4).long() + + @pytest.fixture + def all_attn_mask(self): + return torch.ones([4, 4]) + + @pytest.fixture + def global_batch_size(self): + return 4 + + @pytest.fixture + def qformer_model_for_clm( + self, + dim_q, + dim_kv, + dim_feedforward, + num_hidden_layers, + num_heads, + vocab_size, + ): + qformer_for_clm = QformerForCLM( + dim_q=dim_q, + dim_kv=dim_kv, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=0.0, + dropout=0.0, + num_hidden_layers=num_hidden_layers, + max_position_embeddings=512, + vocab_size=vocab_size, + ) + return qformer_for_clm + + @pytest.fixture + def blip2_output(self): + return Blip2Output( + image_embeddings=torch.ones([4, 5, 2]), + image_features=torch.ones([4, 32, 4]) * 0.5, + image_qformer_output=torch.ones([4, 32, 4]) * 0.5, + text_features=torch.ones([4, 4]) * 0.5, + prediction_scores=torch.ones([4, 4, 20]) * 5, + ) + + @pytest.fixture + def blip2(self, dim_q, dim_kv, qformer_model_for_clm, vit): + blip2 = BLIP2( + dim_q=dim_q, + image_encoder_embedding_dim=dim_kv, + qformer=qformer_model_for_clm, + vision_encoder=vit, + embedding_dim=4, + decoder_bos_token_id=19, + ) + init_weights_with_constant(blip2) + blip2.eval() + return blip2 + + def test_local_loss(self, all_attn_mask, blip2_output, blip2, dim_q, input_ids): + blip2_loss = Blip2Phase1Loss(dim_q=dim_q) + init_weights_with_constant(blip2_loss) + local_loss = blip2_loss( + model_output=blip2_output, + blip2=blip2, + input_ids=input_ids, + attention_mask=all_attn_mask, + ) + assert_expected(local_loss.total_loss.item(), 5.07517, rtol=0, atol=1e-4) + + def test_local_itc_only_loss( + self, all_attn_mask, blip2_output, blip2, dim_q, input_ids + ): + blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itm=False, enable_itg=False) + init_weights_with_constant(blip2_loss) + local_loss = blip2_loss( + model_output=blip2_output, + blip2=blip2, + input_ids=input_ids, + attention_mask=all_attn_mask, + ) + assert_expected(local_loss.total_loss.item(), 1.38629, rtol=0, atol=1e-4) + + def test_local_itm_only_loss( + self, all_attn_mask, blip2_output, blip2, dim_q, input_ids + ): + blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itg=False) + init_weights_with_constant(blip2_loss) + local_loss = blip2_loss( + model_output=blip2_output, + blip2=blip2, + input_ids=input_ids, + attention_mask=all_attn_mask, + ) + assert_expected(local_loss.total_loss.item(), 0.69315, rtol=0, atol=1e-4) + + def test_local_itg_only_loss( + self, all_attn_mask, blip2_output, blip2, dim_q, input_ids + ): + blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itm=False) + init_weights_with_constant(blip2_loss) + local_loss = blip2_loss( + model_output=blip2_output, + blip2=blip2, + input_ids=input_ids, + attention_mask=all_attn_mask, + ) + assert_expected(local_loss.total_loss.item(), 2.9957, rtol=0, atol=1e-4) + + def test_invalid_loss_input(self): + with pytest.raises(ValueError): + Blip2Phase1Loss( + dim_q=dim_q, enable_itc=False, enable_itm=False, enable_itg=False + ) + + @staticmethod + def _model_worker( + gpu_id: int, + sync_file: str, + world_size: int, + global_batch_size: int, + all_images: torch.Tensor, + all_input_ids: torch.Tensor, + all_attn_mask: torch.Tensor, + blip2_output: Blip2Output, + blip2: nn.Module, + dim_q=dim_q, + ): + init_distributed_on_file( + world_size=world_size, gpu_id=gpu_id, sync_file=sync_file + ) + assert global_batch_size % world_size == 0 + local_batch_size = global_batch_size // world_size + all_attn_mask = torch.ones([4, 4]) + + # Split inputs across GPUs + local_images = torch.split(all_images, local_batch_size)[gpu_id].cuda(gpu_id) + local_input_ids = torch.split(all_input_ids, local_batch_size)[gpu_id].cuda( + gpu_id + ) + local_attn_mask = torch.split(all_attn_mask, local_batch_size)[gpu_id].cuda( + gpu_id + ) + assert blip2_output.text_features is not None + assert blip2_output.prediction_scores is not None + local_blip2_output = Blip2Output( + image_embeddings=torch.split( + blip2_output.image_embeddings, local_batch_size + )[gpu_id].cuda(gpu_id), + image_features=torch.split(blip2_output.image_features, local_batch_size)[ + gpu_id + ].cuda(gpu_id), + image_qformer_output=torch.split( + blip2_output.image_qformer_output, local_batch_size + )[gpu_id].cuda(gpu_id), + text_features=torch.split(blip2_output.text_features, local_batch_size)[ + gpu_id + ].cuda(gpu_id), + prediction_scores=torch.split( + blip2_output.prediction_scores, local_batch_size + )[gpu_id].cuda(gpu_id), + ) + + blip2 = blip2.cuda(gpu_id) + loss_fn = Blip2Phase1Loss(dim_q=dim_q) + init_weights_with_constant(loss_fn) + loss_fn = loss_fn.cuda(gpu_id) + + all_params = chain(blip2.parameters(), loss_fn.parameters()) + + optimizer = optim.SGD(all_params, lr=1e-4) + + # Forward pass + loss = loss_fn( + model_output=local_blip2_output, + blip2=blip2, + images=local_images, + input_ids=local_input_ids, + attention_mask=local_attn_mask, + ).total_loss + + # Compute gradients + optimizer.zero_grad() + loss.backward() + + # Gather gradients from all devices + def gather_grads(x: torch.Tensor) -> torch.Tensor: + grads = [torch.zeros_like(x).cuda(gpu_id) for i in range(world_size)] + dist.all_gather(grads, x) + grad = torch.stack(grads).mean() + return grad + + # Gather losses from all devices + gathered_loss = gather_grads(torch.Tensor([loss]).cuda(gpu_id)) + assert_expected(gathered_loss.item(), 5.07517, rtol=0, atol=1e-4) + + @gpu_test(gpu_count=1) + def test_single_gpu_loss( + self, + global_batch_size, + input_ids, + blip2_output, + blip2, + attn_mask, + dim_q, + ): + with with_temp_files(count=1) as sync_file: + world_size = 1 + mp.spawn( + TestBLIP2Stage1Loss._model_worker, + ( + sync_file, + world_size, + global_batch_size, + input_ids, + attn_mask, + blip2_output, + blip2, + dim_q, + ), + nprocs=world_size, + ) + + @gpu_test(gpu_count=2) + def test_multi_gpu_loss( + self, + global_batch_size, + input_ids, + blip2_output, + blip2, + attn_mask, + dim_q, + ): + with with_temp_files(count=1) as sync_file: + world_size = 2 + mp.spawn( + TestBLIP2Stage1Loss._model_worker, + ( + sync_file, + world_size, + global_batch_size, + input_ids, + attn_mask, + blip2_output, + blip2, + dim_q, + ), + nprocs=world_size, + ) diff --git a/torchmultimodal/models/blip2/__init__.py b/torchmultimodal/models/blip2/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/torchmultimodal/models/blip2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/models/blip2/blip2.py b/torchmultimodal/models/blip2/blip2.py new file mode 100644 index 000000000..a3dfde050 --- /dev/null +++ b/torchmultimodal/models/blip2/blip2.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import NamedTuple, Optional + +import torch + +from torch import nn, Tensor +from torch.nn import functional as F +from torchmultimodal.modules.layers.transformer import TransformerOutput + + +class Blip2Output(NamedTuple): + """ + BLIP2 model output for loss computation. + + image_embeddings(Tensor): normalized image embeddings returned by the visual encoder + with shape [bsz x seq_len x embed_dim]. + image_features(Tensor): Image features after qformer and projection (for stage 1 training) + with shape [bsz, num_query_tokens, embed_dim] + image_qformer_output(Tensor) : last hidden state for qformer output by given image input + text_features(Optional[Tensor]): Text features after qformer and projection if text input is provided + with shape [bsz, embed_dim] + prediction_scores (Optional[Tensor]): computed for next word prediction + with shape of [bsz, seq_len, vocab_size] + """ + + image_embeddings: Tensor + image_features: Tensor + image_qformer_output: Tensor + text_features: Optional[Tensor] = None + prediction_scores: Optional[Tensor] = None + + +class BLIP2(nn.Module): + """ + BLIP2(https://arxiv.org/pdf/2301.12597.pdf) provides a pre-training strategy to bootstrap vision-language + pre-training from frozen image encoders and frozen large language models(LLM). BLIP-2 bridges the modality gap + and facilitates cross-modal alignment via Querying Transformer (Q-former). Q-former is a lightweight transformer + which has a set of learnable query vectors to extract visual features from the frozen image encoder. + + Args: + qformer(nn.Module): Querying Transformer (Q-former) + visual_encoder(nn.Module): Frozen image encoder + dim_q(int) : Dimension of query tensor, this value should be the same as dim_q in qformer. + image_encoder_embedding_dim(int): Embedding dimension for image encoder, + this value should be the same as dim_kv in qformer. + freeze_visual_encoder(bool): Whether to freeze the visual encoder, default to True + cross_attention_freq(int): Frequency of adding cross-attention block in Qformer, default to 2 + embedding_dim(int): Embedding dimension + num_query_token(int): Number of query tokens in Qformer, default to 32 + init_query_tokens(bool): whether init query token params, default to True + decoder_bos_token_id(Optional[int]): bos_token_id used in decoder, default to None + """ + + def __init__( + self, + qformer: nn.Module, + vision_encoder: nn.Module, + dim_q: int, + image_encoder_embedding_dim: int, + freeze_vision_encoder: bool = True, + cross_attention_freq: int = 2, + embedding_dim: int = 256, + num_query_token: int = 32, + init_query_tokens: bool = True, + decoder_bos_token_id: Optional[int] = None, + ): + super().__init__() + self.vision_encoder = vision_encoder + if freeze_vision_encoder: + for param in self.vision_encoder.parameters(): + param.requires_grad = False + self.vision_encoder = self.vision_encoder.eval() + + self.qformer = qformer + self.decoder_bos_token_id = decoder_bos_token_id + self.dim_q = dim_q + self.query_tokens = nn.Parameter(torch.zeros(1, num_query_token, self.dim_q)) + if init_query_tokens: + self.query_tokens.data.normal_(mean=0.0, std=0.02) + + self.vision_proj = nn.Linear(self.dim_q, embedding_dim) + self.text_proj = nn.Linear(self.dim_q, embedding_dim) + self.ln_vision = nn.LayerNorm(image_encoder_embedding_dim) + + def forward( + self, + image: Tensor, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ) -> Blip2Output: + """ + Args: + image(Tensor): Image input tensor with shape [B, C, H, W] + input_ids(Optional[Tensor]): Text input tensor with shape [bsz, seq_len] + attention_mask(Optional[Tensor]): Attention mask tensor with shape [bsz, seq_len] + + Returns: + return BLIP2 model output(Blip2Output). + """ + vision_encoder_output = self.vision_encoder(image) + if isinstance(vision_encoder_output, TransformerOutput): + vision_encoder_output = vision_encoder_output.last_hidden_state + assert vision_encoder_output is not None + image_embeds = self.ln_vision(vision_encoder_output) + # query tokens: [batch_size, num_query_token, encoder_hidden_size] + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.qformer.model( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + use_cache=True, + ) + + # image_feats: [batch_size, num_query_token, embedding_dim] + image_feats = F.normalize(self.vision_proj(query_output[0]), dim=-1) + + text_feats: Optional[Tensor] = None + prediction_scores: Optional[Tensor] = None + if input_ids is not None: + text_output = self.qformer.model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + text_feats = F.normalize(self.text_proj(text_output[0][:, 0, :]), dim=-1) + + decoder_input_ids = input_ids.clone() + if self.decoder_bos_token_id is not None: + # pyre-ignore + decoder_input_ids[:, 0] = self.decoder_bos_token_id + + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + input_ids.device + ) + if attention_mask is not None: + attention_mask = torch.cat([query_atts, attention_mask], dim=1) + + # set use_cache = False since past_key_values should be cached in previous steps. + prediction_scores = self.qformer( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + past_key_values=query_output[1], + use_cache=False, + ) + + return Blip2Output( + image_embeddings=image_embeds, + image_features=image_feats, + image_qformer_output=query_output[0], + text_features=text_feats, + prediction_scores=prediction_scores, + ) diff --git a/torchmultimodal/models/blip2/qformer_layers.py b/torchmultimodal/models/blip2/qformer_layers.py new file mode 100644 index 000000000..94e1d30c8 --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_layers.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional, Tuple + +import torch + +from torch import nn, Tensor + +from torchmultimodal.modules.layers.mlp import MLP +from torchmultimodal.modules.layers.multi_head_attention import ( + MHAWithCacheOutput, + MultiHeadAttentionWithCache, +) +from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm + + +class QformerLayer(nn.Module): + """ + Qformer layer module. + + This module is designed with a self-attention (SA) block and optionally includes a cross-attention (CA) block for queries. + The inputs for this module, referred to as hidden_states, can consist of either a query, text, or a combination of both. + Cross-attention is exclusively activated for queries (query_length > 0) with encoder_hidden_states derived from image inputs. + + The feedforward(ff) block will project the hidden states output by the layer before, + query output and text output are concatenated as overall output after separated handling for CA and ff. + + Args: + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + has_cross_attention (bool): whether a cross-attention layer is included + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA. + + """ + + def __init__( + self, + dim_q: int, + dim_feedforward: int, + num_heads: int, + attn_dropout: float = 0.0, + dropout: float = 0.0, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + has_cross_attention: bool = False, + dim_kv: Optional[int] = None, + ): + super().__init__() + self.self_attention = MultiHeadAttentionWithCache( + dim_q, dim_q, num_heads, attn_dropout + ) + self.self_attn_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.has_cross_attention = has_cross_attention + self.cross_attention: Optional[MultiHeadAttentionWithCache] = None + + if has_cross_attention: + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder, key and value caching should be disabled. + if dim_kv is None: + raise ValueError( + "key and value dim should be provided for cross attention." + ) + self.cross_attention = MultiHeadAttentionWithCache( + dim_q=dim_q, dim_kv=dim_kv, num_heads=num_heads, dropout=attn_dropout + ) + self.cross_attn_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.cross_attn_dropout = nn.Dropout(dropout) + + # feedforward block + self.feedforward = MLP( + dim_q, dim_q, dim_feedforward, dropout=0.0, activation=activation + ) + self.feedforward_layernorm = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.feedforward_dropout = nn.Dropout(dropout) + + # query feedforward block + self.feedforward_query = MLP( + dim_q, dim_q, dim_feedforward, dropout=0.0, activation=activation + ) + self.feedforward_layernorm_query = Fp32LayerNorm(dim_q, eps=layer_norm_eps) + self.feedforward_dropout_query = nn.Dropout(dropout) + + def _self_attention_block( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor, Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + x = hidden_states + attn_output = self.self_attention( + x, + x, + x, + attn_mask=attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + present_key_value: Optional[Tuple[Tensor, Tensor]] = None + if use_cache: + assert isinstance(attn_output, MHAWithCacheOutput) + attn_output_value = attn_output.attn_output + present_key_value = attn_output.past_key_value + else: + assert isinstance(attn_output, Tensor) + attn_output_value = attn_output + attn_output = self.dropout(attn_output_value) + + attn_residual = attn_output + x + attn_residual = self.self_attn_layernorm(attn_residual) + return attn_residual, present_key_value + + def _cross_attention_block( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + ) -> Tensor: + x = hidden_states + assert self.cross_attention is not None + # turn off cache for cross attention + cross_attn_output = self.cross_attention( + query=x, + key=encoder_hidden_states, + value=encoder_hidden_states, + use_cache=False, + ) + + if not torch.jit.isinstance(cross_attn_output, Tensor): + raise ValueError("cross-attention output must be Tensor.") + cross_attn_output = self.cross_attn_dropout(cross_attn_output) + cross_attn_residual = cross_attn_output + x + cross_attn_residual = self.cross_attn_layernorm(cross_attn_residual) + return cross_attn_residual + + def _feedforward_block(self, hidden_states: Tensor) -> Tensor: + h = self.feedforward(hidden_states) + h = self.feedforward_dropout(h) + h = self.feedforward_layernorm(h + hidden_states) + return h + + def _feedforward_query_block(self, hidden_states: Tensor) -> Tensor: + h = self.feedforward_query(hidden_states) + h = self.feedforward_dropout_query(h) + h = self.feedforward_layernorm_query(h + hidden_states) + return h + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor, Tensor]] = None, + query_length: int = 0, + use_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + """ + Inputs: + hidden_states (Tensor): input query of shape bsz x seq_len x embed_dim + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + attention_mask (Optional[Tensor]): attention mask, supported mask type is described in MultiHeadAttentionWithCache class + past_key_value (Optional[Tuple[Tensor, Tensor]]): cached key/value tuple for self-attention + query_length (Optional[int]): length of query embedding, used as condition + to determine query attention output and check text existance. + use_cache (bool): whether to use cache for key and value tensors + + Return: + A tuple includes: + layer_output (Tensor): layer output of shape bsz x seq_len x embed_dim + present_key_value (Optional[Tuple[Tensor, Tensor]]): key/value tuple for self-attention + """ + if past_key_value is not None and len(past_key_value) != 2: + raise ValueError( + "past_key_value should be 2-element tuple to represent self-attention cached key/values." + ) + attn_residual, present_key_value = self._self_attention_block( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + use_cache=use_cache, + ) + + if query_length > 0: + query_attn_output = attn_residual[:, :query_length, :] + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError( + "encoder_hidden_states must be given for cross-attention layers" + ) + cross_attn_output = self._cross_attention_block( + hidden_states=query_attn_output, + encoder_hidden_states=encoder_hidden_states, + ) + query_attn_output = cross_attn_output + + # add query feedforward block for self-attention or cross-attention + layer_output = self._feedforward_query_block(query_attn_output) + + # handle text input if present + if attn_residual.shape[1] > query_length: + layer_output_text = self._feedforward_block( + attn_residual[:, query_length:, :] + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + + else: + layer_output = self._feedforward_block(attn_residual) + + return (layer_output, present_key_value) + + +class QformerEncoder(nn.Module): + """ + Qformer encoder module including multiple Qformer layers. + + Args: + num_hidden_layers (int): number of Qformer layers inside encoder + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer in each Qformer layer + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + cross_attention_freq (int): frequency of adding cross attention in QFormer layers, default to 2. + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA. + + """ + + def __init__( + self, + num_hidden_layers: int, + dim_q: int, + dim_feedforward: int, + num_heads: int, + attn_dropout: float = 0.0, + dropout: float = 0.0, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + cross_attention_freq: int = 2, + dim_kv: Optional[int] = None, + ): + super().__init__() + layers = [] + for i in range(num_hidden_layers): + has_cross_attention = i % cross_attention_freq == 0 + layers.append( + QformerLayer( + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + layer_norm_eps=layer_norm_eps, + activation=activation, + has_cross_attention=has_cross_attention, + dim_kv=dim_kv, + ) + ) + self.layers = nn.ModuleList(layers) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None, + query_length: int = 0, + use_cache: bool = False, + ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + """ + Inputs: + hidden_states (Tensor): input query of shape bsz x seq_len x embed_dim + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + attention_mask (Optional[Tensor]): attention mask, supported mask type is described in MultiHeadAttentionWithCache class + past_key_values (Optional[List[Tuple[Tensor, Tensor]]]): cached key/value tuple for self-attention + query_length (int): the length of input query, used for cross-attention + use_cache (bool): whether to use cache for key and value tensors + + Return: + A tuple includes: + the last hidden state: Tensor of shape bsz x seq_len x embed_dim + past_key_values (List[Optional[Tuple[Tensor, Tensor]]]]): cached key/values from Qformer layers + """ + current_key_values = torch.jit.annotate(List[Tuple[Tensor, Tensor]], []) + for i, layer_module in enumerate(self.layers): + past_key_value = past_key_values[i] if past_key_values is not None else None + hidden_states, current_key_value = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + query_length=query_length, + use_cache=use_cache, + ) + if use_cache: + assert isinstance(current_key_value, tuple) + current_key_values.append(current_key_value) + + return (hidden_states, current_key_values) + + +class QformerEmbedding(nn.Module): + """ + Qformer embedding module. + + Args: + embedding_dim (int): dim of embedding space + max_position_embeddings (int): max sequence length allowed for positional embeddings + vocab_size (int): size of vocabulary + pad_token_id (int): id used for padding token, default is 0. + dropout (float): dropout probability after embedding layers and layernorm. + layer_norm_eps (float): the epsilon used by the layer normalization layers. + """ + + def __init__( + self, + embedding_dim: int, + max_position_embeddings: int, + vocab_size: int, + pad_token_id: int = 0, + layer_norm_eps: float = 1e-12, + dropout: float = 0.0, + ): + super().__init__() + self.token_embeddings = nn.Embedding( + vocab_size, embedding_dim, padding_idx=pad_token_id + ) + self.position_embeddings = nn.Embedding(max_position_embeddings, embedding_dim) + self.layernorm = Fp32LayerNorm(embedding_dim, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(max_position_embeddings).expand((1, -1)) + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + query_embeddings: Optional[Tensor] = None, + past_seq_length: int = 0, + ) -> Tensor: + """ + Inputs: + input_ids (Optional[Tensor]): input token ids + position_ids (Optional[Tensor]): batches of of 1D integer tensors used to identify each token's position, + if no position_ids is provided, the IDs are automatically created as absolute positional embeddings. + query_embeddings (Optional[Tensor]): query embeddings for QFormer + past_seq_length (Optional[int]): sequence length cached by past_key_values. + + Returns: + embeddings (Tensor): concatenated embeddings of shape (bsz, num tokens, embedding dim), concatenation is along + the token dimension. + """ + if input_ids is None and query_embeddings is None: + raise ValueError("Either input_ids or query_embeddings must be passed.") + + seq_length = input_ids.size(1) if input_ids is not None else 0 + + embeddings = query_embeddings + + if input_ids is not None: + if position_ids is None: + position_ids = self.position_ids[ + :, past_seq_length : seq_length + past_seq_length + ].clone() + word_embeddings = self.token_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids.long()) + embeddings = word_embeddings + position_embeddings + + if query_embeddings is not None: + embeddings = torch.cat((query_embeddings, embeddings), dim=1) + + assert isinstance(embeddings, Tensor) + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings diff --git a/torchmultimodal/models/blip2/qformer_model.py b/torchmultimodal/models/blip2/qformer_model.py new file mode 100644 index 000000000..1ad32bf28 --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_model.py @@ -0,0 +1,294 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional, Tuple + +from torch import nn, Tensor +from torchmultimodal.models.blip2.qformer_layers import QformerEmbedding, QformerEncoder + +from torchmultimodal.models.blip2.qformer_utils import get_causal_mask + + +class QformerModel(nn.Module): + """ + Qformer model including Qformer embedding and Qformer encoder. + + Args: + num_hidden_layers (int): number of Qformer layers inside encoder + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + max_position_embeddings (int): max sequence length allowed for positional embeddings + vocab_size (int): size of vocabulary + pad_token_id (int): id used for padding token, default is 0. + query_length(int): query length in Qformer, used to compute cached query length. + default value is the same as num_query_token for Blip2 case (https://fburl.com/316803mo). + dim_kv (Optional[int]): dimensionality of the key and value tensors, this value is only used in CA, default is None. + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer in each Qformer layer + cross_attention_freq (int): frequency of adding cross attention in QFormer layers, default to 2. + """ + + def __init__( + self, + num_hidden_layers: int, + dim_q: int, + dim_feedforward: int, + num_heads: int, + max_position_embeddings: int, + vocab_size: int, + pad_token_id: int = 0, + query_length: int = 32, + dim_kv: Optional[int] = None, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.ReLU, + attn_dropout: float = 0.0, + dropout: float = 0.0, + cross_attention_freq: int = 2, + ) -> None: + super().__init__() + self.query_length = query_length + self.embeddings = QformerEmbedding( + embedding_dim=dim_q, + max_position_embeddings=max_position_embeddings, + vocab_size=vocab_size, + pad_token_id=pad_token_id, + layer_norm_eps=layer_norm_eps, + dropout=dropout, + ) + self.encoder = QformerEncoder( + num_hidden_layers=num_hidden_layers, + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + layer_norm_eps=layer_norm_eps, + activation=activation, + cross_attention_freq=cross_attention_freq, + dim_kv=dim_kv, + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + query_embeds: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tensor] = None, + past_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None, + use_cache: bool = False, + use_causal_mask: bool = False, + ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + """ + Inputs: + input_ids (Optional[Tensor]): input token ids for QFormer + attention_mask (Optional[Tensor]): attention mask for QFormer + position_ids (Optional[Tensor]): position ids for QFormer + query_embeds (Optional[Tensor]): query embeddings for QFormer + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + past_key_values: (Optional[List[Tuple[Tensor, Tensor]]]): a list of num_layers elements, + each element is a 2-element tuple for cached key/value. + key/value is tensor with shape of (bsz x source_seq_len x embed_dim). + use_cache (bool): whether to use cache for key and value tensors + use_causal_mask (bool): apply causal mask if true, default to False + + Returns: + Qformer encoder output with a tuple of last hidden states and past_key_values if use_cache. + """ + past_seq_length = ( + # overall_seq_length - query_length + past_key_values[0][0].shape[2] - self.query_length + if past_key_values is not None + else 0 + ) + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeddings=query_embeds, + past_seq_length=past_seq_length, + ) + bsz, seq_len = embedding_output.size()[:-1] + + if attention_mask is not None: + if use_causal_mask: + # Apply a causal mask in addition to the padding mask and make attention mask broadcastable. + causal_mask = get_causal_mask( + attention_mask, + (bsz, seq_len), + has_query=(query_embeds is not None), + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype) + else: + attention_mask = attention_mask[:, None, None, :] + # create a tensor which is 0.0 for positions to attend and -10000.0 for masked position. + # use float mask to ensure mask values will be added to the attention weight + attention_mask = (1.0 - attention_mask) * -10000.0 + + return self.encoder( + hidden_states=embedding_output, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + query_length=query_length, + ) + + +class QformerPredictionHead(nn.Module): + """ + MLP head for computinng prediction score from QformerModel output + + Args: + dim_q (int): dimensionality of the query tensor + vocab_size (int): the size of vocabulary used by QFormer + layer_norm_eps (float): the epsilon used by the layer normalization layers, default is 1e-12 + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + """ + + def __init__( + self, + dim_q: int, + vocab_size: int, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.linear_1 = nn.Linear(dim_q, dim_q) + self.activation = activation() + self.layernorm = nn.LayerNorm(dim_q, eps=layer_norm_eps) + self.linear_2 = nn.Linear(dim_q, vocab_size) + + def forward(self, sequence_output: Tensor) -> Tensor: + """ + Inputs (Tensor): + sequence_output of shape bsz x seq_len x embed_dim + Returns: + prediction scores (Tensor) of shape: bsz x seq_len x vocab_size + """ + hidden_states = self.linear_1(sequence_output) + hidden_states = self.activation(hidden_states) + hidden_states = self.layernorm(hidden_states) + predictions = self.linear_2(hidden_states) + return predictions + + +class QformerForCLM(nn.Module): + """ + A QformerModel wrapper class for causal language modeling(clm). + + Args: + num_hidden_layers (int): number of Qformer layers inside encoder + dim_q (int): dimensionality of the query tensor + dim_feedforward (int): dimensionality of the feedforward layer + num_heads (int): number of attention heads + max_position_embeddings (int): max sequence length allowed for positional embeddings + vocab_size (int): size of vocabulary + pad_token_id (int): id used for padding token, default is 0. + query_length(int): query length in Qformer, details see QformerModel class. + dim_kv (Optional[int]): dim_kv (Optional[int]): dimensions of the key and value tensors, this value is only used in CA. + Default is None. + layer_norm_eps (float): the epsilon used by the layer normalization layers + activation (Callable[..., nn.Module]): the activation function applied to the feedforward layer + attn_dropout (float): dropout probability for attention weights + dropout (float): dropout probability for the densen layer after attention and feedforward layer in each Qformer layer + cross_attention_freq (int): frequency of adding cross attention in QFormer layers, default to 2 + """ + + def __init__( + self, + num_hidden_layers: int, + dim_q: int, + dim_feedforward: int, + num_heads: int, + max_position_embeddings: int, + vocab_size: int, + pad_token_id: int = 0, + query_length: int = 32, + dim_kv: Optional[int] = None, + layer_norm_eps: float = 1e-12, + activation: Callable[..., nn.Module] = nn.GELU, + attn_dropout: float = 0.0, + dropout: float = 0.0, + cross_attention_freq: int = 2, + ) -> None: + super().__init__() + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + self.head = QformerPredictionHead( + dim_q=dim_q, + activation=activation, + layer_norm_eps=layer_norm_eps, + vocab_size=vocab_size, + ) + self.model = QformerModel( + num_hidden_layers=num_hidden_layers, + dim_q=dim_q, + dim_feedforward=dim_feedforward, + num_heads=num_heads, + max_position_embeddings=max_position_embeddings, + vocab_size=vocab_size, + pad_token_id=pad_token_id, + query_length=query_length, + dim_kv=dim_kv, + layer_norm_eps=layer_norm_eps, + activation=activation, + attn_dropout=attn_dropout, + dropout=dropout, + cross_attention_freq=cross_attention_freq, + ) + + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + query_embeds: Optional[Tensor] = None, + encoder_hidden_states: Optional[Tensor] = None, + past_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None, + use_cache: bool = False, + ) -> Tensor: + """ + Inputs: + input_ids (Optional[Tensor]): input token ids for QFormer + attention_mask (Optional[Tensor]): attention mask for QFormer + position_ids (Optional[Tensor]): position ids for QFormer + query_embeds (Optional[Tensor]): query embeddings for QFormer + encoder_hidden_states (Optional[Tensor]): input key/values of shape bsz x seq_len x embed_dim, only used in CA case + past_key_values: (Optional[List[Tuple[Tensor, Tensor]]]): cached key/value tuple for self-attention + use_cache (bool): whether to use cache for key and value tensors, + default to False for generation as cached values should be computed in previous training tasks. + + Returns: + prediction score (Tensor) computed for next word prediction of shape + bsz x seq_len x vocab_size + """ + # TODO: revisit if it's required for edge cases after BLIP-2 impl. + if past_key_values is not None: + assert query_embeds is None + + sequence_output, _ = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + use_cache=use_cache, + use_causal_mask=True, # set causal mask for clm + ) + if query_embeds is not None: + sequence_output = sequence_output[:, query_embeds.shape[1] :, :] + + prediction_scores = self.head(sequence_output) + return prediction_scores diff --git a/torchmultimodal/models/blip2/qformer_utils.py b/torchmultimodal/models/blip2/qformer_utils.py new file mode 100644 index 000000000..3b6022f34 --- /dev/null +++ b/torchmultimodal/models/blip2/qformer_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from torch import Tensor +from torchmultimodal.utils.attention import get_causal_attention_mask + + +def get_causal_mask( + attention_mask: Tensor, + input_shape: Tuple[int, int], + has_query: bool = False, +) -> Tensor: + """A causal mask in addition to the padding mask for Q-Former for generation task. + when input seq_len is shorter than attn_mask, increasing causal_mask by prefix_seq_len with 1s; + if query is available, apply causal self-attention mask to control query-text interaction; + + Arguments: + attention_mask (Tensor) is a binary mask with 1 for unmasked and 0 for masked positions. + Attention_mask has size of [batch_size, attn_seq_len]. attn_seq_len can be only seq_len for text_token + or query_len + seq_len. + input_shape (tuple[int, int]): indicates input shape of (batch_size, input_seq_len) from embedding output. + If query_emb is used, input_seq_len is query_len + seq_len. + Input shape can be different from attention_mask shape for image caption and text generation tasks. + has_query (bool) indicating whether query is available in qformer input. + + Returns: + causal_mask (Tensor): mask size of [bsz, attn_seq_len, attn_seq_len] with query, + [bsz, input_seq_len, attn_seq_len] without query + + """ + device = attention_mask.device + batch_size, seq_len = input_shape + causal_mask = get_causal_attention_mask(seq_len).to(device) + causal_mask = causal_mask.repeat(batch_size, 1).view(batch_size, seq_len, seq_len) + # compare seq_len in input and attention mask + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: + # if query is available, apply causal self-attention mask to control query-text interaction. + # Allow queries attending each other but not the text tokens. + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + dim=1, + ) # mask size [bsz, attn_seq_len, input_seq_len] + # increase causal_mask by prefix_seq_len with 1s to attend self-attention + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + dim=-1, + ) # size of [bsz, attn_seq_len, attn_seq_len] with query, [bsz, input_seq_len, attn_seq_len] without query + return causal_mask diff --git a/torchmultimodal/modules/losses/blip2_losses.py b/torchmultimodal/modules/losses/blip2_losses.py new file mode 100644 index 000000000..3bf0ecf10 --- /dev/null +++ b/torchmultimodal/modules/losses/blip2_losses.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, OrderedDict, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torchmultimodal.models.blip2.blip2 import Blip2Output +from torchmultimodal.utils.distributed import ( + BackpropType, + concat_gather_all_gpu, + get_rank, +) + + +@dataclass +class Blip2Stage1Losses(OrderedDict): + "Blip-2 stage 1 losses" + image_text_contrastive_loss: torch.Tensor + image_text_matching_loss: torch.Tensor + image_captioning_loss: torch.Tensor + total_loss: torch.Tensor + + +def compute_image_text_similarity( + image_features: torch.Tensor, text_features: torch.Tensor, temp: nn.Parameter +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute image-text similarity across all the devices for itc and itm usage. + + Inputs: + image_features (torch.Tensor): Blip2 image output of shape [bsz, num_query_tokens, embed_dim] + text_features (torch.Tensor): Blip2 text output of shape [bsz, embed_dim] + temp (nn.Parameter): Temperature parameter + + Returns: + a tuple of tensor contains image-to-text similarity and text-to-image similarity. + """ + image_features_all = concat_gather_all_gpu( + image_features, backprop_type=BackpropType.NONE + ) # [bsz x num_gpu, num_query_tokens, embed_dim] + text_features_all = concat_gather_all_gpu( + text_features, backprop_type=BackpropType.NONE + ) # [bsz x num_gpu, embed_dim] + sim_q2t = torch.matmul( + image_features.unsqueeze(1), text_features_all.unsqueeze(-1) + ).squeeze() + # [bsz, bsz x num_gpu, num_query_tokens] + + # image-text similarity: aggregate across all query tokens + sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / temp + + # text-query similarity: [bsz, bsz x num_gpu, num_query_tokens] + sim_t2q = torch.matmul( + text_features.unsqueeze(1).unsqueeze(1), image_features_all.permute(0, 2, 1) + ).squeeze() + + # text-image similarity: aggregate across all query tokens + sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / temp # [bsz, bsz x num_gpu] + + return sim_i2t, sim_t2i + + +def itc_loss( + sim_i2t: torch.Tensor, + sim_t2i: torch.Tensor, + label_smoothing: float = 0.1, +) -> torch.Tensor: + """Compute image-text contrastive loss by given similarity between image and text. + + Inputs: + sim_i2t(torch.Tensor): image-to-text similarity, shape [bsz, bsz x num_gpu] + sim_t2i (torch.Tensor): text-to-image similarity, shape [bsz, bsz x num_gpu] + label_smoothing (Optional[float]): Label smoothing for cross-entropy. Default: 0.1. + + Returns: + itc_loss (torch.Tensor) + """ + rank = get_rank() + + local_batch_size = sim_i2t.size(0) + targets = local_batch_size * rank + torch.arange( + local_batch_size, device=sim_i2t.device + ) + + loss = ( + F.cross_entropy(sim_i2t, targets, label_smoothing=label_smoothing) + + F.cross_entropy(sim_t2i, targets, label_smoothing=label_smoothing) + ) / 2 + return loss + + +def itg_loss( + input_ids: torch.Tensor, + prediction_scores: torch.Tensor, + decoder_bos_token_id: int, + pad_token_id: int, + vocab_size: int, + reduction: str = "mean", + label_smoothing: float = 0.1, +) -> torch.Tensor: + """Compute image caption loss from BLIP2 predictions. + + Inputs: + input_ids (torch.Tensor): text input ids of shape (bsz, seq_len). + prediction_scores (torch.Tensor): BLIP2 prediction scores, shape of (bsz, seq_len, vocab_size) + decoder_bos_token_id (int): bos_token_id for decoder, which is used to replace CLS token. + pad_token_id (int): pad_token_id for decoder + vocab_size (int): vocab size of BLIP2 model + reduction (str): reduction for loss computation, default is "mean". + label_smoothing (float): label smoothing value for cross-entropy loss, default is 0.1. + + Returns: + itg_loss (torch.Tensor): image caption loss. + """ + decoder_input_ids = input_ids.clone() + # Replace CLS token to signal the decoding task as mentioned in paper https://arxiv.org/pdf/2301.12597.pdf + decoder_input_ids[:, 0] = decoder_bos_token_id + labels = decoder_input_ids.masked_fill(decoder_input_ids == pad_token_id, -100) + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + itg_loss = F.cross_entropy( + shifted_prediction_scores.view(-1, vocab_size), + labels.view(-1), + reduction=reduction, + label_smoothing=label_smoothing, + ) + + return itg_loss + + +# TODO: upstream itm_loss for other model usage +def itm_loss( + input_ids: torch.Tensor, + image_embeds: torch.Tensor, + sim_i2t: torch.Tensor, + sim_t2i: torch.Tensor, + model_query_tokens: nn.Parameter, + qformer: nn.Module, + itm_head: nn.Module, + attention_mask: torch.Tensor, +) -> torch.Tensor: + """Compute image-text matching loss + ITM loss computation uses hard negative mining strategy. Negative text and image examples + are selected based on their corresponding similarities. + + The concatenated image-text pairs are constructed as a size of 3 x bsz batch (pos, neg, neg) + with text concatenated inputs (pos, pos, neg) and image inputs (pos, neg, pos). + + Query embedding output are fed into a two-class linear classifier to obtain a logit, + and average the logits across all queries as the output matching score. + + Inputs: + input_ids (torch.Tensor): text input ids of shape [bsz, seq_len]. + image_embeds (torch.Tensor): image embeddings returned by vision encoder + with shape [bsz, image_embedding_dim] + sim_i2t (torch.Tensor): image-to-text similarity, shape [bsz, bsz x num_gpu] + sim_t2i (torch.Tensor): text-to-image similarity, shape [bsz, bsz x num_gpu] + model_query_tokens(nn.Parameter): Blip2 query tokens + qformer (nn.Module): Q-Former module + itm_head (nn.Module): ITM head defined in blip2 stage1 loss + attention_mask (torch.Tensor): attention mask for text input, shape [bsz, seq_len]. + + Returns: + itm_loss (torch.Tensor): image-text matching loss + """ + local_batch_size = image_embeds.size(0) + device = image_embeds.device + text_input_ids_all_gpus = concat_gather_all_gpu( + input_ids, + backprop_type=BackpropType.NONE, + ) + + text_attention_mask_all_gpus = concat_gather_all_gpu( + attention_mask, + backprop_type=BackpropType.NONE, + ) + image_embeds_all_gpus = concat_gather_all_gpu( + image_embeds, backprop_type=BackpropType.GLOBAL + ) + rank = get_rank() + # compute weights for negative sample selection + with torch.no_grad(): + weights_t2i_for_neg_sampling = F.softmax(sim_t2i, dim=1) + 1e-4 + weights_t2i_for_neg_sampling[ + :, rank * local_batch_size : rank * local_batch_size + local_batch_size + ].fill_diagonal_(0) + weights_i2t_for_neg_sampling = F.softmax(sim_i2t, dim=1) + 1e-4 + weights_i2t_for_neg_sampling[ + :, rank * local_batch_size : rank * local_batch_size + local_batch_size + ].fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(local_batch_size): + neg_idx = int(torch.multinomial(weights_t2i_for_neg_sampling[b], 1).item()) + image_embeds_neg.append(image_embeds_all_gpus[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(local_batch_size): + neg_idx = int(torch.multinomial(weights_i2t_for_neg_sampling[b], 1).item()) + text_ids_neg.append(text_input_ids_all_gpus[neg_idx]) + text_atts_neg.append(text_attention_mask_all_gpus[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat( + [input_ids, input_ids, text_ids_neg], dim=0 + ) # pos, pos, neg + text_atts_all = torch.cat( + [attention_mask, attention_mask, text_atts_neg], + dim=0, + ) + + query_tokens_itm = model_query_tokens.expand(text_ids_all.shape[0], -1, -1) + query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to( + device + ) + attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) + + image_embeds_all = torch.cat( + [image_embeds, image_embeds_neg, image_embeds], dim=0 + ) # pos, neg, pos + output_itm = qformer( + input_ids=text_ids_all, + query_embeds=query_tokens_itm, + attention_mask=attention_mask_all, + encoder_hidden_states=image_embeds_all, + ) + vl_embeddings = output_itm[0][ + :, : query_tokens_itm.size(1), : + ] # [bsz x 3, query_token_len, dim_q] + vl_output = itm_head(vl_embeddings) # [bsz x 3, query_token_len, 2] + itm_logits = vl_output.mean(dim=1) + + itm_labels = torch.cat( + [ + torch.ones(local_batch_size, dtype=torch.long), + torch.zeros(2 * local_batch_size, dtype=torch.long), + ], + dim=0, + ).to(device) + + return F.cross_entropy(itm_logits, itm_labels, reduction="mean") + + +class Blip2Phase1Loss(nn.Module): + """ + Blip2 phase 1 loss module + + Args: + dim_q (int): Dimension of query tensor, this value should be the same as dim_q in qformer. + default value is 768 as in the paper. + enable_itc (bool): enable image-text contrastive loss, default is True + enable_itm (bool): enable image-text matching, default is True + enable_itg (bool): enable image caption loss, default is True + temp (float): temperature for image-text similarity computation, default is 0.07 + label_smoothing (float): label smoothing value, default is 0.1 + """ + + def __init__( + self, + dim_q: int = 768, + enable_itc: bool = True, + enable_itm: bool = True, + enable_itg: bool = True, + temp: float = 0.07, + label_smoothing: float = 0.1, + ) -> None: + super().__init__() + if not enable_itc and not enable_itm and not enable_itg: + raise ValueError( + "All the loss tasks are disabled, please set at least one of them." + ) + self.label_smoothing = label_smoothing + self.enable_itc = enable_itc + self.enable_itm = enable_itm + self.enable_itg = enable_itg + self.itm_head = nn.Linear(dim_q, 2) + self.temp = nn.Parameter(temp * torch.ones([])) + + def forward( + self, + model_output: Blip2Output, + blip2: nn.Module, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + ) -> Blip2Stage1Losses: + """ + Inputs: + model_output (Blip2Output): model output from BLIP2 (see blip2.py) + blip2 (nn.Module): BLIP2 model with updated params + input_ids (Optional[torch.Tensor]): text input ids of shape [bsz, seq_len]. + attention_mask (Optional[torch.Tensor]): text input attention mask of shape [bsz, seq_len]. + + Returns: + loss (Blip2Stage1Losses): computed loss for phase 1 tasks. + """ + + # calculate similarities + assert model_output.text_features is not None + (sim_i2t, sim_t2i,) = compute_image_text_similarity( + model_output.image_features, + model_output.text_features, + temp=self.temp, + ) + + # calculate image-text matching loss + loss_itm = torch.tensor(0.0) + if self.enable_itm: + assert input_ids is not None and attention_mask is not None + loss_itm = itm_loss( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=model_output.image_embeddings, + sim_i2t=sim_i2t, + sim_t2i=sim_t2i, + model_query_tokens=blip2.query_tokens, + qformer=blip2.qformer.model, + itm_head=self.itm_head, + ) + + # calculate image captioning loss (aka image-text generation) + loss_itg = torch.tensor(0.0) + if self.enable_itg: + assert input_ids is not None and model_output.prediction_scores is not None + loss_itg = itg_loss( + input_ids=input_ids, + prediction_scores=model_output.prediction_scores, + decoder_bos_token_id=blip2.decoder_bos_token_id, + pad_token_id=blip2.qformer.pad_token_id, + vocab_size=blip2.qformer.vocab_size, + label_smoothing=self.label_smoothing, + ) + + # calculate image-text contrastive loss + loss_itc = torch.tensor(0.0) + if self.enable_itc: + loss_itc = itc_loss( + sim_i2t=sim_i2t, + sim_t2i=sim_t2i, + label_smoothing=self.label_smoothing, + ) + + return Blip2Stage1Losses( + image_text_contrastive_loss=loss_itc, + image_captioning_loss=loss_itg, + image_text_matching_loss=loss_itm, + total_loss=loss_itc + loss_itm + loss_itg, + )