diff --git a/.github/workflows/cpu-test.yml b/.github/workflows/cpu-test.yml new file mode 100644 index 0000000000000..89a702f9751d9 --- /dev/null +++ b/.github/workflows/cpu-test.yml @@ -0,0 +1,34 @@ +name: cpu-test + +on: + # Trigger the workflow on push or pull request, + # but only for the habana_main branch + push: + branches: + - habana_main + pull_request: + branches: + - habana_main + + +jobs: + cputest: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + pip install -r requirements-hpu.txt + VLLM_TARGET_DEVICE=hpu python setup.py develop + - name: cpu-test + run: | + VLLM_SKIP_WARMUP=true VLLM_PROMPT_SEQ_BUCKET_MAX=128 VLLM_USE_FAKE_HPU=1 python examples/offline_inference_fakehpu.py diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 6ebe512c5dbf6..42c141237fb15 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -32,15 +32,23 @@ jobs: pip install types-setuptools - name: Mypy run: | - mypy - mypy tests --follow-imports skip - mypy vllm/attention --follow-imports skip - mypy vllm/distributed --follow-imports skip - mypy vllm/engine --follow-imports skip - mypy vllm/executor --follow-imports skip - mypy vllm/lora --follow-imports skip - mypy vllm/model_executor --follow-imports skip - mypy vllm/prompt_adapter --follow-imports skip - mypy vllm/spec_decode --follow-imports skip - mypy vllm/worker --follow-imports skip + mypy tests --config-file pyproject.toml + mypy vllm/*.py --config-file pyproject.toml + mypy vllm/attention --config-file pyproject.toml + mypy vllm/core --config-file pyproject.toml + mypy vllm/distributed --config-file pyproject.toml + mypy vllm/engine --config-file pyproject.toml + mypy vllm/entrypoints --config-file pyproject.toml + mypy vllm/executor --config-file pyproject.toml + mypy vllm/inputs --config-file pyproject.toml + mypy vllm/logging --config-file pyproject.toml + mypy vllm/lora --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml + mypy vllm/multimodal --config-file pyproject.toml + mypy vllm/platforms --config-file pyproject.toml + mypy vllm/spec_decode --config-file pyproject.toml + mypy vllm/transformers_utils --config-file pyproject.toml + mypy vllm/usage --config-file pyproject.toml + mypy vllm/worker --config-file pyproject.toml + diff --git a/Dockerfile.hpu b/Dockerfile.hpu new file mode 100644 index 0000000000000..ab714cdac4670 --- /dev/null +++ b/Dockerfile.hpu @@ -0,0 +1,18 @@ +FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements-hpu.txt + +ENV no_proxy=localhost,127.0.0.1 +ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true + +RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install + +WORKDIR /workspace/ + +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/README_GAUDI.md b/README_GAUDI.md index 5109f7ddf9927..0ef30d5f96e64 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -81,6 +81,7 @@ Supported Features - Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) for accelerating low-batch latency and throughput +- INC quantization Unsupported Features ==================== @@ -88,7 +89,7 @@ Unsupported Features - Beam search - LoRA adapters - Attention with Linear Biases (ALiBi) -- Quantization (AWQ, FP8 E5M2, FP8 E4M3) +- AWQ quantization - Prefill chunking (mixed-batch inferencing) Supported Configurations @@ -315,9 +316,9 @@ mark 90% of free device memory at that point as usable. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable `VLLM_GRAPH_RESERVED_MEM` defines the ratio of memory reserved for HPU Graphs capture. With its default value -(`VLLM_GRAPH_RESERVED_MEM=0.4`), 40% of usable memory will be reserved +(`VLLM_GRAPH_RESERVED_MEM=0.1`), 10% of usable memory will be reserved for graph capture (later referred to as \"usable graph memory\"), and -the remaining 60% will be utilized for KV cache. Environment variable +the remaining 90% will be utilized for KV cache. Environment variable `VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (`VLLM_GRAPH_PROMPT_RATIO=0.5`), both stages have equal memory @@ -445,7 +446,7 @@ Environment variables - `VLLM_SKIP_WARMUP`: if `true`, warmup will be skipped, `false` by default - `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for - HPUGraph capture, `0.4` by default + HPUGraph capture, `0.1` by default - `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory dedicated for prompt graphs, `0.5` by default - `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index ed3beabb2c8aa..4c094eaec842a 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -76,6 +76,7 @@ Supported Features - Tensor parallelism support for multi-card inference - Inference with `HPU Graphs `__ for accelerating low-batch latency and throughput +- INC quantization Unsupported Features ==================== @@ -83,7 +84,7 @@ Unsupported Features - Beam search - LoRA adapters - Attention with Linear Biases (ALiBi) -- Quantization (AWQ, FP8 E5M2, FP8 E4M3) +- AWQ quantization - Prefill chunking (mixed-batch inferencing) Supported Configurations @@ -243,7 +244,7 @@ Before KV cache gets allocated, model weights are loaded onto the device, and a Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture. -With its default value (``VLLM_GRAPH_RESERVED_MEM=0.4``), 40% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 60% will be utilized for KV cache. +With its default value (``VLLM_GRAPH_RESERVED_MEM=0.1``), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache. Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. @@ -322,14 +323,14 @@ Environment variables **Performance tuning knobs:** - ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default -- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.4`` by default +- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.1`` by default - ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default - ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default - ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default - ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism - ``{phase}`` is either ``PROMPT`` or ``DECODE`` - - ``{dim}`` is either ``BS`` or ``SEQ`` + - ``{dim}`` is either ``BS``, ``SEQ`` or ``BLOCK`` - ``{param}`` is either ``MIN``, ``STEP`` or ``MAX`` - Default values: @@ -345,9 +346,9 @@ Environment variables - batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)`` - batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)`` - batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs`` - - sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128`` - - sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128`` - - sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` + - sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``128`` + - sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``128`` + - sequence length max (``VLLM_DECODE_BLOCK_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)`` Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: diff --git a/examples/offline_inference_fakehpu.py b/examples/offline_inference_fakehpu.py new file mode 100644 index 0000000000000..972d84b60b318 --- /dev/null +++ b/examples/offline_inference_fakehpu.py @@ -0,0 +1,38 @@ +import os + +from vllm import LLM, SamplingParams + +if os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0': + from vllm.utils import migrate_to_cpu + migrate_to_cpu() + +# Sample prompts. +prompts = [ + "Berlin is the capital city of ", + "Louvre is located in the city of ", + "Barack Obama was the 44th president of ", + "Warsaw is the capital city of ", + "Gniezno is a city in ", + "San Francisco is located in the state of ", + "Llanfairpwllgwyngyll is located in country of ", +] +ref_answers = [ + "Germany", "Paris", "United States", "Poland", "Poland", "California", + "Wales" +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output, answer in zip(outputs, ref_answers): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert answer in generated_text, ( + f"The generated text does not contain the correct answer: {answer}") +print('PASSED') diff --git a/format.sh b/format.sh index 2204b3ba59498..623525e64bd84 100755 --- a/format.sh +++ b/format.sh @@ -96,18 +96,23 @@ echo 'vLLM yapf: Done' # Run mypy echo 'vLLM mypy:' -mypy --follow-imports skip # Note that this is less strict than CI -mypy tests --follow-imports skip -mypy vllm/attention --follow-imports skip -mypy vllm/distributed --follow-imports skip -mypy vllm/engine --follow-imports skip -mypy vllm/executor --follow-imports skip -mypy vllm/lora --follow-imports skip -mypy vllm/model_executor --follow-imports skip -mypy vllm/prompt_adapter --follow-imports skip -mypy vllm/spec_decode --follow-imports skip -mypy vllm/worker --follow-imports skip -echo 'vLLM mypy: Done' +mypy tests --config-file pyproject.toml +mypy vllm/*.py --config-file pyproject.toml +mypy vllm/attention --config-file pyproject.toml +mypy vllm/core --config-file pyproject.toml +mypy vllm/distributed --config-file pyproject.toml +mypy vllm/engine --config-file pyproject.toml +mypy vllm/entrypoints --config-file pyproject.toml +mypy vllm/executor --config-file pyproject.toml +mypy vllm/logging --config-file pyproject.toml +mypy vllm/lora --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml +mypy vllm/multimodal --config-file pyproject.toml +mypy vllm/prompt_adapter --config-file pyproject.toml +mypy vllm/spec_decode --config-file pyproject.toml +mypy vllm/transformers_utils --config-file pyproject.toml +mypy vllm/usage --config-file pyproject.toml +mypy vllm/worker --config-file pyproject.toml # If git diff returns a file that is in the skip list, the file may be checked anyway: diff --git a/requirements-hpu.txt b/requirements-hpu.txt index e0f03c8464c7b..d451200aa1144 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -6,3 +6,4 @@ ray == 2.32.0 triton pandas tabulate +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@30ee2d1 diff --git a/tests/lora/test_lora_hpu.py b/tests/lora/test_lora_hpu.py index 57bc19b2170db..1e0e728ae7240 100644 --- a/tests/lora/test_lora_hpu.py +++ b/tests/lora/test_lora_hpu.py @@ -1,6 +1,6 @@ import pytest import torch -from vllm.hpu.ops import LoraMask +from vllm_hpu_extension.ops import LoraMask from vllm.hpu.punica_hpu import GaudiPunicaWrapper diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 19a5ca5e27502..3cb46dbc213d9 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -9,7 +9,7 @@ from transformers import GenerationConfig, GenerationMixin import vllm.envs as envs -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import ApplyToppTopkScalar, Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, @@ -781,3 +781,62 @@ def test_sampler_include_gpu_probs_tensor(device: str): assert sampler_output.sampled_token_probs is not None assert sampler_output.logprobs is not None assert sampler_output.sampled_token_ids is not None + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_topk_topk_scalar(): + obj1 = ApplyToppTopkScalar(2) + assert ApplyToppTopkScalar._padded_k == 0 + x = torch.tensor([[9, 9, 8, 8, 8, 8, 7, 7, 7.0], + [10, 10, 9, 9, 9, 8, 5, 5, 5]]) + + retval1 = obj1(x, p=0.9, k=5) + ninf = -float("inf") + expected1 = torch.tensor([[9., 9., 8., 8., 8., 8., ninf, ninf, ninf], + [10., 10., 9., 9., 9., ninf, ninf, ninf, ninf]]) + assert torch.all(retval1 == expected1).item() + assert ApplyToppTopkScalar._padded_k == 9 + + obj2 = ApplyToppTopkScalar(2) + assert obj2._padded_k == 9 + + x = torch.tensor([[2, 2, 9, 9, 2, 2, 1, 1, 1.0], + [10, 9, 9, 5, 9, 9, 5, 9, 10]]) + retval2 = obj2(x, p=0.9, k=5) + expected2 = torch.tensor( + [[ninf, ninf, 9., 9., ninf, ninf, ninf, ninf, ninf], + [10., ninf, 9., ninf, 9., 9., ninf, 9., 10.]]) + assert torch.all(retval2 == expected2).item() + assert obj2._padded_k == 9 + + retval3 = obj2(x, p=1.0, k=5) + expected3 = torch.tensor([[2., 2., 9., 9., 2., 2., ninf, ninf, ninf], + [10., 9., 9., ninf, 9., 9., ninf, 9., 10.]]) + + assert torch.all(retval3 == expected3).item() + + # this should not be done in general, doing it here for testing purposes + ApplyToppTopkScalar._padded_k = 0 + x = torch.tensor([[1, 1, 1, 9, 8, 1, 1, 1, 1.0], + [2, 1, 2, 2, 1, 1, 1, 1, 1]]) + obj3 = ApplyToppTopkScalar(2) + retval4 = obj3(x, p=0.9, k=2) + expected4 = torch.tensor( + [[ninf, ninf, ninf, 9., 8., ninf, ninf, ninf, ninf], + [2., ninf, 2., 2., ninf, ninf, ninf, ninf, ninf]]) + assert torch.all(retval4 == expected4).item() + assert obj3._padded_k == 4 + y = torch.tensor([[8, 8, 8, 9, 8, 1, 1, 1, 1.0], + [2, 1, 2, 2, 1, 1, 1, 1, 1]]) + retval5 = obj3(y, p=0.9, k=2) + assert obj3._padded_k == 8 + expected5 = torch.tensor([[8., 8., 8., 9., 8., ninf, ninf, ninf, ninf], + [2., ninf, 2., 2., ninf, ninf, ninf, ninf, + ninf]]) + assert torch.all(retval5 == expected5).item() + y = torch.tensor([[8, 8, 8, 9, 8, 8, 1, 1, 1.0], + [2, 1, 2, 2, 3, 1, 1, 1, 1]]) + retval6 = obj3(y, p=0.9, k=2) + expected6 = torch.tensor([[8., 8., 8., 9., 8., 8., ninf, ninf, ninf], + [2., ninf, 2., 2., 3., ninf, ninf, ninf, ninf]]) + assert torch.all(retval6 == expected6).item() + assert obj3._padded_k == 8 diff --git a/vllm/__init__.py b/vllm/__init__.py index 0895c571d1d89..29fc02ae3e96a 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,4 +1,8 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +from vllm.utils import is_fake_hpu, migrate_to_cpu + +if is_fake_hpu(): + migrate_to_cpu() from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 20b0f2bc7630b..b7b8072de3fe5 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -7,14 +7,14 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch +import vllm_hpu_extension.ops as ops +from vllm_hpu_extension import cache_ops +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache -import vllm.hpu.ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) -from vllm.hpu import cache_ops -from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.logger import init_logger logger = init_logger(__name__) @@ -108,17 +108,10 @@ def __init__( self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window - self.position_bias = None self.alibi_slopes = alibi_slopes if alibi_slopes is not None: - # FIXME(kzawora): Need a general method to set max_seq_len on - # per-model basis. alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=torch.bfloat16) - self.position_bias = _make_alibi_bias(alibi_slopes_tensor, - num_kv_heads, - alibi_slopes_tensor.dtype, - max_seq_len) self.alibi_slopes = alibi_slopes_tensor assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -190,11 +183,13 @@ def forward( assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward!' attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) + if self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) else: attn_bias = None diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index cab8d7abe95fd..49a3e3f774d58 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -6,8 +6,7 @@ from typing import Dict, List, Optional, Tuple import torch - -from vllm.hpu import cache_ops, ops +from vllm_hpu_extension import cache_ops, ops # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index d69a85a816636..f0822283296dd 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, - make_async) + make_async, is_fake_hpu) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -87,18 +87,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_ip = get_ip() worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("HPU", 0): + resource_name = "HPU" if not is_fake_hpu() else "CPU" + if not bundle.get(resource_name, 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) - + resources = {'HPU': num_gpus} if not is_fake_hpu() else {} + num_cpus = 0 if not is_fake_hpu() else num_gpus worker = ray.remote( - num_cpus=0, + num_cpus=num_cpus, num_gpus=0, - resources={'HPU': num_gpus}, + resources=resources, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index dff33b2c021b4..34b002514c27a 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import get_ip, is_hip, is_xpu +from vllm.utils import get_ip, is_hip, is_xpu, hpu_device_string from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -241,7 +241,7 @@ def initialize_ray_cluster( if current_platform.is_tpu(): device_str = "TPU" elif current_platform.is_hpu(): - device_str = "HPU" + device_str = hpu_device_string().upper() # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: diff --git a/vllm/hpu/__init__.py b/vllm/hpu/__init__.py deleted file mode 100644 index b8e4d3aac98a7..0000000000000 --- a/vllm/hpu/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -############################################################################### diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py deleted file mode 100644 index 9042924f68b3d..0000000000000 --- a/vllm/hpu/cache_ops.py +++ /dev/null @@ -1,107 +0,0 @@ -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -############################################################################### - -import math - -import habana_frameworks.torch as htorch -import torch - - -def reshape_and_cache(key, - value, - key_cache, - value_cache, - slot_mapping, - dtype, - is_prompt=False): - num_blocks = key_cache.size(0) - block_size = key_cache.size(1) - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - offsets = torch.fmod(slot_mapping, block_size) - num_slots_requested = slot_mapping.size(0) - num_slots_available = num_blocks * block_size - # NOTE(kzawora): HPU PT bridge crashes with - # RuntimeError: Invalid inputs for scatter_nd_onnx - # on index_put when num_slots_requested > num_slots_available. - # This case might occur when we have little kv cache blocks and - # lots of padding, or are doing warmup. - # This loop is a workaround for this issue. Please remove it - # once key_cache.index_put_(indices, offsets), key) works. - num_kv_cache_passes = math.ceil(num_slots_requested / num_slots_available) - for i in range(num_kv_cache_passes): - start_idx = i * num_slots_available - end_idx = (i + 1) * num_slots_available - key_cache.index_put_( - (indices[start_idx:end_idx], offsets[start_idx:end_idx]), - key[start_idx:end_idx]) - value_cache.index_put_( - (indices[start_idx:end_idx], offsets[start_idx:end_idx]), - value[start_idx:end_idx]) - - -def prepare_to_cache(cache, slot_mapping): - num_blocks = cache.size(0) - block_size = cache.size(1) - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - offsets = torch.fmod(slot_mapping, block_size) - num_slots_requested = slot_mapping.size(0) - num_slots_available = num_blocks * block_size - # NOTE(kzawora): HPU PT bridge crashes with - # RuntimeError: Invalid inputs for scatter_nd_onnx - # on index_put when num_slots_requested > num_slots_available. - # This case might occur when we have little kv cache blocks and - # lots of padding, or are doing warmup. - # This loop is a workaround for this issue. Please remove it - # once key_cache.index_put_(indices, offsets), key) works. - num_kv_cache_passes = math.ceil(num_slots_requested / num_slots_available) - - return num_kv_cache_passes, num_slots_available, indices, offsets - - -def insert_or_update_cache(input, cache, num_kv_cache_passes, - num_slots_available, block_indices, block_offsets): - for i in range(num_kv_cache_passes): - start_idx = i * num_slots_available - end_idx = (i + 1) * num_slots_available - cache.index_put_((block_indices[start_idx:end_idx], - block_offsets[start_idx:end_idx]), - input[start_idx:end_idx]) - - -def swap_blocks(src, dst, block_mapping): - index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device) - index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device) - for src_idx, dst_idx in block_mapping.items(): - index_src[0] = src_idx - index_dst[0] = dst_idx - dst.index_put_([index_dst], src.index_select(0, index_src)) - if dst.device.type == 'hpu': - htorch.core.mark_step() - torch.hpu.synchronize() - - -def copy_blocks(key_caches, value_caches, block_mapping): - index_src = torch.zeros((1, ), - dtype=torch.int32, - device=key_caches[0].device) - index_dst = torch.zeros((1, ), - dtype=torch.int32, - device=key_caches[0].device) - for src, dsts in block_mapping.items(): - index_src[0] = src - for dst in dsts: - index_dst[0] = dst - for key_cache in key_caches: - key_cache.index_copy_(0, index_dst, - key_cache.index_select(0, index_src)) - for value_cache in value_caches: - value_cache.index_copy_(0, index_dst, - value_cache.index_select(0, index_src)) - if key_caches[0].device.type == 'hpu': - htorch.core.mark_step() diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py deleted file mode 100644 index aaf863aff0cad..0000000000000 --- a/vllm/hpu/ops.py +++ /dev/null @@ -1,252 +0,0 @@ -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -############################################################################### -from typing import Optional - -import habana_frameworks.torch as htorch -import torch -import torch.nn.functional as F - -from vllm.logger import init_logger - -logger = init_logger(__name__) -HPUFusedRMSNorm = None -try: - from habana_frameworks.torch.hpex.normalization import FusedRMSNorm - HPUFusedRMSNorm = FusedRMSNorm -except ImportError: - logger.warning("Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm.") -HPUFusedSDPA = None -try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - HPUFusedSDPA = FusedSDPA -except ImportError: - logger.warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") - - -def batch2block(tensor, block_mapping): - shape = tuple(tensor.shape) - return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) - - -def block2batch(tensor, block_mapping): - shape = tuple(tensor.shape) - return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:]) - - -def block_softmax(batch_size, attn, block_mapping): - attn.sub_(10.0) - attn = attn.exp_() - sums = attn.sum(dim=-1).unsqueeze(-1) - sums = block2batch(sums, block_mapping) - sums = batch2block(sums, block_mapping) - sums.add_(1.0e-12) - attn.div_(sums) - return attn - - -def flat_pa(query, key_cache, value_cache, block_list, block_mapping, - block_bias, scale, matmul_qk_op, matmul_av_op, keys_fetch_func, - values_fetch_func): - batch_size = query.size(0) - q_heads = query.size(1) - kv_heads = key_cache.size(2) - - query = batch2block(scale * query, block_mapping).unsqueeze(-2) - key = keys_fetch_func(key_cache, block_list).transpose(1, 2) - value = values_fetch_func(value_cache, block_list).transpose(1, 2) - block_bias = block_bias.view(key.size(0), 1, 1, -1) - - if kv_heads != q_heads: - block_bias = block_bias.unsqueeze(1) - query = query.unflatten(1, (kv_heads, -1)) - key = key.unflatten(1, (kv_heads, 1)) - value = value.unflatten(1, (kv_heads, 1)) - key = key.transpose(3, 4) - else: - key = key.transpose(2, 3) - - attn = matmul_qk_op(query, key) + block_bias - attn = block_softmax(batch_size, attn, block_mapping) - attn = matmul_av_op(attn, value) - attn = block2batch(attn, block_mapping) - attn = attn.squeeze(-2) - if kv_heads != q_heads: - attn = attn.flatten(1, 2) - return attn - - -def silu_and_mul(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] - - -def static_fused_moe(hidden_states, w1, w2, score, topk): - B, D = hidden_states.shape - num_experts = w1.shape[0] - routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros((1, B, D), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights = torch.zeros((B, num_experts), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights.scatter_(-1, selected_experts, routing_weights) - padded_weights = padded_weights.reshape(-1, B, w1.shape[0]) - padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - - htorch.core.mark_step() - - for expert_idx in range(num_experts): - w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1)) - w_output = silu_and_mul(w_output) - w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) - final_hidden_states += w_output * padded_weights[expert_idx] - - return final_hidden_states.view(-1, D) - - -#TODO: remove after fusedsdpa fix for query_head != kv_head -def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - The kv go from (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = kv.shape - if n_rep == 1: - return kv - kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) - return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def prompt_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - p: float = 0.0, - scale: Optional[float] = None, - matmul_qk_op=torch.matmul, - softmax_op=torch.softmax, - matmul_av_op=torch.matmul, - valid_seq_lengths: Optional[torch.Tensor] = None, -) -> torch.Tensor: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - query_heads = query.size(1) - kv_heads = key.size(1) - if attn_bias is not None or HPUFusedSDPA is None: - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - key = key.unflatten(1, (kv_heads, 1)) - value = value.unflatten(1, (kv_heads, 1)) - if attn_bias is not None: - attn_bias = attn_bias.unsqueeze(2) - attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) - if attn_bias is not None: - attn_weights.add_(attn_bias) - attn_weights = softmax_op(attn_weights, dim=-1) - attn_weights = matmul_av_op(attn_weights, value) - if query_heads != kv_heads: - attn_weights = attn_weights.flatten(1, 2) - else: - #TODO: remove after fusedsdpa fix for query_heads != kv_heads - if query_heads != kv_heads: - key = repeat_kv(key, int(query_heads // kv_heads)) - value = repeat_kv(value, int(query_heads // kv_heads)) - softmax_mode = 'fast' - recompute_mode = True - attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True, - scale, softmax_mode, recompute_mode, - valid_seq_lengths, 'right') - attn_weights = attn_weights.transpose(1, 2) - return attn_weights - - -class LoraMask: - lora_mask = None - - @staticmethod - def setLoraMask(mask): - LoraMask.lora_mask = mask - - @staticmethod - def getLoraMask(): - return LoraMask.lora_mask - - -def dispatch_bgmv_linear( - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - layer_idx: int, - scale: float, -): - """ - `wa_t_all` and `wb_t_all` contains all LoRA A and LoRA B weight matrices - stacked at dimension 0 into single tensors, assuming same rank. `wa` is the - reshaped and transposed version of `wa_t_all` of shape - (h_in, max_loras * lora_rank) and `wb` is the transposed and reshaped - version of `wb_t_all` of shape (max_loras * lora_rank, h_out). - - Matmul input `x` with `wa`. Multiply `x` with a mask to zero-out inputs of - inactive LoRA indices. Matmul masked output with `wb` and scale it to get - the final output. - """ - - assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' - mask = LoraMask.getLoraMask() - - wa = wa_t_all[:, 0, :, :] - wb = wb_t_all[:, 0, :, :].transpose(1, 2) - wa = wa.reshape(wa.shape[0] * wa.shape[1], wa.shape[2]).transpose(0, 1) - wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2]) - - out = x @ wa - assert (out.shape == mask.shape) - out = out * mask - out = out @ wb - y += out * scale - - -def dispatch_bgmv_embedding( - y: torch.Tensor, - x: torch.Tensor, - wb_t_all: torch.Tensor, - layer_idx: int, - scale: float, -): - """ - `wb_t_all` contains all LoRA-B weight matrices stacked at dimension 0 into - a single tensor, assuming same rank. `wb` is the transposed and reshaped - version of `wb_t_all` of shape (num_loras * lora_rank, embedding_dim). - - Output of LoRA-A embedding (tensor x) is repeated max_loras times to match - the shape of `wb`. Multiply `x` with a mask to zero-out inputs of inactive - LoRA indices. Matmul masked output with `wb` and scale it to get the final - output. - """ - - assert layer_idx == 0, f'layer_idx should be 0, but got {layer_idx}' - max_loras = wb_t_all.size(0) - - x = x.repeat(1, max_loras) - x = x * LoraMask.getLoraMask() - wb = wb_t_all[:, 0, :, :].transpose(1, 2) - wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2]) - out = x @ wb - y += out * scale diff --git a/vllm/hpu/punica_hpu.py b/vllm/hpu/punica_hpu.py index aed015ac4ae06..3c37558831bb5 100644 --- a/vllm/hpu/punica_hpu.py +++ b/vllm/hpu/punica_hpu.py @@ -9,7 +9,7 @@ import torch from vllm.lora.punica import PunicaWrapper -from vllm.hpu.ops import dispatch_bgmv_linear, dispatch_bgmv_embedding +from vllm_hpu_extension.ops import dispatch_bgmv_linear, dispatch_bgmv_embedding class GaudiPunicaWrapper(PunicaWrapper): def __init__(self, max_num_batched_tokens: int, max_batches: int, diff --git a/vllm/hpu/rotary_embed.py b/vllm/hpu/rotary_embed.py deleted file mode 100644 index 30a88d68a24af..0000000000000 --- a/vllm/hpu/rotary_embed.py +++ /dev/null @@ -1,115 +0,0 @@ -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -############################################################################### - -import torch -import torch.nn as nn - -from vllm.logger import init_logger -from vllm.utils import is_hpu - -logger = init_logger(__name__) - -if is_hpu(): - try: - from habana_frameworks.torch.hpex.kernels import ( - RotaryPosEmbeddingHelperV1 as FusedRoPE) - except ImportError: - logger.warning("Could not import HPU FusedRoPE kernel. " - "vLLM will use forward_native implementation of RoPE.") - FusedRoPE = None -else: - FusedRoPE = None - - -class HpuRotaryEmbedding(nn.Module): - - def __init__(self, - head_size, - rotary_dim, - max_position_embeddings=2048, - base=10000, - is_neox_style=None, - device='hpu', - RoPEFallback=None): - super().__init__() - - self.head_size = head_size - self.dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**( - torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) - if FusedRoPE is None: - assert RoPEFallback is not None, ( - "HPU FusedRoPE kernel could not be imported, and " - "fallback RoPE implementation was not provided!") - self.fallback_impl = RoPEFallback(head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype=torch.get_default_dtype()) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, - device=device, - dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order - # to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", - emb.cos().to(dtype), - persistent=False) - self.register_buffer("sin_cached", - emb.sin().to(dtype), - persistent=False) - - def forward(self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor): - if FusedRoPE is None: - return self.fallback_impl(positions, query, key) - if query.dim() == 2: - query = query.unsqueeze(0) - if key.dim() == 2: - key = key.unsqueeze(0) - if positions.dim() == 1: - positions = positions.unsqueeze(0) - seq_len = key.shape[-2] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, - device=query.device, - dtype=query.dtype) - - cos, sin = self.cos_cached[:seq_len].to( - dtype=query.dtype), self.sin_cached[:seq_len].to(dtype=query.dtype) - query = query.reshape( - (query.shape[0], query.shape[1], query.shape[2] // self.head_size, - self.head_size)) - key = key.reshape((key.shape[0], key.shape[1], - key.shape[2] // self.head_size, self.head_size)) - - if len(positions[0]) == 1: - cos = self.cos_cached[positions].unsqueeze(2).to(dtype=query.dtype) - sin = self.sin_cached[positions].unsqueeze(2).to(dtype=query.dtype) - else: - cos = cos[positions].unsqueeze(2) - sin = sin[positions].unsqueeze(2) - query, key = FusedRoPE.apply(query, cos, sin, - 0), FusedRoPE.apply(key, cos, sin, 0) - return query.reshape( - (query.shape[0], query.shape[1], - query.shape[2] * query.shape[3])), key.reshape( - (key.shape[0], key.shape[1], key.shape[2] * key.shape[3])) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py deleted file mode 100644 index 13204b83d5742..0000000000000 --- a/vllm/hpu/utils.py +++ /dev/null @@ -1,61 +0,0 @@ -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -############################################################################### - -from functools import wraps - -import habana_frameworks.torch as htorch -import torch - -from vllm.hpu.cache_ops import insert_or_update_cache - - -def with_mark_steps(fn): - - @wraps(fn) - def wrapped(*args, **kwargs): - htorch.core.mark_step() - result = fn(*args, **kwargs) - del args - del kwargs - htorch.core.mark_step() - return result - - return wrapped - - -class Matmul(torch.nn.Module): - - def __init__(self): - super(Matmul, self).__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class Softmax(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, x, dim=None, inv_head=None): - return torch.softmax(x, dim) - - -class VLLMKVCache(torch.nn.Module): - - def __init__(self): - super(VLLMKVCache, self).__init__() - - def forward(self, input, cache, num_kv_cache_passes, num_slots_available, - block_indices, block_offset): - insert_or_update_cache(input, cache, num_kv_cache_passes, - num_slots_available, block_indices, - block_offset) - return cache - - def fetch_from_cache(self, cache, blocks): - return cache.index_select(0, blocks) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e6be20edc8ce6..b6e7e6783a328 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -30,6 +30,10 @@ VocabParallelEmbedding) from vllm.platforms import current_platform +if current_platform.is_hpu(): + from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, + dispatch_bgmv_linear) + if TYPE_CHECKING: pass diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 554dcc0ed43ed..fc2ee94e0e35a 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -80,7 +80,7 @@ def __call__(self, input_ids: List[int], -math.inf, device=scores.device) mask[allowed_tokens] = 0 - scores.add_(mask) + scores = scores.add(mask) return scores diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 10672d200d352..9e4e7233c1eba 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -14,9 +14,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -if current_platform.is_hpu(): - from vllm.hpu.ops import static_fused_moe - logger = init_logger(__name__) @@ -122,15 +119,25 @@ def forward_cuda( topk_ids=topk_ids, inplace=True) - def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, num_expert_group: Optional[int], - topk_group: Optional[int]): + def forward_hpu( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + layer: Optional[torch.nn.Module], + ): assert not use_grouped_topk, 'use_grouped_topk must be False on HPU' assert num_expert_group is None, ('num_expert_group is ' 'not supported on HPU') assert topk_group is None, 'topk_group is not supported on HPU' - return static_fused_moe(x, w1, w2, router_logits, top_k) + if layer is not None: + return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( @@ -165,7 +172,7 @@ def forward_tpu( class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / + This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We @@ -218,6 +225,9 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function + if current_platform.is_hpu(): + from vllm_hpu_extension.ops import StaticFusedMOE + self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -540,3 +550,90 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, # If we are in the row parallel case (down_proj) else: param_data[expert_id] = loaded_weight + # Weights + else: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.intermediate_size_per_partition + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + + # w1, gate_proj case: Load into first shard of w13. + if shard_id == 0: + param_data[expert_id, + 0:shard_size, :] = loaded_weight[shard, :] + if current_platform.is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight( + param_data[expert_id]) + # w3, up_proj case: Load into second shard of w13. + elif shard_id == 2: + param_data[expert_id, shard_size:2 * + shard_size, :] = loaded_weight[shard, :] + if current_platform.is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight( + param_data[expert_id]) + # w2, down_proj case: Load into only shard of w2. + elif shard_id == 1: + param_data[expert_id, :, :] = loaded_weight[:, shard] + if current_platform.is_hpu(): + self.hpu_static_fused_moe.w2_list[expert_id].set_weight( + param_data[expert_id]) + else: + raise ValueError( + f"Shard id must be in [0,1,2] but got {shard_id}") + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, int]]: + + gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] + gate_down_up = [ + ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name + ] + + return [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_scale" + if weight_name in gate_up else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, + shard_id) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_weight" + if weight_name in gate_up else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.a13_scale" + if weight_name in gate_up else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", expert_id, + shard_id) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index b7f6103f05580..049b371df269e 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -75,7 +75,7 @@ def forward_hpu( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - from vllm.hpu.ops import HPUFusedRMSNorm + from vllm_hpu_extension.ops import HPUFusedRMSNorm if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index f6718ec2ac9e7..ec0141b61f58f 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -5,6 +5,8 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -52,6 +54,8 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["INCLinearMethod"]: if isinstance(layer, LinearBase): return INCLinearMethod(self) + elif isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod() return None def get_scaled_act_names(self) -> List[str]: @@ -78,7 +82,7 @@ class INCLinearMethod(LinearMethodBase): 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) - + Args: quant_config: The quantization config. """ diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 260bb8cf28a0d..79bf086c7b24e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -31,7 +31,7 @@ from vllm.platforms import current_platform if current_platform.is_hpu(): - from vllm.hpu.rotary_embed import HpuRotaryEmbedding + from vllm_hpu_extension.rotary_embed import HpuRotaryEmbedding def _rotate_neox(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py old mode 100644 new mode 100755 index c00da106734ae..035453c92866f --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -3,6 +3,7 @@ import warnings from dataclasses import dataclass from importlib.util import find_spec +import math from math import inf from typing import Dict, List, Optional, Tuple, Union @@ -203,6 +204,13 @@ def _init_sampling_tensors( self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p + self._top_p_scalar = sampling_tensors.top_ps[0].item() + self._top_k_scalar = sampling_tensors.top_ks[0].item() + scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar) + scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar) + self._scalar_p_and_k = (scalar_p and scalar_k).item() + if self._scalar_p_and_k and self._do_top_p_top_k: + self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) def forward( self, @@ -262,8 +270,13 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + if self._scalar_p_and_k: + logits = self._apply_top_k_top_p_opt(logits, + self._top_p_scalar, + self._top_k_scalar) + else: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -345,6 +358,101 @@ def _get_bin_counts_and_mask( return bin_counts, mask +class ApplyToppTopkScalar(): + """ + The original implementation of _apply_top_k_top_p is more general + as it uses vector topp, topk + However in a lot of cases, topp and topk is same for all batch elements + For such "scalar" topp, topk cases, we can use this class + + The main optimizations in this class is: + Use topk instead of sort, which is much faster especially for small k. + However just using topk might not suffice in cases as shown below + Consider a tensor: 9 9 8 8 8 8 7 7 7 + Topk, with k=5, on this yields 9 9 8 8 8 + The value "8" is on the boundary, hence the last "8" gets snipped off + However the original implementation accepts all the "8"s, + so it should output: + 9 9 8 8 8 8 (6 values, even though k=5) + To ensure these semantics, we perform topk with _padded_k elements + If we find more boundary elements left over, + then we keep incrementing _padded_k + and in future calls use the expanded value of __padded_k + + The increments to _padded_k should be done + with value > 1 to prevent excessive recompilations + due to dynamic shapes (the output shape of the topk) + + The main logic of this is in __call__ + This is a class instead of a function, just to keep track of + the monotonic non-decreasing state _padded_k + """ + _padded_k = 0 + + def __init__(self, increment: int): + self._increment = increment + + def __call__(self, logits: torch.Tensor, p: float, k: int): + if k > ApplyToppTopkScalar._padded_k: + ApplyToppTopkScalar._padded_k = min(k + self._increment, + logits.shape[1]) + + vals, idx = torch.topk(logits, k=ApplyToppTopkScalar._padded_k, \ + dim=1, sorted=True) + + # this "if" checks if we have bucketed so much that + # we have padded k upto shape of logits + if ApplyToppTopkScalar._padded_k != logits.shape[1]: + smallest_of_top_k = vals[:, k - 1] + num_duplicates_of_smallest_of_topk = torch.sum( + logits == smallest_of_top_k.unsqueeze(1), 1) + max_num_duplicates_of_smallest_of_topk = torch.max( + num_duplicates_of_smallest_of_topk).item() + + # there are n repeats for a border + # (border meaning the smallest value of the top k). + # we do not know if only 1 or 2 or (n-1) + # of them lie outside the kth border, + # so we choose to conservatively increase by n-1 + # when num_duplicates > _padded_k - k + if max_num_duplicates_of_smallest_of_topk - 1 > ( + ApplyToppTopkScalar._padded_k - k): + incr = int( + math.ceil((max_num_duplicates_of_smallest_of_topk - 1) / + self._increment) * self._increment) + # this while loop should be traversed at most twice, + # because we dont increment by self._increment and retry + # instead we compute incr in one go + ApplyToppTopkScalar._padded_k = min( + ApplyToppTopkScalar._padded_k + incr, logits.shape[1]) + + # recompute topk with expanded padded_k + vals, idx = torch.topk(logits, \ + k=ApplyToppTopkScalar._padded_k, \ + dim=1, sorted=True) + + idx = torch.fliplr(idx) + vals = torch.fliplr(vals) + + top_k_smallest_val_idx = vals.size(1) - k + top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1) + top_k_mask = vals < top_k_mask + vals.masked_fill_(top_k_mask, -float("inf")) + + probs_sort = vals.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= (1 - p) + top_p_mask[:, -1] = False + vals.masked_fill_(top_p_mask, -float("inf")) + + new_logits = torch.full(logits.shape, + -float("inf"), + device=logits.device) + new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) + + return new_logits + + def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 0cb373441f869..51f0f9f5e06db 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -85,6 +85,7 @@ def device_loading_context(module: torch.nn.Module, p.data = p.data.to(original_device) # New parameters or parameters already on target device are untouched +from vllm.utils import is_fake_hpu logger = init_logger(__name__) @@ -353,7 +354,10 @@ def load_model(self, *, model_config: ModelConfig, cache_config: CacheConfig) -> nn.Module: target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(self.load_config.device): + _device = torch.device( + device_config.device) if is_fake_hpu() else torch.device( + self.load_config.device) + with _device: model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0052489d99dc4..506a0e197fb2b 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] + mixtral_supported = ["fp8", "compressed-tensors", "inc"] # for gptq_marlin, only run fused MoE for int4 if model_config.quantization == "gptq_marlin": hf_quant_config = getattr(model_config.hf_config, diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index efa044d0b5e92..3ae9003dfa3b7 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -131,13 +131,11 @@ def __init__(self, torch.empty(self.num_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_experts, self.hidden_size, self.intermediate_size, - device="cuda", dtype=self.params_dtype)) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 6160197dc19de..e81de181815b3 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -82,21 +82,15 @@ def __init__( self.router = DbrxRouter(config, self.params_dtype) self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.d_model, - device="cuda", - dtype=self.params_dtype, - )) + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.d_model, + dtype=self.params_dtype)) self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.d_model, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype, - )) + torch.empty(self.num_total_experts, + self.d_model, + self.intermediate_size, + dtype=self.params_dtype)) set_weight_attrs( self.ws, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 88d2bcb9f0c9d..47ec718a43420 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -254,7 +254,6 @@ def forward( if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds - for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index a64e08c422bc3..649caba5d9424 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -47,7 +47,7 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors - +from vllm.platform import current_platform from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers @@ -272,6 +272,9 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( @@ -281,11 +284,15 @@ def forward( attn_metadata, residual, ) + if current_platform.is_hpu(): + htorch.core.mark_step() + if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/utils.py b/vllm/utils.py index ed565d3244541..a8824890e3a29 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -334,6 +334,52 @@ def is_neuron() -> bool: return transformers_neuronx is not None +@lru_cache(maxsize=None) +def is_hpu() -> bool: + return _is_habana_frameworks_installed() or _is_built_for_hpu() + + +@lru_cache(maxsize=None) +def is_fake_hpu() -> bool: + return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' + + +@lru_cache(maxsize=None) +def hpu_device_string(): + device_string = 'hpu' if not is_fake_hpu() else 'cpu' + return device_string + + +@lru_cache(maxsize=None) +def hpu_backend_string(): + backend_string = 'hccl' if not is_fake_hpu() else 'gloo' + return backend_string + + +@lru_cache(maxsize=None) +def _is_habana_frameworks_installed() -> bool: + from importlib import util + return util.find_spec('habana_frameworks') is not None + + +@lru_cache(maxsize=None) +def _is_built_for_hpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "gaudi" in version("vllm") + except PackageNotFoundError: + return False + + +@lru_cache(maxsize=None) +def is_tpu() -> bool: + try: + import libtpu + except ImportError: + libtpu = None + return libtpu is not None + + @lru_cache(maxsize=None) def is_xpu() -> bool: from importlib.metadata import PackageNotFoundError, version @@ -754,18 +800,24 @@ def __init__(self, device=None): @staticmethod def current_device_memory_usage() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory - free_hpu_memory @staticmethod def current_free_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. free_hpu_memory, _ = torch.hpu.mem_get_info() return free_hpu_memory @staticmethod def total_device_memory() -> float: + if is_fake_hpu(): + return 0 # Return the device memory usage in bytes. _, total_hpu_memory = torch.hpu.mem_get_info() return total_hpu_memory @@ -1353,3 +1405,28 @@ def dec(self, num=1): @property def value(self): return self._value + +def migrate_to_cpu(): + import importlib + from unittest.mock import MagicMock + + torch.hpu = MagicMock(name="torch.hpu") + + # Adding dummy submodules to habana_frameworks.torch for cpu-test, + # functions from dummy modules will do nothing by default + spec = importlib.util.spec_from_loader('habana_frameworks', loader=None) + sys.modules['habana_frameworks'] = MagicMock() + sys.modules['habana_frameworks'].__spec__ = spec + + builtin_import = __builtins__['__import__'] # type: ignore + + def import_wrapper(name, *args, **kwargs): + if 'habana_frameworks' in name: + sys.modules[name] = MagicMock() + return builtin_import(name, *args, **kwargs) + + __builtins__['__import__'] = import_wrapper + + # In case you want to mock a function to actually do something + import habana_frameworks.torch as htorch + htorch.utils.internal.is_lazy.return_value = False diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ec0b8c2369210..f678d44f71dd3 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu, is_pin_memory_available) logger = init_logger(__name__) @@ -78,7 +78,7 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers): - if device == 'hpu': + if device == 'hpu' or is_fake_hpu(): key_cache = torch.zeros(kv_cache_shape, dtype=self.dtype, device=device) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 5336ad3ed4da9..e8bf5dfb34628 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -20,13 +20,13 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch +from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.distributed.parallel_state import get_world_group -from vllm.hpu.ops import LoraMask as LoraMask from vllm.inputs.registry import InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -35,11 +35,13 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.multimodal.registry import MultiModalRegistry +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.utils import (HabanaMemoryProfiler, format_bytes, +from vllm.utils import (HabanaMemoryProfiler, format_bytes, is_fake_hpu, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, @@ -94,7 +96,7 @@ def read_bucket_settings(phase: str, dim: str, **defaults): values = [ int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) ] - for e, v, d in zip(env_vars, values, defaults): + for e, v, d in zip(env_vars, values, default_values): logger.info('%s=%s (default:%s)', e, v, d) return values @@ -175,11 +177,16 @@ def generate_prompt_buckets(bs_bucket_config, def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, max_blocks): buckets = [] - for bs in warmup_range(bs_bucket_config): - for blocks in warmup_range(blocks_bucket_config): + bs_buckets = warmup_range(bs_bucket_config) + block_buckets = warmup_range(blocks_bucket_config) + bmin, bstep, bmax = blocks_bucket_config + last_bucket = max_blocks if (max_blocks // bstep + == 0) else (max_blocks // bstep + 1) * bstep + for bs in bs_buckets: + for blocks in block_buckets: if blocks < bs: continue - if blocks > max_blocks: + if blocks > last_bucket: break buckets.append((bs, blocks)) return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) @@ -248,7 +255,8 @@ def __init__(self, model, block_size, dtype, enforce_eager): '0').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype - if not htorch.utils.internal.is_lazy() and not enforce_eager: + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( + ) and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', dynamic=False) @@ -332,7 +340,7 @@ class PreparePromptMetadata(NamedTuple): lora_index_mapping: List[List[int]] lora_prompt_mapping: List[List[int]] lora_requests: Set[LoRARequest] - multi_modal_input: Optional[torch.Tensor] + multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]] slot_mapping: List[List[int]] lora_mask: Optional[torch.Tensor] lora_logits_mask: Optional[torch.Tensor] @@ -347,7 +355,7 @@ def empty(cls): lora_index_mapping=[], lora_prompt_mapping=[], lora_requests=set(), - multi_modal_input=None, + multi_modal_kwargs=None, slot_mapping=[], lora_mask=None, lora_logits_mask=None) @@ -517,7 +525,9 @@ def __init__( if model_config is not None else None) self.device_config = (device_config if device_config is not None else DeviceConfig()) - + if is_fake_hpu(): + device_config.device = torch.device('cpu') + device_config.device_type = 'cpu' self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs @@ -548,6 +558,30 @@ def __init__( self.seen_configs: set = set() self._mem_margin: Optional[int] = None self._setup_buckets() + self._set_gc_threshold() + + def _set_gc_threshold(self) -> None: + # Read https://docs.python.org/3/library/gc.html#gc.set_threshold + # for comprehensive description of gc generations. + # We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority) + # to set particular generation threshold or use simpler + # VLLM_GC_THR_MULTIPLIER to multiply default values. + default_gc_thrs = list(gc.get_threshold()) + requested_gc_thrs = [0] * len(default_gc_thrs) + for i in range(len(default_gc_thrs)): + requested_gc_thrs[i] = int( + os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i])) + if requested_gc_thrs == default_gc_thrs: + gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER', + 2)) + requested_gc_thrs = [ + t * gc_thr_multiplier for t in default_gc_thrs + ] + gc.set_threshold(*requested_gc_thrs) + + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) def load_model(self) -> None: import habana_frameworks.torch.core as htcore @@ -599,7 +633,7 @@ def load_model(self) -> None: mark_only_scales_as_const=True) logger.info("Preparing model with INC took %s", m_inc.get_summary_string()) - else: + elif not is_fake_hpu(): self.model = self.model.to("hpu") htcore.mark_step() torch.hpu.synchronize() @@ -685,7 +719,7 @@ def _prepare_prompt( context_lens: List[int] = [] query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() @@ -743,9 +777,10 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.append(list(range(context_len, seq_len))) - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -783,12 +818,6 @@ def _prepare_prompt( real_num_seqs = len(query_lens) assert max_query_len > 0 - if multi_modal_input_list: - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) @@ -870,6 +899,8 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, ) + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + return PreparePromptMetadata( input_tokens=input_tokens, input_positions=input_positions, @@ -879,7 +910,7 @@ def _prepare_prompt( lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, - multi_modal_input=multi_modal_input, + multi_modal_kwargs=multi_modal_kwargs, slot_mapping=slot_mapping, lora_mask=lora_mask, lora_logits_mask=lora_logits_mask, @@ -944,10 +975,12 @@ def _prepare_decode( seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - if block_number == _PAD_BLOCK_ID: + if len(block_table) == 0: + block_number = _PAD_BLOCK_ID + block_table = [] slot = next(dummy_slots) else: + block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) @@ -972,7 +1005,7 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) - blocks_used = [len(bt) for bt in block_tables] + blocks_used = [len(bt) for bt in block_tables if bt] block_list = list(itertools.chain(*block_tables)) block_mapping_nested: List[List[int]] = [ [i] * b_u for i, b_u in enumerate(blocks_used) @@ -1042,7 +1075,7 @@ def prepare_input_tensors( input_positions = None lora_mapping = None lora_requests = None - multi_modal_input = None + multi_modal_kwargs = None batch_type = None seq_lens = None query_lens = None @@ -1060,8 +1093,11 @@ def prepare_input_tensors( batch_size_padded = find_bucket(real_batch_size, bucket_cfg) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() - seq_group_metadata_list.extend(seq_group_metadata_list[0] - for _ in range(batch_size_padding)) + if batch_size_padding > 0: + dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( + 0, 0, is_prompt) + seq_group_metadata_list.extend(dummy_seq_group_metadata + for _ in range(batch_size_padding)) prefill_reqs = [] decode_reqs = [] @@ -1081,7 +1117,7 @@ def prepare_input_tensors( lora_index_mapping, lora_prompt_mapping, lora_requests, - multi_modal_input, + multi_modal_kwargs, slot_mapping, lora_mask, lora_logits_mask, @@ -1165,7 +1201,7 @@ def prepare_input_tensors( "selected_token_indices": sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_input": multi_modal_input, + "multi_modal_kwargs": multi_modal_kwargs, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, @@ -1190,7 +1226,7 @@ def prepare_input_tensors( attn_metadata=attn_metadata, lora_requests=lora_requests, lora_mapping=lora_mapping, - multi_modal_kwargs=multi_modal_input, + multi_modal_kwargs=multi_modal_kwargs, real_batch_size=real_batch_size, batch_size_padded=batch_size_padded, lora_mask=lora_mask, @@ -1800,7 +1836,6 @@ def execute_model( input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata sampling_metadata = model_input.sampling_metadata - multi_modal_input = model_input.multi_modal_kwargs real_batch_size = model_input.real_batch_size batch_size_padded = model_input.batch_size_padded assert input_tokens is not None @@ -1819,10 +1854,9 @@ def execute_model( "kv_caches": kv_caches, "attn_metadata": self.trim_attn_metadata(attn_metadata), "intermediate_tensors": intermediate_tensors, - "lora_mask": model_input.lora_mask + "lora_mask": model_input.lora_mask, + **(model_input.multi_modal_kwargs or {}), } - if multi_modal_input is not None: - execute_model_kwargs.update(multi_modal_input) if htorch.utils.internal.is_lazy(): execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 407c618a9d597..89c796068bac4 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -21,7 +21,8 @@ from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import HabanaMemoryProfiler, format_bytes +from vllm.utils import (HabanaMemoryProfiler, format_bytes, hpu_backend_string, + hpu_device_string, is_fake_hpu) from vllm.worker.cache_engine import CacheEngine from vllm.worker.habana_model_runner import HabanaModelRunner from vllm.worker.model_runner_base import ModelRunnerBase @@ -109,6 +110,8 @@ def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") torch.hpu.set_device(self.device) + elif self.device_config.device_type == "cpu": + self.device = torch.device("cpu") else: raise RuntimeError( f"Not support device type: {self.device_config.device}") @@ -142,6 +145,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. + if is_fake_hpu(): + cache_block_size = self.get_cache_block_size_bytes() + fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu + return fake_hpu_cache_alloc // cache_block_size, 0 with HabanaMemoryProfiler() as m: self.model_runner.profile_run() torch.hpu.synchronize() @@ -154,7 +161,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: cache_block_size = self.get_cache_block_size_bytes() graph_reserved_mem = (float( - os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.4')) + os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1')) if not self.model_config.enforce_eager else 0) graph_headroom = 1 - graph_reserved_mem available_hpu_memory = free_hpu_memory * \ @@ -339,11 +346,12 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + backend = hpu_backend_string() init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, - backend='hccl') + backend=backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) @@ -360,15 +368,17 @@ def init_worker_distributed_environment( "distributed_init_method must be set if torch.distributed " "is not already initialized") else: + backend = hpu_backend_string() torch.distributed.init_process_group( - backend="hccl", + backend=backend, world_size=parallel_config.world_size, rank=rank, init_method=distributed_init_method, ) # A small all_reduce for warmup & checking conformance. - dummy_tensor_hpu = torch.ones(1).to('hpu') + device = hpu_device_string() + dummy_tensor_hpu = torch.ones(1).to(device) torch.distributed.all_reduce(dummy_tensor_hpu) assert dummy_tensor_hpu.item() == parallel_config.world_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,