Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IREE numerics test for Llama 3.1 8B FP16 TP8 #394

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@
from sharktank.layers import *
from sharktank.types import *

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.llama.sharding import shard_theta
from ..models.mixtral.mixtral import *
from ..models.grok.grok import *
from .. import ops


def main():
def main(raw_args: list[str] | None = None):
from ..utils import cli

parser = cli.create_parser()
Expand Down Expand Up @@ -60,7 +57,7 @@ def main():
choices=["decomposed", "torch"],
)

args = cli.parse(parser)
args = cli.parse(parser, args=raw_args)
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)
Expand Down Expand Up @@ -110,7 +107,7 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):

fxb = FxProgramsBuilder(model)

def setup_cache(model, shard_count):
def setup_cache(model):
if model.config.kv_cache_type == "paged":
cache_state = model.cache.allocate(
page_count=hp.context_length // llama_config.block_seq_stride
Expand Down Expand Up @@ -161,7 +158,7 @@ def generate_batch_prefill(bs: int):
sl_dim = llama_config.block_seq_stride * block_dim

cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
model, llama_config.tensor_parallelism_size
model
)

# We need to offset the indices for the cache
Expand Down Expand Up @@ -234,7 +231,7 @@ def generate_batch_decode(bs: int):
cache_shard_dim,
cache_dynamic_shapes,
arg_affinities,
) = setup_cache(model, llama_config.tensor_parallelism_size)
) = setup_cache(model)

# We need to offset the indices for the cache
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}
Expand Down
5 changes: 4 additions & 1 deletion sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def shard_state(
"""Shard an unsharded state.
We can't just split the slab on the sub page dims.
First it needs to be reinterpreted into the actual shape.
The split the head dimension, then flatten each shard.
Then split the head dimension, then flatten each shard.
This is a work-around for the lack of block-cyclic sharded tensor type."""
if self.shard_count == 1:
return state
Expand All @@ -324,6 +324,9 @@ def shard_state(
flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1)
return [flat_sharded_page_table]

def unshard_state(self, state: list[SplitPrimitiveTensor]) -> list[torch.Tensor]:
return [ops.unshard(self.unflatten_page_table(state)).flatten(start_dim=1)]

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride
Expand Down
23 changes: 0 additions & 23 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,6 @@ def decode(
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)

if self.config.tensor_parallelism_size > 1:
if not isinstance(tokens, ReplicatedTensor):
tokens = ops.replicate(
tokens, count=self.config.tensor_parallelism_size
)
if not isinstance(attention_mask, ReplicatedTensor):
attention_mask = ops.replicate(
attention_mask, count=self.config.tensor_parallelism_size
)
if not isinstance(start_positions, ReplicatedTensor):
start_positions = ops.replicate(
start_positions, count=self.config.tensor_parallelism_size
)
if not isinstance(seq_block_ids, ReplicatedTensor):
seq_block_ids = ops.replicate(
seq_block_ids, count=self.config.tensor_parallelism_size
)
# If the user provided unsharded arguments they probably want
# an unsharded result as well.
unshard_result = True
else:
unshard_result = False

bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
Expand Down
33 changes: 33 additions & 0 deletions sharktank/sharktank/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
from typing import Any, Callable
from operator import eq
from collections.abc import Iterable
import pytest
from sharktank.utils.tokenizer import InferenceTokenizer

from ..types import *

longrun = pytest.mark.skipif("not config.getoption('longrun')")

# Range of torch.rand() is [0,1)
# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
def make_rand_torch(shape, dtype=torch.float32):
Expand All @@ -31,6 +35,16 @@ def tearDown(self):
shutil.rmtree(self._temp_dir, ignore_errors=True)


@pytest.mark.usefixtures("path_prefix")
class PathPrefixTestBase(TempDirTestBase):
"""Creates a temporary directory and uses it if a path prefix is not given."""

def setUp(self):
super().setUp()
if self.path_prefix is None:
self.path_prefix = f"{self._temp_dir}/"


class MainRunnerTestBase(TempDirTestBase):
"""Performs an in-process test of a `main(args)` func."""

Expand All @@ -54,6 +68,25 @@ def assertFileWritten(self, p: Path):
self.assertGreater(p.stat().st_size, 0, msg=f"Expected file {p} had zero size")


class ModuloTokenizer(InferenceTokenizer):
"""A tokenizer used for testing where we take a modulo of each character.
Guarantees that we are producing tokens of up to the max token ID."""

def __init__(self, vocabulary_size: int):
self.vocabulary_size = vocabulary_size

def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]:
return [
[ord(character) % self.vocabulary_size for character in text]
for text in texts
]

def _decode(self, tokens: list[list[int]]) -> list[str]:
return [
"".join([chr(token) for token in prompt_tokens]) for prompt_tokens in tokens
]


@contextlib.contextmanager
def temporary_directory(identifier: str):
"""Returns a context manager TemporaryDirectory suitable for testing.
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def pad_tokens(
return token_ids, lengths

@abstractmethod
def _encode(self, texts: list[str]) -> list[list[int]]:
def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]:
...

@abstractmethod
Expand Down
Loading
Loading