Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefill 8B f16 torch sdpa test, update tests with compile flags and tp flags, with nightly iree #456

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
13 changes: 7 additions & 6 deletions .github/workflows/ci-llama-large-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,18 @@ jobs:
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/

# Install latest iree-tubrine.
# Install latest iree-turbine.
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"

# Test with pinned nightly releases, not what iree-turbine uses.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
iree-base-compiler==3.0.0rc20241115 \
iree-base-runtime==3.0.0rc20241115

# Test with nightly releases, not what iree-turbine uses.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
iree-base-runtime \

- name: Run llama tests
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-all-llama --iree-hip-target=gfx942 --html=out/index.html
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
15 changes: 8 additions & 7 deletions .github/workflows/ci-llama-quick-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,18 @@ jobs:
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/

# Install latest iree-tubrine.
# Install latest iree-turbine.
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"

# Test with pinned nightly releases, not what iree-turbine uses.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
iree-base-compiler==3.0.0rc20241115 \
iree-base-runtime==3.0.0rc20241115

- name: Run llama 8b tests
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-8b-llama
# Test with nightly releases, not what iree-turbine uses.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
iree-base-runtime \

- name: Run llama 8b f16 decomposed test
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-quick-llama-test

- name: Upload llama executable files
uses: actions/upload-artifact@v4
Expand Down
10 changes: 5 additions & 5 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ def pytest_addoption(parser):
)

parser.addoption(
"--run-8b-llama",
"--run-quick-llama-test",
action="store_true",
dest="run-8b-llama",
dest="run-quick-llama-test",
default=False,
help="Enable llama 8b benchmarking tests",
help="Enable llama 8b f16 decomposed benchmarking test",
)

parser.addoption(
"--run-all-llama",
"--run-nightly-llama-tests",
action="store_true",
dest="run-all-llama",
dest="run-nightly-llama-tests",
default=False,
help="Enable all llama benchmarking tests",
)
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def main():
help="Enables strictness during export",
action="store_true",
)

cli.add_quantization_options(parser)
cli.add_model_options(parser)
args = cli.parse(parser)
Expand Down Expand Up @@ -312,7 +313,8 @@ def _(
bsizes = []
for bs in args.bs:
generate_batch_prefill(bs)
generate_batch_decode(bs)
if not args.skip_decode:
generate_batch_decode(bs)
bsizes.append(bs)
config = generate_params_json(hp, bsizes, bsizes)
print("GENERATED!")
Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def add_model_options(parser: argparse.ArgumentParser):
default="decomposed",
choices=["decomposed", "torch"],
)
parser.add_argument(
"--skip-decode",
help="Enables prefill only, skips decode",
action="store_true",
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down
14 changes: 10 additions & 4 deletions sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def wrapper(*args, **kwargs):
def shard_irpa_file(
self,
*,
gguf_file: str,
irpa_file: str,
output_irpa: str,
):
shard_irpa_args = [
"python3",
"-m",
"sharktank.examples.sharding.shard_llm_dataset",
"--gguf-file",
gguf_file,
"--irpa-file",
irpa_file,
"--output-irpa-file",
output_irpa,
"--tensor-parallelism-size",
Expand Down Expand Up @@ -160,6 +160,7 @@ def export_to_mlir(
*,
mlir_path: str,
json_path: str,
skip_decode: Optional[bool] = None,
):
export_args = [
"python3",
Expand All @@ -170,6 +171,8 @@ def export_to_mlir(
f"--output-config={json_path}",
f"--bs={str(self.batch_size)}",
]
if skip_decode:
export_args.append("--skip-decode")
if self.attention_kernel in ["decomposed", "torch"]:
export_args.append("--attention-kernel")
export_args.append(self.attention_kernel)
Expand All @@ -195,6 +198,7 @@ def compile_to_vmfb(
vmfb_path,
cwd,
hal_dump_path: Optional[Path] = None,
args: Optional[List[str]] = None,
):
# TODO: Control flag to enable multiple backends
compile_args = [
Expand All @@ -214,7 +218,9 @@ def compile_to_vmfb(
compile_args += [
f"--iree-hal-dump-executable-files-to={hal_dump_path}/files"
]

# Append optional arguments if provided
if args:
compile_args += args
cmd = subprocess.list2cmdline(compile_args)

logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}")
Expand Down
14 changes: 14 additions & 0 deletions sharktank/tests/models/llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# How to run Llama 3.1 Benchmarking Tests
In order to run Llama 3.1 8B F16 Decomposed test:
```
pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s \
--run-quick-test --iree-hip-target=gfx942
```

In order to filter by test, use the -k option. If you
wanted to only run the Llama 3.1 70B F16 Decomposed test:
```
pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s \
--run-nightly-llama-tests --iree-hip-target=gfx942 \
-k 'testBenchmark70B_f16_TP8_Decomposed'
```
72 changes: 64 additions & 8 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
)

is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
skipif_run_8b_llama = pytest.mark.skipif(
'config.getoption("run-8b-llama") and not config.getoption("run-all-llama")',
reason="Skipping largs tests when --run-8b is set.",
skipif_run_quick_llama_test = pytest.mark.skipif(
'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")',
reason="Skipping largs tests when --run-quick-llama-test is set.",
)


Expand All @@ -49,6 +49,13 @@ def setUpClass(cls):

def setUp(self):
self.hip_device_id = os.getenv("HIP_DEVICE_ID", default="0")
self.compile_args = [
"--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-global-opt-propagate-transposes=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-data-tiling=false",
"--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'",
]


@is_mi300x
Expand All @@ -57,7 +64,6 @@ def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/data/llama-3.1/weights/8b")
self.gguf_path = self.artifacts_dir / "fp16/llama3.1_8b_fp16.gguf"
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_8b_fp16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama8b_fp8.irpa"
self.tensor_parallelism_size = 1
Expand Down Expand Up @@ -155,6 +161,7 @@ def testBenchmark8B_f16_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb(
Expand All @@ -173,6 +180,41 @@ def testBenchmark8B_f16_Decomposed(self):
cwd=self.repo_root,
)

@skipif_run_quick_llama_test
def testBenchmark8B_f16_Non_Decomposed_Prefill(self):
output_file_name = self.dir_path_8b / "f16_torch_prefill"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
)
output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".json", prefix=output_file_name
)
output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".vmfb", prefix=output_file_name
)
self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch"
export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
skip_decode=True,
)
self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.hip_device_id,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_prefill_args,
cwd=self.repo_root,
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def testBenchmark8B_f16_Non_Decomposed(self):
output_file_name = self.dir_path_8b / "f16_torch"
Expand All @@ -195,6 +237,7 @@ def testBenchmark8B_f16_Non_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -236,6 +279,7 @@ def testBenchmark8B_fp8_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama8b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -277,6 +321,7 @@ def testBenchmark8B_fp8_Non_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
Expand All @@ -297,13 +342,12 @@ def testBenchmark8B_fp8_Non_Decomposed(self):


@is_mi300x
@skipif_run_8b_llama
@skipif_run_quick_llama_test
class BenchmarkLlama3_1_70B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/data/llama-3.1/weights/70b")
self.gguf_path = self.artifacts_dir / "fp16/llama3.1_70b_f16.gguf"
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_70b_f16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama70b_fp8.irpa"
self.tensor_parallelism_size = 8
Expand Down Expand Up @@ -380,6 +424,11 @@ def setUp(self):
f"--input=@{self.decode_args_fp8}/cache_state_f16.npy",
"--benchmark_repetitions=3",
]
self.compile_args += [
"--iree-hal-force-indirect-command-buffers=true",
"--iree-stream-resource-memory-model=discrete",
"--iree-hip-legacy-sync=false",
]

@pytest.mark.xfail(
reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
Expand Down Expand Up @@ -410,6 +459,7 @@ def testBenchmark70B_f16_TP8_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -455,6 +505,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama70b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -502,6 +553,7 @@ def testBenchmark70B_fp8_TP8_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama70b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -549,6 +601,7 @@ def testBenchmark70B_fp8_TP8_Non_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
Expand All @@ -569,14 +622,13 @@ def testBenchmark70B_fp8_TP8_Non_Decomposed(self):


@is_mi300x
@skipif_run_8b_llama
@skipif_run_quick_llama_test
class BenchmarkLlama3_1_405B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/data/llama-3.1/weights/405b")
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_405b_fp16.irpa"
self.gguf_path = self.artifacts_dir / "fp16/llama3_405b_f16.gguf"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama405b_fp8.irpa"
self.tensor_parallelism_size = 8
self.dir_path_405b = self.dir_path / "llama-405b"
Expand Down Expand Up @@ -682,6 +734,7 @@ def testBenchmark405B_f16_TP8_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -727,6 +780,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -774,6 +828,7 @@ def testBenchmark405B_fp8_TP8_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama405b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
Expand Down Expand Up @@ -821,6 +876,7 @@ def testBenchmark405B_fp8_TP8_Non_Decomposed(self):
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
Expand Down
Loading