Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⬆️✅ Support 0.6.5+ vllm #7

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
]

dependencies = [
"vllm>=0.6.2,<0.6.5"
"vllm>=0.6.5"
]

[project.optional-dependencies]
Expand Down
15 changes: 15 additions & 0 deletions tests/generative_detectors/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard
from dataclasses import dataclass
from typing import Optional
import asyncio

# Third Party
Expand All @@ -16,21 +17,33 @@
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]


@dataclass
class MockTokenizer:
type: Optional[str] = None


@dataclass
class MockHFConfig:
model_type: str = "any"


@dataclass
class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
diff_sampling_param: Optional[dict] = None
hf_config = MockHFConfig()
logits_processor_pattern = None
allowed_local_media_path: str = ""

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


@dataclass
Expand All @@ -42,6 +55,7 @@ async def get_model_config(self):
async def _async_serving_detection_completion_init():
"""Initialize a chat completion base with string templates"""
engine = MockEngine()
engine.errored = False
model_config = await engine.get_model_config()

detection_completion = ChatCompletionDetectionBase(
Expand All @@ -52,6 +66,7 @@ async def _async_serving_detection_completion_init():
base_model_paths=BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None,
Expand Down
30 changes: 15 additions & 15 deletions tests/generative_detectors/test_granite_guardian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional
from unittest.mock import patch
import asyncio

Expand Down Expand Up @@ -33,21 +34,33 @@
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]


@dataclass
class MockTokenizer:
type: Optional[str] = None


@dataclass
class MockHFConfig:
model_type: str = "any"


@dataclass
class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
diff_sampling_param: Optional[dict] = None
hf_config = MockHFConfig()
logits_processor_pattern = None
allowed_local_media_path: str = ""

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


@dataclass
Expand All @@ -59,6 +72,7 @@ async def get_model_config(self):
async def _granite_guardian_init():
"""Initialize a granite guardian"""
engine = MockEngine()
engine.errored = False
model_config = await engine.get_model_config()

granite_guardian = GraniteGuardian(
Expand All @@ -69,6 +83,7 @@ async def _granite_guardian_init():
base_model_paths=BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None,
Expand Down Expand Up @@ -229,18 +244,3 @@ def test_chat_detection_errors_on_stream(granite_guardian_detection):
assert type(detection_response) == ErrorResponse
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
assert "streaming is not supported" in detection_response.message


def test_chat_detection_with_extra_unallowed_params(granite_guardian_detection):
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
chat_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(role="user", content="How do I pick a lock?")
],
detector_params={"boo": 3}, # unallowed param
)
detection_response = asyncio.run(
granite_guardian_detection_instance.chat(chat_request)
)
assert type(detection_response) == ErrorResponse
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
28 changes: 15 additions & 13 deletions tests/generative_detectors/test_llama_guard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional
from unittest.mock import patch
import asyncio

Expand Down Expand Up @@ -33,21 +34,33 @@
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]


@dataclass
class MockTokenizer:
type: Optional[str] = None


@dataclass
class MockHFConfig:
model_type: str = "any"


@dataclass
class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
diff_sampling_param: Optional[dict] = None
hf_config = MockHFConfig()
logits_processor_pattern = None
allowed_local_media_path: str = ""

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


@dataclass
Expand All @@ -59,6 +72,7 @@ async def get_model_config(self):
async def _llama_guard_init():
"""Initialize a llama guard"""
engine = MockEngine()
engine.errored = False
model_config = await engine.get_model_config()

llama_guard_detection = LlamaGuard(
Expand All @@ -69,6 +83,7 @@ async def _llama_guard_init():
base_model_paths=BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None,
Expand Down Expand Up @@ -177,16 +192,3 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response):
assert detection_0["detection"] == "safe"
assert detection_0["detection_type"] == "risk"
assert pytest.approx(detection_0["score"]) == 0.001346767


def test_chat_detection_with_extra_unallowed_params(llama_guard_detection):
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
chat_request = ChatDetectionRequest(
messages=[
DetectionChatMessageParam(role="user", content="How do I search for moose?")
],
detector_params={"moo": "unallowed"}, # unallowed param
)
detection_response = asyncio.run(llama_guard_detection_instance.chat(chat_request))
assert type(detection_response) == ErrorResponse
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
4 changes: 2 additions & 2 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_detection_to_completion_request_unknown_params():
detector_params={"moo": 2},
)
request = chat_request.to_chat_completion_request(MODEL_NAME)
assert type(request) == ErrorResponse
assert request.code == HTTPStatus.BAD_REQUEST.value
# As of vllm >= 0.6.5, extra fields are allowed
assert type(request) == ChatCompletionRequest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. this will change the general API behavior from our side. Does orchestrator expects bad request in such scenario or passthrough?

Copy link
Collaborator Author

@evaline-ju evaline-ju Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will just cause a passthrough of the variable from my testing. My worry is that adding additional validation when vllm and openAI allow passthrough is then we're even more tied to small API changes (like tracking all expected fields)



def test_response_from_completion_response():
Expand Down
22 changes: 18 additions & 4 deletions vllm_detector_adapter/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.config import ModelConfig
from vllm.engine.arg_utils import nullable_str
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai import api_server
Expand Down Expand Up @@ -61,6 +62,16 @@ def init_app_state_with_detectors(
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]

resolved_chat_template = load_chat_template(args.chat_template)
# Post-0.6.6 incoming change for vllm - ref. https://github.com/vllm-project/vllm/pull/11660
# Will be included after an official release includes this refactor
# state.openai_serving_models = OpenAIServingModels(
# model_config=model_config,
# base_model_paths=base_model_paths,
# lora_modules=args.lora_modules,
# prompt_adapters=args.prompt_adapters,
# )

# Use vllm app state init
api_server.init_app_state(engine_client, model_config, state, args)

Expand All @@ -72,15 +83,18 @@ def init_app_state_with_detectors(
args.output_template,
engine_client,
model_config,
base_model_paths,
base_model_paths, # Not present in post-0.6.6 incoming change
# state.openai_serving_models, # Post-0.6.6 incoming change
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
lora_modules=args.lora_modules, # Not present in post-0.6.6 incoming change
prompt_adapters=args.prompt_adapters, # Not present in post-0.6.6 incoming change
request_logger=request_logger,
chat_template=args.chat_template,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
)


Expand Down
5 changes: 2 additions & 3 deletions vllm_detector_adapter/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def to_chat_completion_request(self, model_name: str):
]

# Try to pass all detector_params through as additional parameters to chat completions.
# This will error if extra unallowed parameters are included. We do not try to provide
# validation or changing of parameters here to not be dependent on chat completion API
# changes
# We do not try to provide validation or changing of parameters here to not be dependent
# on chat completion API changes. As of vllm >= 0.6.5, extra fields are allowed
try:
return ChatCompletionRequest(
messages=messages,
Expand Down
Loading