From 226778e76561fa942651f8053bf7d473d368f838 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Wed, 7 Feb 2024 17:21:46 -0800 Subject: [PATCH] implement MaMMUT (#520) Summary: Implement MaMMUT, mostly based on current CoCa code as well as https://github.com/lucidrains/MaMMUT-pytorch. Differential Revision: D52823194 Privacy Context Container: 303860477774201 --- torchmultimodal/modules/layers/transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 47f42d57c..06237474f 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -554,6 +554,8 @@ class TransformerDecoder(nn.Module): If None, K and V are assumed to have dimension d_model. Defaults to None. final_layer_norm_eps (Optional[float]): epsilon used in final layer norm. Defaults to None (no final layer norm). + cross_attention_interval: interval layers to apply cross attention. Not used if + use_cross_attention = False """ def __init__( @@ -569,6 +571,7 @@ def __init__( use_cross_attention: bool = True, dim_kv: Optional[int] = None, final_layer_norm_eps: Optional[float] = None, + cross_attention_interval: int = 1, ): super().__init__() self.layer = nn.ModuleList( @@ -581,7 +584,7 @@ def __init__( activation, layer_norm_eps, norm_first, - use_cross_attention, + use_cross_attention and (i % cross_attention_interval == 0), dim_kv, ) for i in range(n_layer)