Skip to content

Commit

Permalink
Integrate flex attention (#1193)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Sep 11, 2024
1 parent eb92658 commit 8451b0d
Show file tree
Hide file tree
Showing 21 changed files with 817 additions and 201 deletions.
20 changes: 10 additions & 10 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -227,7 +227,7 @@ def setup(self, cfg: DictConfig) -> None:
self._loss_fn = config.instantiate(cfg.loss)

if self._compile:
training.compile_loss(self.loss_fn, verbose=self._is_rank_zero)
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)

if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
# set num_output_chunks for model
Expand Down Expand Up @@ -491,14 +491,14 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
),
)

Expand Down
18 changes: 9 additions & 9 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -451,14 +451,14 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
),
)

Expand Down
18 changes: 9 additions & 9 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
DoRALinear,
Expand Down Expand Up @@ -559,14 +559,14 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
),
)

Expand Down
18 changes: 9 additions & 9 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
Expand Down Expand Up @@ -486,14 +486,14 @@ def _setup_data(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
),
)

Expand Down
18 changes: 9 additions & 9 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -523,14 +523,14 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
),
)

Expand Down
8 changes: 4 additions & 4 deletions tests/torchtune/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def test_log_config(self, capsys):
with mock.patch(
"torchtune.config._utils.get_logger", return_value=logger
), mock.patch(
"torchtune.config._utils.dist.is_available", return_value=True
"torchtune.utils.logging.dist.is_available", return_value=True
), mock.patch(
"torchtune.config._utils.dist.is_initialized", return_value=True
"torchtune.utils.logging.dist.is_initialized", return_value=True
):
# Make sure rank 0 logs as expected
with mock.patch(
"torchtune.config._utils.dist.get_rank",
"torchtune.utils.logging.dist.get_rank",
return_value=0,
):
log_config("test", cfg)
Expand All @@ -153,7 +153,7 @@ def test_log_config(self, capsys):

# Make sure all other ranks do not log anything
with mock.patch(
"torchtune.config._utils.dist.get_rank",
"torchtune.utils.logging.dist.get_rank",
return_value=1,
):
log_config("test", cfg)
Expand Down
118 changes: 118 additions & 0 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@

# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from unittest import mock

import pytest
import torch
from tests.test_utils import gpu_test
from torchtune.data import (
left_pad_sequence,
padded_collate,
padded_collate_dpo,
padded_collate_packed,
padded_collate_sft,
)
from torchtune.modules.attention_utils import _SUPPORTS_FLEX_ATTENTION


class TestPaddedCollateSFT:
Expand Down Expand Up @@ -47,6 +52,119 @@ def test_batch_pad_sequence(self):
padded_label, torch.tensor([10, ignore_idx, ignore_idx])
)

@mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False)
def test_padded_collate_packed_sdpa(self):
token_pairs = [
{
"tokens": torch.tensor([1, 2, 3, 4, 5, 6]),
"labels": torch.tensor([7, 8, 9, 10, 11, 12]),
"input_pos": torch.tensor([0, 1, 2, 0, 1, 0]),
"seq_lens": torch.tensor([3, 2, 1]),
},
{
"tokens": torch.tensor([13, 14, 15, 16, 17, 18]),
"labels": torch.tensor([19, 20, 21, 22, 23, 24]),
"input_pos": torch.tensor([0, 1, 0, 1, 0, 1]),
"seq_lens": torch.tensor([2, 2, 2]),
},
]
collated = padded_collate_packed(
batch=token_pairs,
)
torch.testing.assert_close(
collated["tokens"],
torch.tensor([[1, 2, 3, 4, 5, 6], [13, 14, 15, 16, 17, 18]]),
)
torch.testing.assert_close(
collated["labels"],
torch.tensor([[7, 8, 9, 10, 11, 12], [19, 20, 21, 22, 23, 24]]),
)
torch.testing.assert_close(
collated["input_pos"],
torch.tensor([[0, 1, 2, 0, 1, 0], [0, 1, 0, 1, 0, 1]]),
)
torch.testing.assert_close(
collated["mask"],
torch.tensor(
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 1, 1],
],
],
dtype=torch.bool,
),
)

@pytest.mark.skipif(
not _SUPPORTS_FLEX_ATTENTION,
reason="Please install a nightly build of torch to run this test.",
)
@gpu_test(gpu_count=1)
def test_padded_collate_packed_flex(self):
# create_block_mask requires that seq_len be divisible by 128, the default block size.
# see https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L636
batch = [
{
"tokens": torch.arange(128, dtype=torch.long),
"labels": torch.arange(128, dtype=torch.long),
"input_pos": torch.arange(128, dtype=torch.long),
"seq_lens": torch.ones(64, dtype=torch.long) * 2,
},
{
"tokens": torch.arange(128, 256, dtype=torch.long),
"labels": torch.arange(128, 256, dtype=torch.long),
"input_pos": torch.arange(128, 256, dtype=torch.long),
"seq_lens": torch.ones(32, dtype=torch.long) * 4,
},
]
collated = padded_collate_packed(
batch=batch,
)
torch.testing.assert_close(
collated["tokens"],
torch.stack(
[
torch.arange(128, dtype=torch.long),
torch.arange(128, 256, dtype=torch.long),
]
),
)
torch.testing.assert_close(
collated["labels"],
torch.stack(
[
torch.arange(128, dtype=torch.long),
torch.arange(128, 256, dtype=torch.long),
]
),
)
torch.testing.assert_close(
collated["input_pos"],
torch.stack(
[
torch.arange(128, dtype=torch.long),
torch.arange(128, 256, dtype=torch.long),
]
),
)
torch.testing.assert_close(
collated["mask"].to_dense(),
torch.tensor([[[[1]]], [[[1]]]], dtype=torch.int32, device="cuda"),
)


class TestLeftPadSequence:
def test_left_pad_sequence(self):
Expand Down
Loading

0 comments on commit 8451b0d

Please sign in to comment.