From ecb818fd25c65e1b543ca41516f1e96f2b0e8e80 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 31 Dec 2024 15:31:12 -0700 Subject: [PATCH 1/6] :arrow_up: Unpin vllm Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 330d01a..51f7702 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ ] dependencies = [ - "vllm>=0.6.2,<0.6.5" + "vllm>=0.6.2" ] [project.optional-dependencies] From 0b2c9be562f88444064b219c5b733cce9a9d4e82 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 31 Dec 2024 15:31:42 -0700 Subject: [PATCH 2/6] :white_check_mark::wrench: Update mock model configs Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- tests/generative_detectors/test_base.py | 13 +++++++++++++ tests/generative_detectors/test_granite_guardian.py | 13 +++++++++++++ tests/generative_detectors/test_llama_guard.py | 13 +++++++++++++ 3 files changed, 39 insertions(+) diff --git a/tests/generative_detectors/test_base.py b/tests/generative_detectors/test_base.py index 60b10dd..c777bfd 100644 --- a/tests/generative_detectors/test_base.py +++ b/tests/generative_detectors/test_base.py @@ -1,4 +1,5 @@ # Standard +from typing import Optional from dataclasses import dataclass import asyncio @@ -15,6 +16,9 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +@dataclass +class MockTokenizer: + type: Optional[str] = None @dataclass class MockHFConfig: @@ -23,6 +27,7 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -30,7 +35,13 @@ class MockModelConfig: tokenizer_revision = None embedding_mode = False multimodal_config = MultiModalConfig() + diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() + logits_processor_pattern = None + allowed_local_media_path: str = "" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} @dataclass @@ -42,6 +53,7 @@ async def get_model_config(self): async def _async_serving_detection_completion_init(): """Initialize a chat completion base with string templates""" engine = MockEngine() + engine.errored = False model_config = await engine.get_model_config() detection_completion = ChatCompletionDetectionBase( @@ -52,6 +64,7 @@ async def _async_serving_detection_completion_init(): base_model_paths=BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None, diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index 7afb02a..fadcdf9 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -1,4 +1,5 @@ # Standard +from typing import Optional from dataclasses import dataclass from http import HTTPStatus from unittest.mock import patch @@ -32,6 +33,9 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +@dataclass +class MockTokenizer: + type: Optional[str] = None @dataclass class MockHFConfig: @@ -40,6 +44,7 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -47,7 +52,13 @@ class MockModelConfig: tokenizer_revision = None embedding_mode = False multimodal_config = MultiModalConfig() + diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() + logits_processor_pattern = None + allowed_local_media_path: str = "" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} @dataclass @@ -59,6 +70,7 @@ async def get_model_config(self): async def _granite_guardian_init(): """Initialize a granite guardian""" engine = MockEngine() + engine.errored = False model_config = await engine.get_model_config() granite_guardian = GraniteGuardian( @@ -69,6 +81,7 @@ async def _granite_guardian_init(): base_model_paths=BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None, diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py index 3d59705..2bc8f13 100644 --- a/tests/generative_detectors/test_llama_guard.py +++ b/tests/generative_detectors/test_llama_guard.py @@ -1,4 +1,5 @@ # Standard +from typing import Optional from dataclasses import dataclass from http import HTTPStatus from unittest.mock import patch @@ -32,6 +33,9 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +@dataclass +class MockTokenizer: + type: Optional[str] = None @dataclass class MockHFConfig: @@ -40,6 +44,7 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -47,7 +52,13 @@ class MockModelConfig: tokenizer_revision = None embedding_mode = False multimodal_config = MultiModalConfig() + diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() + logits_processor_pattern = None + allowed_local_media_path: str = "" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} @dataclass @@ -59,6 +70,7 @@ async def get_model_config(self): async def _llama_guard_init(): """Initialize a llama guard""" engine = MockEngine() + engine.errored = False model_config = await engine.get_model_config() llama_guard_detection = LlamaGuard( @@ -69,6 +81,7 @@ async def _llama_guard_init(): base_model_paths=BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None, From f0a70b51d05b3a8928cb7a4c1a3bdde6998bc438 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:30:50 -0700 Subject: [PATCH 3/6] :white_check_mark: Update test for extra fields Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- tests/test_protocol.py | 9 +++++++-- vllm_detector_adapter/protocol.py | 5 ++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 93ae9e8..9004962 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -54,8 +54,13 @@ def test_detection_to_completion_request_unknown_params(): detector_params={"moo": 2}, ) request = chat_request.to_chat_completion_request(MODEL_NAME) - assert type(request) == ErrorResponse - assert request.code == HTTPStatus.BAD_REQUEST.value + from importlib import metadata + if metadata.version("vllm") >= "0.6.5": + # As of vllm >= 0.6.5, extra fields are allowed + assert type(request) == ChatCompletionRequest + else: + assert type(request) == ErrorResponse + assert request.code == HTTPStatus.BAD_REQUEST.value def test_response_from_completion_response(): diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index c952b21..2c6bb92 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -68,9 +68,8 @@ def to_chat_completion_request(self, model_name: str): ] # Try to pass all detector_params through as additional parameters to chat completions. - # This will error if extra unallowed parameters are included. We do not try to provide - # validation or changing of parameters here to not be dependent on chat completion API - # changes + # We do not try to provide validation or changing of parameters here to not be dependent + # on chat completion API changes. As of vllm >= 0.6.5, extra fields are allowed try: return ChatCompletionRequest( messages=messages, From 1964f07eee4a107141320dbe9b3c02bdccc3c192 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:44:16 -0700 Subject: [PATCH 4/6] :arrow_up: Upgrade lower bound of vllm Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- pyproject.toml | 2 +- tests/test_protocol.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51f7702..9ce45d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ ] dependencies = [ - "vllm>=0.6.2" + "vllm>=0.6.5" ] [project.optional-dependencies] diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 9004962..8480691 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -54,13 +54,8 @@ def test_detection_to_completion_request_unknown_params(): detector_params={"moo": 2}, ) request = chat_request.to_chat_completion_request(MODEL_NAME) - from importlib import metadata - if metadata.version("vllm") >= "0.6.5": - # As of vllm >= 0.6.5, extra fields are allowed - assert type(request) == ChatCompletionRequest - else: - assert type(request) == ErrorResponse - assert request.code == HTTPStatus.BAD_REQUEST.value + # As of vllm >= 0.6.5, extra fields are allowed + assert type(request) == ChatCompletionRequest def test_response_from_completion_response(): From 66b9acf87a0ac5442e602e5e5e12ca0bb5b16998 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:53:21 -0700 Subject: [PATCH 5/6] :fire: Remove error on extra params tests Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- tests/generative_detectors/test_base.py | 4 +++- .../test_granite_guardian.py | 19 +++---------------- .../generative_detectors/test_llama_guard.py | 17 +++-------------- 3 files changed, 9 insertions(+), 31 deletions(-) diff --git a/tests/generative_detectors/test_base.py b/tests/generative_detectors/test_base.py index c777bfd..0805714 100644 --- a/tests/generative_detectors/test_base.py +++ b/tests/generative_detectors/test_base.py @@ -1,6 +1,6 @@ # Standard -from typing import Optional from dataclasses import dataclass +from typing import Optional import asyncio # Third Party @@ -16,10 +16,12 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + @dataclass class MockTokenizer: type: Optional[str] = None + @dataclass class MockHFConfig: model_type: str = "any" diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index fadcdf9..7862684 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -1,7 +1,7 @@ # Standard -from typing import Optional from dataclasses import dataclass from http import HTTPStatus +from typing import Optional from unittest.mock import patch import asyncio @@ -33,10 +33,12 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + @dataclass class MockTokenizer: type: Optional[str] = None + @dataclass class MockHFConfig: model_type: str = "any" @@ -242,18 +244,3 @@ def test_chat_detection_errors_on_stream(granite_guardian_detection): assert type(detection_response) == ErrorResponse assert detection_response.code == HTTPStatus.BAD_REQUEST.value assert "streaming is not supported" in detection_response.message - - -def test_chat_detection_with_extra_unallowed_params(granite_guardian_detection): - granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) - chat_request = ChatDetectionRequest( - messages=[ - DetectionChatMessageParam(role="user", content="How do I pick a lock?") - ], - detector_params={"boo": 3}, # unallowed param - ) - detection_response = asyncio.run( - granite_guardian_detection_instance.chat(chat_request) - ) - assert type(detection_response) == ErrorResponse - assert detection_response.code == HTTPStatus.BAD_REQUEST.value diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py index 2bc8f13..92e9954 100644 --- a/tests/generative_detectors/test_llama_guard.py +++ b/tests/generative_detectors/test_llama_guard.py @@ -1,7 +1,7 @@ # Standard -from typing import Optional from dataclasses import dataclass from http import HTTPStatus +from typing import Optional from unittest.mock import patch import asyncio @@ -33,10 +33,12 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + @dataclass class MockTokenizer: type: Optional[str] = None + @dataclass class MockHFConfig: model_type: str = "any" @@ -190,16 +192,3 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response): assert detection_0["detection"] == "safe" assert detection_0["detection_type"] == "risk" assert pytest.approx(detection_0["score"]) == 0.001346767 - - -def test_chat_detection_with_extra_unallowed_params(llama_guard_detection): - llama_guard_detection_instance = asyncio.run(llama_guard_detection) - chat_request = ChatDetectionRequest( - messages=[ - DetectionChatMessageParam(role="user", content="How do I search for moose?") - ], - detector_params={"moo": "unallowed"}, # unallowed param - ) - detection_response = asyncio.run(llama_guard_detection_instance.chat(chat_request)) - assert type(detection_response) == ErrorResponse - assert detection_response.code == HTTPStatus.BAD_REQUEST.value From 15bf20ca14470a8c6a68279aa8f847bf6d4e7506 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 3 Jan 2025 14:07:48 -0700 Subject: [PATCH 6/6] :recycle: API server updates Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- vllm_detector_adapter/api_server.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/vllm_detector_adapter/api_server.py b/vllm_detector_adapter/api_server.py index bff5ffb..51ce754 100644 --- a/vllm_detector_adapter/api_server.py +++ b/vllm_detector_adapter/api_server.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import nullable_str from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai import api_server @@ -61,6 +62,16 @@ def init_app_state_with_detectors( BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] + resolved_chat_template = load_chat_template(args.chat_template) + # Post-0.6.6 incoming change for vllm - ref. https://github.com/vllm-project/vllm/pull/11660 + # Will be included after an official release includes this refactor + # state.openai_serving_models = OpenAIServingModels( + # model_config=model_config, + # base_model_paths=base_model_paths, + # lora_modules=args.lora_modules, + # prompt_adapters=args.prompt_adapters, + # ) + # Use vllm app state init api_server.init_app_state(engine_client, model_config, state, args) @@ -72,15 +83,18 @@ def init_app_state_with_detectors( args.output_template, engine_client, model_config, - base_model_paths, + base_model_paths, # Not present in post-0.6.6 incoming change + # state.openai_serving_models, # Post-0.6.6 incoming change args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, + lora_modules=args.lora_modules, # Not present in post-0.6.6 incoming change + prompt_adapters=args.prompt_adapters, # Not present in post-0.6.6 incoming change request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, )