-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate flex attention #1193
Merged
Merged
Integrate flex attention #1193
Changes from 13 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
cc92391
initial commit
RdoubleA de151e3
update APIs
RdoubleA 41df445
fix mask mod for padding
RdoubleA 1d61ed4
benchmark script
RdoubleA 151336e
move block mask to recipe level
RdoubleA 69005b4
official implementation of flex, gated by version
RdoubleA 076d5b9
Merge branch 'main' into flex_attention
RdoubleA 759e882
add tests
RdoubleA 065bcce
Merge branch 'main' into flex_attention
RdoubleA 2017401
undo collate move to device, add logging
RdoubleA d590b63
remove device in collate
RdoubleA b8b4ede
Merge branch 'main' into flex_attention
RdoubleA 94c2c76
fix merge
RdoubleA 14b59e7
Merge branch 'main' into flex_attention
RdoubleA 9fe98f1
Merge branch 'main' into flex_attention
RdoubleA 628c233
Merge branch 'main' into flex_attention
RdoubleA b82c50c
Merge branch 'main' into flex_attention
RdoubleA f988424
Merge branch 'main' into flex_attention
RdoubleA 5b64fb9
Merge branch 'main' into flex_attention
RdoubleA cd2291c
first round of comments
RdoubleA f59e3d7
second round of comments
RdoubleA b685bcb
Merge branch 'main' into flex_attention
RdoubleA 2a1215c
fix attention utils tests
RdoubleA a8c0dc8
make seq_lens uneven in tests
RdoubleA 99ea72a
Merge branch 'main' into flex_attention
RdoubleA 5a0943a
Merge branch 'main' into flex_attention
RdoubleA 97fbe4d
NO NEW CYCLES
RdoubleA 6976635
Merge branch 'main' into flex_attention
RdoubleA 563a8ef
remove device
RdoubleA c3cda75
log on rank zero
RdoubleA 48e127c
fix tests
RdoubleA 74732c1
do not compile flex by default
RdoubleA 5770bdf
ugly compile logic
RdoubleA 6497f37
revert for testing
RdoubleA 40056ce
quick fix
RdoubleA 6cca19a
the fix
RdoubleA 926bdf2
stick with hacky way
RdoubleA 2844cdb
Merge branch 'main' into flex_attention
RdoubleA 1c31a83
comments
RdoubleA 4f1eaa4
comments
RdoubleA File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
from unittest import mock | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchtune.utils.attention_bias import ( | ||
_get_document_ids_from_seq_lens, | ||
create_block_causal_mask, | ||
packed_block_causal_mask, | ||
) | ||
|
||
|
||
class TestBlockCausalMask: | ||
@pytest.fixture | ||
def seq_lens(self): | ||
return torch.tensor([[2, 3, 1, 0], [2, 2, 2, 0]]) | ||
|
||
def test_get_document_ids_from_seq_lens(self, seq_lens): | ||
actual = _get_document_ids_from_seq_lens(seq_lens) | ||
expected = torch.tensor([[0, 0, 1, 1, 1, 2], [0, 0, 1, 1, 2, 2]]) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_create_block_causal_mask(self, seq_lens): | ||
actual = create_block_causal_mask(seq_lens) | ||
expected = torch.tensor( | ||
[ | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[0, 0, 1, 0, 0, 0], | ||
[0, 0, 1, 1, 0, 0], | ||
[0, 0, 1, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 1], | ||
], | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[0, 0, 1, 0, 0, 0], | ||
[0, 0, 1, 1, 0, 0], | ||
[0, 0, 0, 0, 1, 0], | ||
[0, 0, 0, 0, 1, 1], | ||
], | ||
], | ||
dtype=torch.bool, | ||
) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
@mock.patch("torchtune.utils.attention_bias.torch_version_ge") | ||
def test_packed_block_causal_mask_sdpa(self, mock_version, seq_lens): | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mock_version.return_value = False | ||
actual = packed_block_causal_mask(seq_lens, device="cpu") | ||
expected = torch.tensor( | ||
[ | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[0, 0, 1, 0, 0, 0], | ||
[0, 0, 1, 1, 0, 0], | ||
[0, 0, 1, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 1], | ||
], | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[0, 0, 1, 0, 0, 0], | ||
[0, 0, 1, 1, 0, 0], | ||
[0, 0, 0, 0, 1, 0], | ||
[0, 0, 0, 0, 1, 1], | ||
], | ||
], | ||
dtype=torch.bool, | ||
) | ||
torch.testing.assert_close(actual, expected) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no longer need device here?