Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate flex attention #1193

Merged
merged 40 commits into from
Sep 11, 2024
Merged

Integrate flex attention #1193

merged 40 commits into from
Sep 11, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Jul 17, 2024

TLDR: We can nearly 7x (+560%) our throughput by turning compile + sample packing + flex attention on (based on Alpaca). Compared to packing with compiled SDPA, compiled flex attention boosts throughput by over 2x (140%) for Llama3 8B at max sequence length of 8192. This effect is more pronounced as sequence length increases as it scales more efficiently.

Context

This is a proposal of what the integration with the prototype API flex_attention available in PyTorch Core's nightlies would look like in torchtune. It requires replacing the nn.scaled_dot_product_attention call in CausalSelfAttention with flex_attention and the usage of BlockMask instead of a standard tensor mask. The primary challenges are handling versioning since this is not available in the latest stable release of torch. I've tried to minimize changes to the attention/transformer modules and the recipes as much as possible besides version gating logic. Most of the core updates occur in PackedDataset and a new collate function for sample packing to construct the BlockMask.

What is FlexAttention?

FlexAttention is an alternative to SDPA that can enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile. It allows users to specify custom modifications to attention scores within the Fused Scaled Dot Product Attention Kernel. This enables various attention patterns and biases to be implemented efficiently, with potential runtime and memory savings. This includes support for arbitrary masks with flash attention, which will enable performant attention kernels for sample packing (which requires a block causal mask) and vision cross attention masks.

torch.nn.attention.flex_attention.flex_attention(
    q,
    k,
    v,
    block_mask=mask,
)

The signature is largely similar to SDPA, except the expected mask is now a BlockMask (see here). This takes in a mask_mod function that will materialize a kernel-level mask ad-hoc, preventing us from holding large 2D attention masks in memory that scale quadratically with sequence length.

Thus, we expect FlexAttention with sample packing to be 1) more memory-efficient and 2) increase throughput due to faster mask construction and ability to use flash attention.

How does sample packing with SDPA compare to Flex?

This gist covers the details of this experiment: https://gist.github.com/RdoubleA/012409f7919973d6ba7e9ca3efd5c237. Losses were observed to be equivalent before and after this change.

TLDR is sample packing with FlexAttention scales significantly better with sequence length:

  • Throughput actually decreases as we increased max sequence length from 2048 -> 8192 for SDPA, likely due to quadratic increase in compute for assembling a [b x 8192 x 8192] block causal mask
  • At a sequence length of 2048, we observed +50% increase in WPS when switching to FlexAttention.
  • At a sequence length of 8192, WPS dropped by 23% for SDPA compared to 2048
  • At a sequence length of 8192, FlexAttention was a +140% increase in WPS compared to SDPA
  • Without packing, we achieve a baseline of ~500 WPS per gpu, meaning we can 7x our throughput by turning packing + compile + flex attention on
  • We also reduce peak memory reserved by 1 GB by using flex attention

From these observations, we can extrapolate that flex attention will be a dramatic improvement over SDPA with sample packing for context lengths 100k - 1M. For these reasons, it is worthwhile integrating this in our core modules even if it may add some slightly complexity.

image

Changelog

  • Instead of creating a 2D mask on the fly in PackedDataset's __getitem__, we create a list of document_ids that indicate which tokens belong to which samples in a packed sequence
  • We use document_ids to create a BlockMask in the batch collater. A new collate function for packing is added, and mask utilities were added to utils/attention_bias.py
  • CausalSelfAttention will swap which attention to use depending on if a normla tensor mask is passed or a BlockMask
  • All flex attention logic is gated by version. I was on torch==2.5.0.dev20240717+cu124 when testing.

Test plan

  • Ensure loss curves are identical
  • Add test for packed collate
  • Update attention tests for flex attention
  • Update PackedDataset tests
  • Test with compile
  • Test affected recipes

cc @drisspg @Chillee @kartikayk

Copy link

pytorch-bot bot commented Jul 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1193

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4f1eaa4 with merge base eb92658 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 17, 2024
@RdoubleA RdoubleA changed the title [DO NOT LAND] Integrate flex attention [RFC] Integrate flex attention Jul 22, 2024
@RdoubleA RdoubleA marked this pull request as ready for review July 22, 2024 22:40
@drisspg
Copy link

drisspg commented Jul 22, 2024

One nit: Flex is not intended to be a replacement for sdpa rather it is designed to enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile

@RdoubleA RdoubleA changed the title [RFC] Integrate flex attention Integrate flex attention Aug 1, 2024
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey, thanks for PR! I am excited to see it land!

The examples in the docstring are very nice. I left some comments, some are nits, but others are more structural, like:

  • Where should padding be done
  • When to send to device
  • Will we ever fallback to efficient attention?

) -> Dict[str, torch.Tensor]:
"""Collate packed sequences into a batch. Only convert the seq lens into
a block mask for use with attention. Tokens, labels, and input_pos are
already padded to the same length within :class:`~torchtune.datasets.PackedDataset`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tokens, labels, and input_pos are already padded to the same length within :class:

I believe that packing should be moved to this function.

  1. All collator would have the same signature. The other collator takes ignore_idx=self._loss_fn.ignore_index.

  2. If we pad to the batch max_seq_len, instead of the defined max_seq_len, we can save tokens.

  3. The collator logic is fragmented (partially within the dataset, partially here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am also confused by the padding logic, but have a slightly different question. Iiuc:

(1) we pad each pack to max_seq_len in the packed dataset. We also append an element to the pack's seq_len field telling us how many padding tokens there are
(2) in the collator we then pad the seq_len field for each pack in the batch to be the same length
(3) we construct the mask using these padded seq_len fieldsing

Is there something in (3) that necessitates the padding done in (2)? Maybe I'm being dense but I don't see it.. we either use flex attention, in which case we call _get_document_ids_from_seq_lens, or we don't, in which case we call create_block_causal_mask. In both cases we are just iterating over seq_lens[batch_idx], right? So if I've already padded everything to a fixed length inside of the packed dataset, what is the value of also padding the seq lengths?

Concretely, if I have pack 0 with (unpadded) sequence lengths [1, 3, 5] and pack 1 with (unpadded) sequence lengths [1, 2, 3, 4, 5] and my max_seq_len is 20, then in the packed dataset I will pad each pack and have new seq lengths [1, 3, 5, 11] and [1, 2, 3, 4, 5, 5], respectively. But then iiuc the collator will also pad [1, 3, 5, 11] -> [1, 3, 5, 11, 0, 0]. Why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also @felipemello1's comment on moving all the padding into collate is interesting. I'm kinda inclined to agree; the one drawback could be that we lose that nice fixed-sequence-length property of our packed datasets that makes compile's life easier (though it's done fine with our unpacked variable-length sequences so maybe nbd). I do wonder if we should decouple those changes from this PR since it involves more nontrivial changes to the packed dataset (unless ofc it winds up being straightforward to do on top of the other stuff here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to keep this property, we can still have it in the collator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gau-nernst mentioned that keep seq len fixed in PackedDataset might still be more efficient for compile even if it does compile successfully

Copy link
Contributor Author

@RdoubleA RdoubleA Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concretely, if I have pack 0 with (unpadded) sequence lengths [1, 3, 5] and pack 1 with (unpadded) sequence lengths [1, 2, 3, 4, 5] and my max_seq_len is 20, then in the packed dataset I will pad each pack and have new seq lengths [1, 3, 5, 11] and [1, 2, 3, 4, 5, 5], respectively. But then iiuc the collator will also pad [1, 3, 5, 11] -> [1, 3, 5, 11, 0, 0]. Why?

@ebsmothers I think your assessment is correct. We shouldn't need to pad seq lens as long as they each sum to max seq len

I'm also inclined to agree that updating PackedDataset to have variable length packs would be non-trivial and needs some experimentation to understand memory/compute impact, so it should be a separate PR.

torchtune/utils/collate.py Outdated Show resolved Hide resolved
torchtune/utils/collate.py Outdated Show resolved Hide resolved
torchtune/utils/collate.py Outdated Show resolved Hide resolved
torchtune/utils/attention_bias.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
torchtune/utils/attention_bias.py Outdated Show resolved Hide resolved
torchtune/modules/attention.py Outdated Show resolved Hide resolved
tests/torchtune/utils/test_collate.py Outdated Show resolved Hide resolved
with shape [batch_size x seq_length x seq_length]. This is applied after
the query-key multiplication and before the softmax. A value of True in row i
and column j means token i attends to token j. A value of False means token i
does not attend to token j. If no mask is specified, a causal mask
is used by default. Default is None.
is used by default. If a BlockMask is passed for document masking in a packed
sequence, we use :func:`~torch.nn.attention.flex_attention.flex_attention` when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link to flex_attention doesn't seem to render in the live docs, fyi

@codecov-commenter
Copy link

codecov-commenter commented Sep 7, 2024

Codecov Report

Attention: Patch coverage is 75.34247% with 54 lines in your changes missing coverage. Please review.

Project coverage is 71.21%. Comparing base (66590b4) to head (4f1eaa4).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/modules/attention_utils.py 62.50% 21 Missing ⚠️
tests/torchtune/modules/test_attention_utils.py 69.69% 20 Missing ⚠️
tests/torchtune/data/test_collate.py 70.00% 6 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 2 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 1 Missing ⚠️
recipes/qat_distributed.py 0.00% 1 Missing ⚠️
torchtune/utils/logging.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1193       +/-   ##
===========================================
+ Coverage   27.22%   71.21%   +43.99%     
===========================================
  Files         286      287        +1     
  Lines       13828    14058      +230     
===========================================
+ Hits         3764    10011     +6247     
+ Misses      10064     4047     -6017     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

if not packed
else partial(
padded_collate_packed,
device=self._device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer need device here?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You really had to jump through a lot of hoops on this one.. one of these days I swear you will have a PR that is nice and easy where everything goes smoothly. This was not that PR. Honestly great work though, the perf wins on this are gonna be incredible

@RdoubleA RdoubleA merged commit 8451b0d into pytorch:main Sep 11, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants