diff --git a/build_tools/integration_tests/llm/conftest.py b/build_tools/integration_tests/llm/conftest.py index 1103065bc..3b69b8a41 100644 --- a/build_tools/integration_tests/llm/conftest.py +++ b/build_tools/integration_tests/llm/conftest.py @@ -117,22 +117,6 @@ def model_test_dir(request, tmp_path_factory): logger.info(f"Model successfully compiled to {vmfb_path}") # Write config if it doesn't exist - edited_config_path = tmp_dir / "edited_config.json" - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - logger.info(f"Saving edited config to: {edited_config_path}\n") - logger.info(f"Config: {json.dumps(config, indent=2)}") - with open(edited_config_path, "w") as f: - json.dump(config, f) logger.info("Model artifacts setup successfully") yield hf_home, tmp_dir finally: @@ -198,7 +182,8 @@ def llm_server(request, model_test_dir, available_port): "-m", "shortfin_apps.llm.server", f"--tokenizer_json={hf_home / 'tokenizer.json'}", - f"--model_config={tmp_dir / 'edited_config.json'}", + f"--tokenizer_config_json={hf_home / 'tokenizer_config.json'}", + f"--model_config={tmp_dir / 'config.json'}", f"--vmfb={tmp_dir / 'model.vmfb'}", f"--parameters={hf_home / model_file}", f"--device={settings['device']}", diff --git a/build_tools/integration_tests/llm/cpu_llm_server_test.py b/build_tools/integration_tests/llm/cpu_llm_server_test.py index 638bce7ee..9f9fd704f 100644 --- a/build_tools/integration_tests/llm/cpu_llm_server_test.py +++ b/build_tools/integration_tests/llm/cpu_llm_server_test.py @@ -37,7 +37,10 @@ def do_generate(prompt, port): # Create a GenerateReqInput-like structure data = { "text": prompt, - "sampling_params": {"max_completion_tokens": 50, "temperature": 0.7}, + "sampling_params": { + "max_completion_tokens": 20, # enough to span multiple pages + "temperature": 0.7, + }, "rid": uuid.uuid4().hex, "return_logprob": False, "logprob_start_len": -1, diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 7bf76a2ce..0eca86d3b 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -7,6 +7,7 @@ """Export support for the PagedLLMV1 protocol of models.""" import json +from typing import Any, Dict import torch from iree.turbine.aot import * @@ -89,17 +90,29 @@ def main(): else: model = PagedLlamaModelV1(dataset.root_theta, llama_config) - def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): + def generate_params_json( + hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int] + ) -> Dict[str, Any]: + """ + Generate config.json for shortfin. + + + For shortfin, we only write attention_head_count_kv because that's all shortfin needs. + Note that this is different from hp.attn_head_count when grouped attention shares kvcache between heads. + """ return { "module_name": "module", "module_abi_version": 1, "max_seq_len": hp.context_length, - "attn_head_count": hp.attention_head_count, "attn_head_dim": hp.attn_head_dim, "prefill_batch_sizes": prefill_bs, "decode_batch_sizes": decode_bs, "transformer_block_count": hp.block_count, - "block_seq_stride": llama_config.block_seq_stride, + "paged_kv_cache": { + "attention_head_count_kv": hp.attention_head_count_kv, + "block_seq_stride": llama_config.block_seq_stride, + "device_block_count": 256, # so that this makes its way into the config file & can be edited. + }, } # Unrolling cache updates by batch row makes dynamo sad without an diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 141c7a7eb..8ff826d38 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -11,20 +11,20 @@ In a typical transformer model, the KV cache is organized similar to (mapped to our parameter names below): k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count, attn_head_dim) + attn_head_count_kv, attn_head_dim) v = ... For context, a popular model has parameters of: attn_dtype_size = 2 # (fp16) max_seq_len = 2048 transformer_block_count = 32 - attn_head_count = 32 + attn_head_count_kv = 32 attn_head_dim = 128 # (dim / head_count) If paging, then we primarily care about the organization of a single block, where a block represents a single position in the sequence for a single item in the batch. Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) + block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim) In this scenario, we declare that one block holds the KV cache for all transformer block layers because it reduces the accounting. As such, for the above example, @@ -80,29 +80,38 @@ def _decode_dtype(name: str) -> sfnp.DType: class PagedKVCacheParams: """Parameters for the paged KV cache.""" - # Position stride per attention block + # Tokens per page. block_seq_stride: int + # Number of attention heads per block. This can be different from the model's + # attention head count due to sharing. + attention_head_count_kv: int + # Size of the cache on each device. + # Default: 256 device_block_count: int @dataclass_json(undefined=Undefined.RAISE) @dataclass class ModelParams: - """Parameters for a specific compiled model, sufficient to do cache planning and - invocations.""" + """ + Parameters for a specific compiled model, sufficient to do cache planning and + invocations. + + Compatibility should be maintained with function generate_params_json in + + sharktank/sharktank/examples/export_paged_llm_v1.py + """ # Maximum length of a sequence including prompt and output. max_seq_len: int - # Number of transformer blocks. + # Number of transformer layers (aka attention blocks / transformer blocks). transformer_block_count: int - # Number of attention heads per block. - attn_head_count: int - - # Dimensionality of each attention head + # Dimensionality of each attention head. This is the dimensionality of the + # key and value vectors. AKA rope_dimension_count from the GGUF props. attn_head_dim: int # Batch sizes that the prefill stage is compiled for. These are expected to be @@ -157,7 +166,7 @@ def paged_kv_unit_size_elements(self) -> int: size = 1 size *= self.transformer_block_count size *= 2 # K and V cache line - size *= self.attn_head_count + size *= self.paged_kv_cache.attention_head_count_kv size *= self.attn_head_dim return size diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py index 169d082b1..397b8d653 100644 --- a/shortfin/tests/apps/llm/components/cache_test.py +++ b/shortfin/tests/apps/llm/components/cache_test.py @@ -45,12 +45,15 @@ def model_params(): "module_name": "module", "module_abi_version": 1, "max_seq_len": 2048, - "attn_head_count": 32, "attn_head_dim": 100, "prefill_batch_sizes": [4], "decode_batch_sizes": [4], "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + "paged_kv_cache": { + "attention_head_count_kv": 32, + "block_seq_stride": 16, + "device_block_count": 256, + }, } # Create a temporary file to store the JSON