Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate flex attention #1193

Merged
merged 40 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
cc92391
initial commit
RdoubleA Jul 17, 2024
de151e3
update APIs
RdoubleA Jul 17, 2024
41df445
fix mask mod for padding
RdoubleA Jul 18, 2024
1d61ed4
benchmark script
RdoubleA Jul 18, 2024
151336e
move block mask to recipe level
RdoubleA Jul 18, 2024
69005b4
official implementation of flex, gated by version
RdoubleA Jul 22, 2024
076d5b9
Merge branch 'main' into flex_attention
RdoubleA Jul 25, 2024
759e882
add tests
RdoubleA Jul 26, 2024
065bcce
Merge branch 'main' into flex_attention
RdoubleA Jul 31, 2024
2017401
undo collate move to device, add logging
RdoubleA Aug 1, 2024
d590b63
remove device in collate
RdoubleA Aug 1, 2024
b8b4ede
Merge branch 'main' into flex_attention
RdoubleA Aug 20, 2024
94c2c76
fix merge
RdoubleA Aug 20, 2024
14b59e7
Merge branch 'main' into flex_attention
RdoubleA Aug 26, 2024
9fe98f1
Merge branch 'main' into flex_attention
RdoubleA Aug 27, 2024
628c233
Merge branch 'main' into flex_attention
RdoubleA Sep 4, 2024
b82c50c
Merge branch 'main' into flex_attention
RdoubleA Sep 4, 2024
f988424
Merge branch 'main' into flex_attention
RdoubleA Sep 5, 2024
5b64fb9
Merge branch 'main' into flex_attention
RdoubleA Sep 5, 2024
cd2291c
first round of comments
RdoubleA Sep 6, 2024
f59e3d7
second round of comments
RdoubleA Sep 6, 2024
b685bcb
Merge branch 'main' into flex_attention
RdoubleA Sep 6, 2024
2a1215c
fix attention utils tests
RdoubleA Sep 6, 2024
a8c0dc8
make seq_lens uneven in tests
RdoubleA Sep 6, 2024
99ea72a
Merge branch 'main' into flex_attention
RdoubleA Sep 6, 2024
5a0943a
Merge branch 'main' into flex_attention
RdoubleA Sep 6, 2024
97fbe4d
NO NEW CYCLES
RdoubleA Sep 7, 2024
6976635
Merge branch 'main' into flex_attention
RdoubleA Sep 9, 2024
563a8ef
remove device
RdoubleA Sep 9, 2024
c3cda75
log on rank zero
RdoubleA Sep 9, 2024
48e127c
fix tests
RdoubleA Sep 9, 2024
74732c1
do not compile flex by default
RdoubleA Sep 9, 2024
5770bdf
ugly compile logic
RdoubleA Sep 9, 2024
6497f37
revert for testing
RdoubleA Sep 9, 2024
40056ce
quick fix
RdoubleA Sep 9, 2024
6cca19a
the fix
RdoubleA Sep 10, 2024
926bdf2
stick with hacky way
RdoubleA Sep 10, 2024
2844cdb
Merge branch 'main' into flex_attention
RdoubleA Sep 10, 2024
1c31a83
comments
RdoubleA Sep 11, 2024
4f1eaa4
comments
RdoubleA Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,10 @@ def _setup_data(
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
else partial(
utils.padded_collate_packed,
device=self._device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer need device here?

),
)

if self._is_rank_zero:
Expand Down
5 changes: 4 additions & 1 deletion recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,10 @@ def _setup_data(
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
else partial(
utils.padded_collate_packed,
device=self._device,
),
)

log.info("Dataset and Sampler are initialized.")
Expand Down
17 changes: 9 additions & 8 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,14 +520,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
utils.padded_collate_packed,
device=self._device,
),
)

Expand Down
17 changes: 9 additions & 8 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,15 @@ def _setup_data(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
collate_fn=(
partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
utils.padded_collate_packed,
device=self._device,
),
)

Expand Down
5 changes: 4 additions & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,10 @@ def _setup_data(
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
else partial(
utils.padded_collate_packed,
device=self._device,
),
)

if self._is_rank_zero:
Expand Down
82 changes: 23 additions & 59 deletions tests/torchtune/datasets/test_packed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,27 @@ def __len__(self):


class TestPackedDataset:
def _get_expected_mask_and_input_pos(
def _get_expected_seq_lens_and_input_pos(
self, max_seq_len, sample_size, split_across_pack
):
"""
Generate expected integer mask and position ids for given max sequence
Generate expected seq lens and position ids for given max sequence
length and sample length
"""
num_samples, remainder = divmod(max_seq_len, sample_size)
seq_lens = [sample_size] * num_samples
if split_across_pack and remainder > 0:
num_samples += 1
mask = torch.block_diag(
*[
torch.tril(torch.ones(sample_size, sample_size, dtype=torch.bool))
for i in range(1, num_samples + 1)
]
)
input_pos = [list(range(sample_size)) for i in range(1, num_samples + 1)]
input_pos = list(itertools.chain(*input_pos))

# Emulate mask and position id padding
if not split_across_pack and remainder > 0:
mask = torch.block_diag(
mask,
torch.eye(remainder, dtype=torch.bool),
)
input_pos.extend(list(range(sample_size, sample_size + remainder)))
# Emulate seq len and position id padding
if remainder > 0:
if not split_across_pack:
input_pos.extend(list(range(sample_size, sample_size + remainder)))
seq_lens.extend([remainder])

return mask[:max_seq_len, :max_seq_len], torch.tensor(input_pos[:max_seq_len])
return torch.tensor(seq_lens), torch.tensor(input_pos[:max_seq_len])

def _calculate_num_packs(
self, dataset_size, max_seq_len, sample_size, split_across_pack, max_packs
Expand Down Expand Up @@ -122,7 +115,6 @@ def test_packed_dataset(
assert (
len(packed[0]["tokens"])
== len(packed[0]["labels"])
== len(packed[0]["mask"])
== len(packed[0]["input_pos"])
)
# Check that samples are packed correctly - very last individual sample
Expand All @@ -145,10 +137,13 @@ def test_packed_dataset(

assert packed[-1]["tokens"][-1].item() == last_index

expected_mask, expected_input_pos = self._get_expected_mask_and_input_pos(
(
expected_seq_lens,
expected_input_pos,
) = self._get_expected_seq_lens_and_input_pos(
max_seq_len, sample_size, split_across_pack
)
torch.testing.assert_close(packed[0]["mask"], expected_mask)
torch.testing.assert_close(packed[0]["seq_lens"], expected_seq_lens)
torch.testing.assert_close(packed[0]["input_pos"], expected_input_pos)

def test_packed_dataset_real_data(self):
Expand All @@ -162,48 +157,15 @@ def test_packed_dataset_real_data(self):
torch.tensor([5, 2, 6, 4, 3, 8, -1, 0, 4, 3]),
torch.tensor([4, 3, 2, 5, 7, -1, -100, -100, -100, -100]),
]
expected_mask = [
expected_seq_lens = [
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
[7, 3],
),
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
[7, 3],
),
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
]
[6, 4],
),
]
expected_input_pos = [
Expand All @@ -219,16 +181,16 @@ def test_packed_dataset_real_data(self):
)

for i in range(len(packed)):
prompt, label, mask, input_pos = (
prompt, label, seq_lens, input_pos = (
packed[i]["tokens"],
packed[i]["labels"],
packed[i]["mask"],
packed[i]["seq_lens"],
packed[i]["input_pos"],
)
torch.testing.assert_close(prompt, expected_tokenized_prompts[i])
torch.testing.assert_close(label, expected_tokenized_labels[i])
torch.testing.assert_close(input_pos, expected_input_pos[i])
torch.testing.assert_close(mask, expected_mask[i].to(dtype=torch.bool))
torch.testing.assert_close(seq_lens, expected_seq_lens[i])

def test_pad_pack(self):
padding_idx = -8
Expand All @@ -255,6 +217,7 @@ def test_pad_pack(self):
padded_input = padded["tokens"]
padded_label = padded["labels"]
padded_input_pos = padded["input_pos"]
padded_seq_lens = padded["seq_lens"]

torch.testing.assert_close(
padded_input, torch.tensor([2, 5, padding_idx, padding_idx])
Expand All @@ -263,6 +226,7 @@ def test_pad_pack(self):
padded_label, torch.tensor([3, 7, ignore_idx, ignore_idx])
)
torch.testing.assert_close(padded_input_pos, torch.tensor([8, 0, 1, 2]))
torch.testing.assert_close(padded_seq_lens, torch.tensor([1, 1, 2]))

def test_pack_errors_if_sample_too_long(self):
dataset = DummyDataset(8)
Expand Down
81 changes: 81 additions & 0 deletions tests/torchtune/utils/test_attention_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from unittest import mock

import pytest
import torch

from torchtune.utils.attention_bias import (
_get_document_ids_from_seq_lens,
create_block_causal_mask,
packed_block_causal_mask,
)


class TestBlockCausalMask:
@pytest.fixture
def seq_lens(self):
return torch.tensor([[2, 3, 1, 0], [2, 2, 2, 0]])

def test_get_document_ids_from_seq_lens(self, seq_lens):
actual = _get_document_ids_from_seq_lens(seq_lens)
expected = torch.tensor([[0, 0, 1, 1, 1, 2], [0, 0, 1, 1, 2, 2]])
torch.testing.assert_close(actual, expected)

def test_create_block_causal_mask(self, seq_lens):
actual = create_block_causal_mask(seq_lens)
expected = torch.tensor(
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 1, 1],
],
],
dtype=torch.bool,
)
torch.testing.assert_close(actual, expected)

@mock.patch("torchtune.utils.attention_bias.torch_version_ge")
def test_packed_block_causal_mask_sdpa(self, mock_version, seq_lens):
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
mock_version.return_value = False
actual = packed_block_causal_mask(seq_lens, device="cpu")
expected = torch.tensor(
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 1, 1],
],
],
dtype=torch.bool,
)
torch.testing.assert_close(actual, expected)
Loading
Loading