Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate flex attention #1193

Merged
merged 40 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
cc92391
initial commit
RdoubleA Jul 17, 2024
de151e3
update APIs
RdoubleA Jul 17, 2024
41df445
fix mask mod for padding
RdoubleA Jul 18, 2024
1d61ed4
benchmark script
RdoubleA Jul 18, 2024
151336e
move block mask to recipe level
RdoubleA Jul 18, 2024
69005b4
official implementation of flex, gated by version
RdoubleA Jul 22, 2024
076d5b9
Merge branch 'main' into flex_attention
RdoubleA Jul 25, 2024
759e882
add tests
RdoubleA Jul 26, 2024
065bcce
Merge branch 'main' into flex_attention
RdoubleA Jul 31, 2024
2017401
undo collate move to device, add logging
RdoubleA Aug 1, 2024
d590b63
remove device in collate
RdoubleA Aug 1, 2024
b8b4ede
Merge branch 'main' into flex_attention
RdoubleA Aug 20, 2024
94c2c76
fix merge
RdoubleA Aug 20, 2024
14b59e7
Merge branch 'main' into flex_attention
RdoubleA Aug 26, 2024
9fe98f1
Merge branch 'main' into flex_attention
RdoubleA Aug 27, 2024
628c233
Merge branch 'main' into flex_attention
RdoubleA Sep 4, 2024
b82c50c
Merge branch 'main' into flex_attention
RdoubleA Sep 4, 2024
f988424
Merge branch 'main' into flex_attention
RdoubleA Sep 5, 2024
5b64fb9
Merge branch 'main' into flex_attention
RdoubleA Sep 5, 2024
cd2291c
first round of comments
RdoubleA Sep 6, 2024
f59e3d7
second round of comments
RdoubleA Sep 6, 2024
b685bcb
Merge branch 'main' into flex_attention
RdoubleA Sep 6, 2024
2a1215c
fix attention utils tests
RdoubleA Sep 6, 2024
a8c0dc8
make seq_lens uneven in tests
RdoubleA Sep 6, 2024
99ea72a
Merge branch 'main' into flex_attention
RdoubleA Sep 6, 2024
5a0943a
Merge branch 'main' into flex_attention
RdoubleA Sep 6, 2024
97fbe4d
NO NEW CYCLES
RdoubleA Sep 7, 2024
6976635
Merge branch 'main' into flex_attention
RdoubleA Sep 9, 2024
563a8ef
remove device
RdoubleA Sep 9, 2024
c3cda75
log on rank zero
RdoubleA Sep 9, 2024
48e127c
fix tests
RdoubleA Sep 9, 2024
74732c1
do not compile flex by default
RdoubleA Sep 9, 2024
5770bdf
ugly compile logic
RdoubleA Sep 9, 2024
6497f37
revert for testing
RdoubleA Sep 9, 2024
40056ce
quick fix
RdoubleA Sep 9, 2024
6cca19a
the fix
RdoubleA Sep 10, 2024
926bdf2
stick with hacky way
RdoubleA Sep 10, 2024
2844cdb
Merge branch 'main' into flex_attention
RdoubleA Sep 10, 2024
1c31a83
comments
RdoubleA Sep 11, 2024
4f1eaa4
comments
RdoubleA Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -500,14 +500,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
device=self._device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer need device here?

),
)

Expand Down
19 changes: 10 additions & 9 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -464,14 +464,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
device=self._device,
),
)

Expand Down
19 changes: 10 additions & 9 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
Expand Down Expand Up @@ -551,14 +551,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
device=self._device,
),
)

Expand Down
19 changes: 10 additions & 9 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
Expand Down Expand Up @@ -495,14 +495,15 @@ def _setup_data(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
device=self._device,
),
)

Expand Down
19 changes: 10 additions & 9 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -523,14 +523,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
device=self._device,
),
)

Expand Down
109 changes: 109 additions & 0 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from unittest import mock

import pytest
import torch
from tests.test_utils import gpu_test
from torchtune.data import (
left_pad_sequence,
padded_collate,
padded_collate_dpo,
padded_collate_packed,
padded_collate_sft,
)
from torchtune.modules.attention_utils import _SUPPORTS_FLEX_ATTENTION


class TestPaddedCollateSFT:
Expand Down Expand Up @@ -47,6 +52,110 @@ def test_batch_pad_sequence(self):
padded_label, torch.tensor([10, ignore_idx, ignore_idx])
)

@mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False)
def test_padded_collate_packed_sdpa(self):
token_pairs = [
{
"tokens": torch.tensor([1, 2, 3, 4, 5, 6]),
"labels": torch.tensor([7, 8, 9, 10, 11, 12]),
"input_pos": torch.tensor([0, 1, 2, 0, 1, 0]),
"seq_lens": torch.tensor([3, 2, 1]),
},
{
"tokens": torch.tensor([13, 14, 15, 16, 17, 18]),
"labels": torch.tensor([19, 20, 21, 22, 23, 24]),
"input_pos": torch.tensor([0, 1, 0, 1, 0, 1]),
"seq_lens": torch.tensor([2, 2, 2]),
},
]
collated = padded_collate_packed(
batch=token_pairs,
)
torch.testing.assert_close(
collated["tokens"],
torch.tensor([[1, 2, 3, 4, 5, 6], [13, 14, 15, 16, 17, 18]]),
)
torch.testing.assert_close(
collated["labels"],
torch.tensor([[7, 8, 9, 10, 11, 12], [19, 20, 21, 22, 23, 24]]),
)
torch.testing.assert_close(
collated["input_pos"],
torch.tensor([[0, 1, 2, 0, 1, 0], [0, 1, 0, 1, 0, 1]]),
)
torch.testing.assert_close(
collated["mask"],
torch.tensor(
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 1, 1],
],
],
dtype=torch.bool,
),
)

@pytest.mark.skipif(
not _SUPPORTS_FLEX_ATTENTION,
reason="Please install a nightly build of torch to run this test.",
)
@gpu_test(gpu_count=1)
def test_padded_collate_packed_flex(self):
# create_block_mask requires that seq_len be divisible by 128, the default block size.
# see https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L636
batch = [
{
"tokens": torch.ones(128, dtype=torch.long),
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
"labels": torch.ones(128, dtype=torch.long),
"input_pos": torch.zeros(128, dtype=torch.long),
"seq_lens": torch.ones(64, dtype=torch.long) * 2,
},
{
"tokens": torch.ones(128, dtype=torch.long),
"labels": torch.ones(128, dtype=torch.long),
"input_pos": torch.zeros(128, dtype=torch.long),
"seq_lens": torch.ones(32, dtype=torch.long) * 4,
},
]
collated = padded_collate_packed(
batch=batch,
)
torch.testing.assert_close(
collated["tokens"],
torch.stack(
[torch.ones(128, dtype=torch.long), torch.ones(128, dtype=torch.long)]
),
)
torch.testing.assert_close(
collated["labels"],
torch.stack(
[torch.ones(128, dtype=torch.long), torch.ones(128, dtype=torch.long)]
),
)
torch.testing.assert_close(
collated["input_pos"],
torch.stack(
[torch.zeros(128, dtype=torch.long), torch.zeros(128, dtype=torch.long)]
),
)
torch.testing.assert_close(
collated["mask"].to_dense(),
torch.tensor([[[[1]]], [[[1]]]], dtype=torch.int32, device="cuda"),
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
)


class TestLeftPadSequence:
def test_left_pad_sequence(self):
Expand Down
Loading
Loading