From 00244e6eed5194239b271d580429109d31b78e95 Mon Sep 17 00:00:00 2001 From: Stephen Baione Date: Mon, 4 Nov 2024 17:53:37 -0600 Subject: [PATCH] Move setup funcs to `conftest.py`, Add logging throughout test, Move integration test to `build_tools/integration_tests/llm`, Rename ci file --- ...and-shortfin.yml => ci-shark-platform.yml} | 6 +- .../conftest.py} | 141 ++++++++---------- .../llm/cpu_llm_server_test.py | 85 +++++++++++ 3 files changed, 150 insertions(+), 82 deletions(-) rename .github/workflows/{ci-sharktank-and-shortfin.yml => ci-shark-platform.yml} (94%) rename build_tools/integration_tests/{cpu_llm_server_test.py => llm/conftest.py} (61%) create mode 100644 build_tools/integration_tests/llm/cpu_llm_server_test.py diff --git a/.github/workflows/ci-sharktank-and-shortfin.yml b/.github/workflows/ci-shark-platform.yml similarity index 94% rename from .github/workflows/ci-sharktank-and-shortfin.yml rename to .github/workflows/ci-shark-platform.yml index a798fdfea..d9f4a35da 100644 --- a/.github/workflows/ci-sharktank-and-shortfin.yml +++ b/.github/workflows/ci-shark-platform.yml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - sharktank and shortfin +name: CI - shark-platform on: workflow_dispatch: @@ -70,5 +70,5 @@ jobs: iree-runtime \ "numpy<2.0" - - name: Run shortfin LLM Server Integration Test - run: pytest -v build_tools/integration_tests/ + - name: Run LLM Integration Tests + run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO diff --git a/build_tools/integration_tests/cpu_llm_server_test.py b/build_tools/integration_tests/llm/conftest.py similarity index 61% rename from build_tools/integration_tests/cpu_llm_server_test.py rename to build_tools/integration_tests/llm/conftest.py index f0389f90c..1bc014e63 100644 --- a/build_tools/integration_tests/cpu_llm_server_test.py +++ b/build_tools/integration_tests/llm/conftest.py @@ -1,10 +1,5 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - import json +import logging import os from pathlib import Path import pytest @@ -12,30 +7,31 @@ import shutil import subprocess import time -import uuid pytest.importorskip("transformers") from transformers import AutoTokenizer -CPU_SETTINGS = { - "device_flags": [ - "-iree-hal-target-backends=llvm-cpu", - "--iree-llvmcpu-target-cpu=host", - ], - "device": "local-task", -} -IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100") -gpu_settings = { - "device_flags": [ - "-iree-hal-target-backends=rocm", - f"--iree-hip-target={IREE_HIP_TARGET}", - ], - "device": "hip", -} +logger = logging.getLogger(__name__) @pytest.fixture(scope="module") def model_test_dir(request, tmp_path_factory): + """Prepare model artifacts for starting the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - repo_id (str): The Hugging Face repo ID. + - model_file (str): The model file to download. + - tokenizer_id (str): The tokenizer ID to download. + - settings (dict): The settings for sharktank export. + - batch_sizes (list): The batch sizes to use for the model. + tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. + + Yields: + Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. + """ + logger.info("Preparing model artifacts...") + repo_id = request.param["repo_id"] model_file = request.param["model_file"] tokenizer_id = request.param["tokenizer_id"] @@ -48,25 +44,43 @@ def model_test_dir(request, tmp_path_factory): try: # Download model if it doesn't exist model_path = hf_home / model_file + logger.info(f"Preparing model_path: {model_path}..") if not os.path.exists(model_path): + logger.info( + f"Downloading model {repo_id} {model_file} from Hugging Face..." + ) subprocess.run( f"huggingface-cli download --local-dir {hf_home} {repo_id} {model_file}", shell=True, check=True, ) + logger.info(f"Model downloaded to {model_path}") + else: + logger.info("Using cached model") # Set up tokenizer if it doesn't exist tokenizer_path = hf_home / "tokenizer.json" + logger.info(f"Preparing tokenizer_path: {tokenizer_path}...") if not os.path.exists(tokenizer_path): + logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...") tokenizer = AutoTokenizer.from_pretrained( tokenizer_id, ) tokenizer.save_pretrained(hf_home) + logger.info(f"Tokenizer saved to {tokenizer_path}") + else: + logger.info("Using cached tokenizer") - # Export model if it doesn't exist + # Export model mlir_path = tmp_dir / "model.mlir" config_path = tmp_dir / "config.json" bs_string = ",".join(map(str, batch_sizes)) + logger.info( + "Exporting model with following settings:\n" + f" MLIR Path: {mlir_path}\n" + f" Config Path: {config_path}\n" + f" Batch Sizes: {bs_string}" + ) subprocess.run( [ "python", @@ -79,9 +93,11 @@ def model_test_dir(request, tmp_path_factory): ], check=True, ) + logger.info(f"Model successfully exported to {mlir_path}") - # Compile model if it doesn't exist + # Compile model vmfb_path = tmp_dir / "model.vmfb" + logger.info(f"Compiling model to {vmfb_path}") subprocess.run( [ "iree-compile", @@ -92,6 +108,7 @@ def model_test_dir(request, tmp_path_factory): + settings["device_flags"], check=True, ) + logger.info(f"Model successfully compiled to {vmfb_path}") # Write config if it doesn't exist edited_config_path = tmp_dir / "edited_config.json" @@ -106,8 +123,11 @@ def model_test_dir(request, tmp_path_factory): "transformer_block_count": 26, "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, } + logger.info(f"Saving edited config to: {edited_config_path}\n") + logger.info(f"Config: {json.dumps(config, indent=2)}") with open(edited_config_path, "w") as f: json.dump(config, f) + logger.info("Model artifacts setup successfully") yield hf_home, tmp_dir finally: shutil.rmtree(tmp_dir) @@ -117,6 +137,8 @@ def model_test_dir(request, tmp_path_factory): def available_port(port=8000, max_port=8100): import socket + logger.info(f"Finding available port in range {port}-{max_port}...") + starting_port = port while port < max_port: @@ -124,6 +146,7 @@ def available_port(port=8000, max_port=8100): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("localhost", port)) s.close() + logger.info(f"Found available port: {port}") return port except socket.error: port += 1 @@ -132,10 +155,12 @@ def available_port(port=8000, max_port=8100): def wait_for_server(url, timeout=10): + logger.info(f"Waiting for server to start at {url}...") start = time.time() while time.time() - start < timeout: try: requests.get(f"{url}/health") + logger.info("Server successfully started") return except requests.exceptions.ConnectionError: time.sleep(1) @@ -144,6 +169,19 @@ def wait_for_server(url, timeout=10): @pytest.fixture(scope="module") def llm_server(request, model_test_dir, available_port): + """Start the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - model_file (str): The model file to download. + - settings (dict): The settings for starting the server. + model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. + available_port (int): The available port to start the server on. + + Yields: + subprocess.Popen: The server process that was started. + """ + logger.info("Starting LLM server...") # Start the server hf_home, tmp_dir = model_test_dir model_file = request.param["model_file"] @@ -166,58 +204,3 @@ def llm_server(request, model_test_dir, available_port): # Teardown: kill the server server_process.terminate() server_process.wait() - - -def do_generate(prompt, port): - headers = {"Content-Type": "application/json"} - # Create a GenerateReqInput-like structure - data = { - "text": prompt, - "sampling_params": {"max_tokens": 50, "temperature": 0.7}, - "rid": uuid.uuid4().hex, - "return_logprob": False, - "logprob_start_len": -1, - "top_logprobs_num": 0, - "return_text_in_logprobs": False, - "stream": False, - } - print("Prompt text:") - print(data["text"]) - BASE_URL = f"http://localhost:{port}" - response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data) - print(f"Generate endpoint status code: {response.status_code}") - if response.status_code == 200: - print("Generated text:") - data = response.text - assert data.startswith("data: ") - data = data[6:] - assert data.endswith("\n\n") - data = data[:-2] - return data - else: - response.raise_for_status() - - -@pytest.mark.parametrize( - "model_test_dir,llm_server", - [ - ( - { - "repo_id": "SlyEcho/open_llama_3b_v2_gguf", - "model_file": "open-llama-3b-v2-f16.gguf", - "tokenizer_id": "openlm-research/open_llama_3b_v2", - "settings": CPU_SETTINGS, - "batch_sizes": [1, 4], - }, - {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - ) - ], - indirect=True, -) -def test_llm_server(llm_server, available_port): - # Here you would typically make requests to your server - # and assert on the responses - assert llm_server.poll() is None - output = do_generate("1 2 3 4 5 ", available_port) - print(output) - assert output.startswith("6 7 8") diff --git a/build_tools/integration_tests/llm/cpu_llm_server_test.py b/build_tools/integration_tests/llm/cpu_llm_server_test.py new file mode 100644 index 000000000..1b27e12da --- /dev/null +++ b/build_tools/integration_tests/llm/cpu_llm_server_test.py @@ -0,0 +1,85 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import os +import pytest +import requests +import uuid + +logger = logging.getLogger(__name__) + +CPU_SETTINGS = { + "device_flags": [ + "-iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu=host", + ], + "device": "local-task", +} +IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100") +gpu_settings = { + "device_flags": [ + "-iree-hal-target-backends=rocm", + f"--iree-hip-target={IREE_HIP_TARGET}", + ], + "device": "hip", +} + + +def do_generate(prompt, port): + logger.info("Generating request...") + headers = {"Content-Type": "application/json"} + # Create a GenerateReqInput-like structure + data = { + "text": prompt, + "sampling_params": {"max_tokens": 50, "temperature": 0.7}, + "rid": uuid.uuid4().hex, + "return_logprob": False, + "logprob_start_len": -1, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "stream": False, + } + logger.info("Prompt text:") + logger.info(data["text"]) + BASE_URL = f"http://localhost:{port}" + response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data) + logger.info(f"Generate endpoint status code: {response.status_code}") + if response.status_code == 200: + logger.info("Generated text:") + data = response.text + assert data.startswith("data: ") + data = data[6:] + assert data.endswith("\n\n") + data = data[:-2] + return data + else: + response.raise_for_status() + + +@pytest.mark.parametrize( + "model_test_dir,llm_server", + [ + ( + { + "repo_id": "SlyEcho/open_llama_3b_v2_gguf", + "model_file": "open-llama-3b-v2-f16.gguf", + "tokenizer_id": "openlm-research/open_llama_3b_v2", + "settings": CPU_SETTINGS, + "batch_sizes": [1, 4], + }, + {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, + ) + ], + indirect=True, +) +def test_llm_server(llm_server, available_port): + # Here you would typically make requests to your server + # and assert on the responses + assert llm_server.poll() is None + output = do_generate("1 2 3 4 5 ", available_port) + logger.info(output) + assert output.startswith("6 7 8")