-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate flex attention #1193
Integrate flex attention #1193
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1193
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4f1eaa4 with merge base eb92658 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
One nit: Flex is not intended to be a replacement for sdpa rather it is designed to enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey, thanks for PR! I am excited to see it land!
The examples in the docstring are very nice. I left some comments, some are nits, but others are more structural, like:
- Where should padding be done
- When to send to device
- Will we ever fallback to efficient attention?
torchtune/utils/collate.py
Outdated
) -> Dict[str, torch.Tensor]: | ||
"""Collate packed sequences into a batch. Only convert the seq lens into | ||
a block mask for use with attention. Tokens, labels, and input_pos are | ||
already padded to the same length within :class:`~torchtune.datasets.PackedDataset`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tokens, labels, and input_pos are already padded to the same length within :class:
I believe that packing should be moved to this function.
-
All collator would have the same signature. The other collator takes ignore_idx=self._loss_fn.ignore_index.
-
If we pad to the batch max_seq_len, instead of the defined max_seq_len, we can save tokens.
-
The collator logic is fragmented (partially within the dataset, partially here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am also confused by the padding logic, but have a slightly different question. Iiuc:
(1) we pad each pack to max_seq_len
in the packed dataset. We also append an element to the pack's seq_len field telling us how many padding tokens there are
(2) in the collator we then pad the seq_len field for each pack in the batch to be the same length
(3) we construct the mask using these padded seq_len fieldsing
Is there something in (3) that necessitates the padding done in (2)? Maybe I'm being dense but I don't see it.. we either use flex attention, in which case we call _get_document_ids_from_seq_lens
, or we don't, in which case we call create_block_causal_mask
. In both cases we are just iterating over seq_lens[batch_idx]
, right? So if I've already padded everything to a fixed length inside of the packed dataset, what is the value of also padding the seq lengths?
Concretely, if I have pack 0 with (unpadded) sequence lengths [1, 3, 5] and pack 1 with (unpadded) sequence lengths [1, 2, 3, 4, 5] and my max_seq_len is 20, then in the packed dataset I will pad each pack and have new seq lengths [1, 3, 5, 11] and [1, 2, 3, 4, 5, 5], respectively. But then iiuc the collator will also pad [1, 3, 5, 11] -> [1, 3, 5, 11, 0, 0]. Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also @felipemello1's comment on moving all the padding into collate is interesting. I'm kinda inclined to agree; the one drawback could be that we lose that nice fixed-sequence-length property of our packed datasets that makes compile's life easier (though it's done fine with our unpacked variable-length sequences so maybe nbd). I do wonder if we should decouple those changes from this PR since it involves more nontrivial changes to the packed dataset (unless ofc it winds up being straightforward to do on top of the other stuff here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to keep this property, we can still have it in the collator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @gau-nernst mentioned that keep seq len fixed in PackedDataset might still be more efficient for compile even if it does compile successfully
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Concretely, if I have pack 0 with (unpadded) sequence lengths [1, 3, 5] and pack 1 with (unpadded) sequence lengths [1, 2, 3, 4, 5] and my max_seq_len is 20, then in the packed dataset I will pad each pack and have new seq lengths [1, 3, 5, 11] and [1, 2, 3, 4, 5, 5], respectively. But then iiuc the collator will also pad [1, 3, 5, 11] -> [1, 3, 5, 11, 0, 0]. Why?
@ebsmothers I think your assessment is correct. We shouldn't need to pad seq lens as long as they each sum to max seq len
I'm also inclined to agree that updating PackedDataset to have variable length packs would be non-trivial and needs some experimentation to understand memory/compute impact, so it should be a separate PR.
torchtune/modules/transformer.py
Outdated
with shape [batch_size x seq_length x seq_length]. This is applied after | ||
the query-key multiplication and before the softmax. A value of True in row i | ||
and column j means token i attends to token j. A value of False means token i | ||
does not attend to token j. If no mask is specified, a causal mask | ||
is used by default. Default is None. | ||
is used by default. If a BlockMask is passed for document masking in a packed | ||
sequence, we use :func:`~torch.nn.attention.flex_attention.flex_attention` when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The link to flex_attention doesn't seem to render in the live docs, fyi
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1193 +/- ##
===========================================
+ Coverage 27.22% 71.21% +43.99%
===========================================
Files 286 287 +1
Lines 13828 14058 +230
===========================================
+ Hits 3764 10011 +6247
+ Misses 10064 4047 -6017 ☔ View full report in Codecov by Sentry. |
recipes/full_finetune_distributed.py
Outdated
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
device=self._device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no longer need device here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You really had to jump through a lot of hoops on this one.. one of these days I swear you will have a PR that is nice and easy where everything goes smoothly. This was not that PR. Honestly great work though, the perf wins on this are gonna be incredible
TLDR: We can nearly 7x (+560%) our throughput by turning compile + sample packing + flex attention on (based on Alpaca). Compared to packing with compiled SDPA, compiled flex attention boosts throughput by over 2x (140%) for Llama3 8B at max sequence length of 8192. This effect is more pronounced as sequence length increases as it scales more efficiently.
Context
This is a proposal of what the integration with the prototype API
flex_attention
available in PyTorch Core's nightlies would look like in torchtune. It requires replacing thenn.scaled_dot_product_attention
call inCausalSelfAttention
withflex_attention
and the usage ofBlockMask
instead of a standard tensor mask. The primary challenges are handling versioning since this is not available in the latest stable release of torch. I've tried to minimize changes to the attention/transformer modules and the recipes as much as possible besides version gating logic. Most of the core updates occur inPackedDataset
and a new collate function for sample packing to construct theBlockMask
.What is FlexAttention?
FlexAttention is an alternative to SDPA that can enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile. It allows users to specify custom modifications to attention scores within the Fused Scaled Dot Product Attention Kernel. This enables various attention patterns and biases to be implemented efficiently, with potential runtime and memory savings. This includes support for arbitrary masks with flash attention, which will enable performant attention kernels for sample packing (which requires a block causal mask) and vision cross attention masks.
The signature is largely similar to SDPA, except the expected mask is now a
BlockMask
(see here). This takes in amask_mod
function that will materialize a kernel-level mask ad-hoc, preventing us from holding large 2D attention masks in memory that scale quadratically with sequence length.Thus, we expect FlexAttention with sample packing to be 1) more memory-efficient and 2) increase throughput due to faster mask construction and ability to use flash attention.
How does sample packing with SDPA compare to Flex?
This gist covers the details of this experiment: https://gist.github.com/RdoubleA/012409f7919973d6ba7e9ca3efd5c237. Losses were observed to be equivalent before and after this change.
TLDR is sample packing with FlexAttention scales significantly better with sequence length:
From these observations, we can extrapolate that flex attention will be a dramatic improvement over SDPA with sample packing for context lengths 100k - 1M. For these reasons, it is worthwhile integrating this in our core modules even if it may add some slightly complexity.
Changelog
PackedDataset
's__getitem__
, we create a list ofdocument_ids
that indicate which tokens belong to which samples in a packed sequencedocument_ids
to create aBlockMask
in the batch collater. A new collate function for packing is added, and mask utilities were added toutils/attention_bias.py
CausalSelfAttention
will swap which attention to use depending on if a normla tensor mask is passed or aBlockMask
torch==2.5.0.dev20240717+cu124
when testing.Test plan
PackedDataset
testscc @drisspg @Chillee @kartikayk