From b40cf6402e356a10415e969e648a32911fb9b8ec Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 15 Nov 2024 12:23:09 +0800 Subject: [PATCH] [Model] Support Qwen2 embeddings and use tags to select model tests (#10184) --- .buildkite/run-cpu-test-ppc64le.sh | 6 +- .buildkite/run-cpu-test.sh | 6 +- .buildkite/test-pipeline.yaml | 48 ++++---- docs/source/models/supported_models.rst | 13 +- .../decoder_only/language/test_jamba.py | 18 +-- .../decoder_only/language/test_mamba.py | 18 +-- .../decoder_only/language/test_models.py | 71 ++++++----- .../embedding/language/test_cls_models.py | 30 ++--- .../embedding/language/test_embedding.py | 42 +++---- .../vision_language/test_llava_next.py | 2 + .../embedding/vision_language/test_phi3v.py | 2 + .../encoder_decoder/language/test_bart.py | 11 +- .../vision_language/test_mllama.py | 3 + tests/models/registry.py | 4 + tests/models/test_registry.py | 4 +- vllm/model_executor/models/qwen2.py | 112 ++++++++++++++++-- vllm/model_executor/models/qwen2_cls.py | 15 +-- vllm/model_executor/models/qwen2_rm.py | 16 +-- vllm/model_executor/models/registry.py | 9 +- 19 files changed, 252 insertions(+), 178 deletions(-) diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh index 79526adef2a79..5d7a0bff90963 100755 --- a/.buildkite/run-cpu-test-ppc64le.sh +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -27,9 +27,9 @@ function cpu_tests() { decord einops librosa peft Pillow sentence-transformers soundfile \ transformers_stream_generator matplotlib datamodel_code_generator pip install torchvision --index-url https://download.pytorch.org/whl/cpu - pytest -v -s tests/models/embedding/language - pytest -v -s tests/models/encoder_decoder/language - pytest -v -s tests/models/decoder_only/language/test_models.py + pytest -v -s tests/models/decoder_only/language -m cpu_model + pytest -v -s tests/models/embedding/language -m cpu_model + pytest -v -s tests/models/encoder_decoder/language -m cpu_model pytest -v -s tests/models/decoder_only/audio_language -m cpu_model pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index a00331abb7d03..14756b5964aaf 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -38,9 +38,9 @@ function cpu_tests() { decord einops librosa peft Pillow sentence-transformers soundfile \ transformers_stream_generator matplotlib datamodel_code_generator pip install torchvision --index-url https://download.pytorch.org/whl/cpu - pytest -v -s tests/models/embedding/language - pytest -v -s tests/models/encoder_decoder/language - pytest -v -s tests/models/decoder_only/language/test_models.py + pytest -v -s tests/models/decoder_only/language -m cpu_model + pytest -v -s tests/models/embedding/language -m cpu_model + pytest -v -s tests/models/encoder_decoder/language -m cpu_model pytest -v -s tests/models/decoder_only/audio_language -m cpu_model pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index baad54eaf6a91..24bf223fb12c0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -323,62 +323,60 @@ steps: - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py -- label: Decoder-only Language Models Test (Standard) # 18min +- label: Language Models Test (Standard) # 42min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/decoder_only/language + - tests/models/embedding/language + - tests/models/encoder_decoder/language commands: - - pytest -v -s models/decoder_only/language -m core_model - - pytest -v -s models/decoder_only/language -m quant_model + - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' + - pytest -v -s models/embedding/language -m core_model + - pytest -v -s models/embedding/vision_language -m core_model -- label: Decoder-only Language Models Test (Extended) # 46min +- label: Language Models Test (Extended) # 50min nightly: true source_file_dependencies: - vllm/ - tests/models/decoder_only/language + - tests/models/embedding/language + - tests/models/encoder_decoder/language commands: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' + - pytest -v -s models/embedding/language -m 'not core_model' + - pytest -v -s models/embedding/vision_language -m 'not core_model' -- label: Decoder-only Multi-Modal Models Test (Standard) # 22min +- label: Multi-Modal Models Test (Standard) # 26min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language + - tests/models/embedding/vision_language + - tests/models/encoder_decoder/vision_language commands: - - pytest -v -s models/decoder_only/audio_language -m core_model - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model - # No tests under this group for now - # - pytest -v -s models/decoder_only/audio_language -m quant_model - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model + - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' + - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' + - pytest -v -s models/encoder_decoder/language -m core_model + - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Decoder-only Multi-Modal Models Test (Extended) # 1h10m +- label: Multi-Modal Models Test (Extended) # 1h15m nightly: true source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language + - tests/models/embedding/vision_language + - tests/models/encoder_decoder/vision_language commands: - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - -- label: Other Models Test # 20min - #mirror_hardwares: [amd] - source_file_dependencies: - - vllm/ - - tests/models/embedding/language - - tests/models/embedding/vision_language - - tests/models/encoder_decoder/language - - tests/models/encoder_decoder/vision_language - commands: - - pytest -v -s models/embedding/language - - pytest -v -s models/embedding/vision_language - - pytest -v -s models/encoder_decoder/language - - pytest -v -s models/encoder_decoder/vision_language + - pytest -v -s models/encoder_decoder/language -m 'not core_model' + - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 161733c049bbe..a76bb775c6ee6 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -330,11 +330,16 @@ Text Embedding - :code:`BAAI/bge-multilingual-gemma2`, etc. - - ✅︎ - * - :code:`MistralModel` - - Mistral-based + * - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc. + - Llama-based - :code:`intfloat/e5-mistral-7b-instruct`, etc. - ✅︎ - ✅︎ + * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` + - Qwen2-based + - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc. + - ✅︎ + - ✅︎ .. important:: Some model architectures support both generation and embedding tasks. @@ -355,7 +360,7 @@ Reward Modeling * - :code:`Qwen2ForRewardModel` - Qwen2-based - :code:`Qwen/Qwen2.5-Math-RM-72B`, etc. - - + - ✅︎ - ✅︎ .. note:: @@ -376,7 +381,7 @@ Classification * - :code:`Qwen2ForSequenceClassification` - Qwen2-based - :code:`jason9693/Qwen2.5-1.5B-apeach`, etc. - - + - ✅︎ - ✅︎ .. note:: diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 384ec77e5455a..6542689c3f277 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -33,6 +33,10 @@ def test_models( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -293,17 +297,3 @@ def test_jamba_distributed_produces_identical_generation( name_0="vllm_tp_1", name_1="vllm_tp_2", ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_model_print( - vllm_runner, - model: str, - dtype: str, -) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 2dc231c595ffa..78eab8d5354fd 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -51,6 +51,10 @@ def test_models( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -279,17 +283,3 @@ def test_state_cleanup( except ValueError: pytest.fail("Mamba inner state wasn't cleaned up between states, " "could be related to finished_requests_ids") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_model_print( - vllm_runner, - model: str, - dtype: str, -) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index beb1ffb18436e..2a7ed8826d2f3 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -4,37 +4,52 @@ """ import pytest -from vllm.platforms import current_platform - from ...utils import check_logprobs_close -MODELS = [ - "facebook/opt-125m", # opt - "openai-community/gpt2", # gpt2 - # "Milos/slovak-gpt-j-405M", # gptj - # "bigcode/tiny_starcoder_py", # gpt_bigcode - # "EleutherAI/pythia-70m", # gpt_neox - "bigscience/bloom-560m", # bloom - testing alibi slopes - "microsoft/phi-2", # phi - # "stabilityai/stablelm-3b-4e1t", # stablelm - # "bigcode/starcoder2-3b", # starcoder2 - "google/gemma-1.1-2b-it", # gemma - "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 - "meta-llama/Llama-3.2-1B-Instruct", # llama -] - -if not current_platform.is_cpu(): - MODELS += [ - # fused_moe which not supported on CPU - "openbmb/MiniCPM3-4B", - ] - -target_dtype = "half" - -@pytest.mark.core_model -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize( + "model", + [ + pytest.param( + "bigscience/bloom-560m", # bloom - testing alibi slopes + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), + pytest.param( + "openai-community/gpt2", # gpt2 + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), + pytest.param("Milos/slovak-gpt-j-405M"), # gptj + pytest.param("bigcode/tiny_starcoder_py"), # gpt_bigcode + pytest.param("EleutherAI/pythia-70m"), # gpt_neox + pytest.param( + "google/gemma-1.1-2b-it", # gemma + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), + pytest.param( + "meta-llama/Llama-3.2-1B-Instruct", # llama + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), + pytest.param( + "openbmb/MiniCPM3-4B", + # fused_moe not supported on CPU + marks=[pytest.mark.core_model], + ), + pytest.param( + "facebook/opt-125m", # opt + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), + pytest.param( + "microsoft/phi-2", # phi + marks=[pytest.mark.core_model], + ), + pytest.param( + "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 + marks=[pytest.mark.core_model], + ), + pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm + pytest.param("bigcode/starcoder2-3b"), # starcoder2 + ]) +@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models( diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py index 40ee49cf60742..6321503e7b248 100644 --- a/tests/models/embedding/language/test_cls_models.py +++ b/tests/models/embedding/language/test_cls_models.py @@ -9,10 +9,14 @@ import torch from transformers import AutoModelForSequenceClassification -CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"] - -@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) +@pytest.mark.parametrize( + "model", + [ + pytest.param("jason9693/Qwen2.5-1.5B-apeach", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + ], +) @pytest.mark.parametrize("dtype", ["float"]) def test_classification_models( hf_runner, @@ -23,31 +27,19 @@ def test_classification_models( ) -> None: with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSequenceClassification) as hf_model: hf_outputs = hf_model.classify(example_prompts) - print(hf_outputs, vllm_outputs) - # check logits difference for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) assert torch.allclose(hf_output, vllm_output, 1e-3) - - -@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_classification_model_print( - vllm_runner, - model: str, - dtype: str, -) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index fcdd684168d04..c3f351ef707be 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -4,25 +4,25 @@ """ import pytest -from vllm.utils import current_platform - from ..utils import check_embeddings_close -# Model, Guard -MODELS = [ - "intfloat/e5-mistral-7b-instruct", - "BAAI/bge-base-en-v1.5", - "BAAI/bge-multilingual-gemma2", - "intfloat/multilingual-e5-large", -] - -ENCODER_ONLY = [ - "BAAI/bge-base-en-v1.5", - "intfloat/multilingual-e5-large", -] - -@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize( + "model", + [ + # [Encoder-only] + pytest.param("BAAI/bge-base-en-v1.5", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("intfloat/multilingual-e5-large"), + # [Encoder-decoder] + pytest.param("intfloat/e5-mistral-7b-instruct", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("BAAI/bge-multilingual-gemma2", + marks=[pytest.mark.core_model]), + pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), + pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), + ], +) @pytest.mark.parametrize("dtype", ["half"]) def test_models( hf_runner, @@ -31,9 +31,6 @@ def test_models( model, dtype: str, ) -> None: - if model not in ENCODER_ONLY and current_platform.is_cpu(): - pytest.skip("Skip large embedding models test on CPU.") - # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -46,8 +43,13 @@ def test_models( is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, dtype=dtype, max_model_len=None) as vllm_model: + with vllm_runner(model, task="embedding", dtype=dtype, + max_model_len=None) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) check_embeddings_close( embeddings_0_lst=hf_outputs, diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index 9fab5898a06ba..329c6ba279f89 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -88,6 +88,7 @@ def _run_test( @pytest.mark.skipif(transformers.__version__.startswith("4.46"), reason="Model broken with changes in transformers 4.46") +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models_text( @@ -112,6 +113,7 @@ def test_models_text( @large_gpu_test(min_gb=48) +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models_image( diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py index ee411472ba284..6145aff1a5ea2 100644 --- a/tests/models/embedding/vision_language/test_phi3v.py +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -74,6 +74,7 @@ def _run_test( ) +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models_text( @@ -98,6 +99,7 @@ def test_models_text( @large_gpu_test(min_gb=48) +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models_image( diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 8e8862fadbf04..10aba8427944f 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -14,8 +14,6 @@ from ....utils import multi_gpu_test from ...utils import check_logprobs_close -MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] - def vllm_to_hf_output( vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -170,7 +168,14 @@ def run_test( ) -@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize( + "model", + [ + pytest.param("facebook/bart-base", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("facebook/bart-large-cnn"), + ], +) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index a3b1c0950d9a2..77dd1d81f84d7 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -233,6 +233,7 @@ def clear_cache(): @large_gpu_test(min_gb=48) +@pytest.mark.core_model @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "sizes", @@ -278,6 +279,7 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, @large_gpu_test(min_gb=48) +@pytest.mark.core_model @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @@ -326,6 +328,7 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, @large_gpu_test(min_gb=48) +@pytest.mark.core_model @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) diff --git a/tests/models/registry.py b/tests/models/registry.py index ec9ff52d112df..3848367b6126c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -129,9 +129,13 @@ class _HfExamplesInfo: # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), + "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), + "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 + "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"), # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index dbc415796ee55..e462dae3dc688 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -77,8 +77,8 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda): def test_hf_registry_coverage(): - untested_archs = (HF_EXAMPLE_MODELS.get_supported_archs() - - set(ModelRegistry.get_supported_archs())) + untested_archs = (ModelRegistry.get_supported_archs() - + HF_EXAMPLE_MODELS.get_supported_archs()) assert not untested_archs, ( "Please add the following architectures to " diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b623c576bb673..431e397e1e10d 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -37,6 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -44,8 +45,9 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, @@ -247,6 +249,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = {} is less than " + "`num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -405,20 +419,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): - raise ValueError("Sliding window for some but all layers is not " - "supported. This model uses sliding window " - "but `max_window_layers` = {} is less than " - "`num_hidden_layers` = {}. Please open an issue " - "to discuss this feature.".format( - config.max_window_layers, - config.num_hidden_layers, - )) + pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config @@ -438,6 +441,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() + + # The same model class supports both language generation and embedding + # because the architecture name is the same + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False) + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -475,6 +487,13 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, @@ -482,3 +501,70 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.tie_word_embeddings else None), ) loader.load_weights(weights) + + +class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.MEAN, + normalize=True, + softmax=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self, + ignore_unexpected_prefixes=["lm_head."]) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 27eb7e8a93975..120403e948686 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -17,10 +17,11 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput +from .interfaces import SupportsLoRA, SupportsPP from .utils import AutoWeightsLoader, maybe_prefix -class Qwen2ForSequenceClassification(nn.Module): +class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -46,21 +47,9 @@ class Qwen2ForSequenceClassification(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config pooler_config = vllm_config.model_config.pooler_config - # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): - raise ValueError("Sliding window for some but all layers is not " - "supported. This model uses sliding window " - "but `max_window_layers` = {} is less than " - "`num_hidden_layers` = {}. Please open an issue " - "to discuss this feature.".format( - config.max_window_layers, - config.num_hidden_layers, - )) self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 89768ec9dff37..55843d8325348 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -16,7 +16,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix @@ -32,7 +32,7 @@ def forward(self, input): return self.activation(input) -class Qwen2ForRewardModel(nn.Module, SupportsPP): +class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -58,21 +58,9 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config pooler_config = vllm_config.model_config.pooler_config - # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): - raise ValueError("Sliding window for some but all layers is not " - "supported. This model uses sliding window " - "but `max_window_layers` = {} is less than " - "`num_hidden_layers` = {}. Please open an issue " - "to discuss this feature.".format( - config.max_window_layers, - config.num_hidden_layers, - )) self.config = config self.lora_config = lora_config diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c0d503a1c5ba2..22c2e328bfb65 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -11,7 +11,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import lru_cache -from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type, + TypeVar, Union) import cloudpickle import torch.nn as nn @@ -110,6 +111,8 @@ }, "MistralModel": ("llama", "LlamaEmbeddingModel"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501 # [Multimodal] @@ -301,8 +304,8 @@ class _ModelRegistry: # Keyed by model_arch models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) - def get_supported_archs(self) -> List[str]: - return list(self.models.keys()) + def get_supported_archs(self) -> AbstractSet[str]: + return self.models.keys() def register_model( self,