From 9e2467cd9fe48d3634acd3320f807b70eb50e3f5 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Mon, 22 Jan 2024 18:14:17 -0800 Subject: [PATCH] create configuration file for MaMMUT training (#521) Summary: Mostly based on original coca and https://github.com/lucidrains/MaMMUT-pytorch Update the logics of loading checkpoint for MaMMUT text decoder as well. Differential Revision: D52891614 Privacy Context Container: 303860477774201 --- torchmultimodal/modules/layers/transformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 06237474..3330eabf 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -413,10 +413,7 @@ def _forward_prenorm( self_attn_output = attn_output + hidden_states # Optional cross-attention - if self.use_cross_attention: - assert ( - encoder_hidden_states is not None - ), "encoder_hidden_states must be provided for cross attention" + if self.use_cross_attention and encoder_hidden_states is not None: assert hasattr( self, "cross_attention_layernorm" ), "Cross-attention layernorm not initialized"