Skip to content

Commit

Permalink
fix: openai server (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Sep 29, 2023
1 parent cbeeabe commit 69a4c32
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 56 deletions.
44 changes: 44 additions & 0 deletions aphrodite/common/logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
import torch
from typing import Dict


class LogitsProcessor(ABC):

@abstractmethod
def __call__(self, logits: torch.tensor) -> torch.tensor:
pass


class BiasLogitsProcessor(LogitsProcessor):
"""This is to enable logit_bias in the OpenAI server.
biases is a dict where each value is -100 to 100
according to the OpenAI API docs.
Args:
biases: Dict ov values from -100 to 100 to scale the
probability of a token being generated.
Each key of the dict coresponds to the the token id.
"""

def __init__(self, biases: Dict[int, float]):
self.biases = biases

if not biases:
return

self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
self.values = torch.tensor(list(self.biases.values()),
dtype=torch.long)

def __call__(self, logits):
if not self.biases:
return logits

values = self.values.to(logits.device)
keys = self.keys.to(logits.device)

update_factors = torch.where(values >= 0, 1 + (values / 100),
1 / (1 - (values / 100)))
logits[0, keys] *= update_factors

return logits
12 changes: 11 additions & 1 deletion aphrodite/common/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union
from aphrodite.common.logits_processor import LogitsProcessor

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -60,6 +61,10 @@ class SamplingParams:
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
skip_special_tokens: Whether to skip special tokens in the output.
defaults to true.
logits_processors: List of LogitsProcessors to change the probability
of token prediction at runtime.
"""

def __init__(
Expand All @@ -79,6 +84,8 @@ def __init__(
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
logits_processors: List[LogitsProcessor] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand All @@ -103,6 +110,8 @@ def __init__(
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
self.skip_special_tokens = skip_special_tokens
self.logits_processors = logits_processors

self._verify_args()
if self.use_beam_search:
Expand Down Expand Up @@ -196,4 +205,5 @@ def __repr__(self) -> str:
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs})")
f"logprobs={self.logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})")
43 changes: 17 additions & 26 deletions aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ async def check_length(
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)

if request.max_tokens is None:
request.max_tokens = max_model_len - token_num
if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
Expand Down Expand Up @@ -196,12 +198,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
if error_check_ret is not None:
return error_check_ret


prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret

if not request.logit_bias:
logit_processors = []
else:
Expand All @@ -210,6 +206,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
request.logit_bias.items()))
logit_processors = [BiasLogitsProcessor(biases)]

prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret

model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
Expand All @@ -221,11 +222,13 @@ async def create_chat_completion(request: ChatCompletionRequest,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
logits_processors=logit_processors,
)
except ValueError as e:
Expand All @@ -234,9 +237,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)

async def abort_request() -> None:
await engine.abort(request_id)

def create_stream_response_json(
index: int,
text: str,
Expand Down Expand Up @@ -296,19 +296,15 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
media_type="text/event-stream")

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
Expand Down Expand Up @@ -361,7 +357,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the engine does not currently support
- echo (since the Aphrodite engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
Expand All @@ -373,7 +369,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return error_check_ret

if request.echo:
# We do not support echo since the engine does not
# We do not support echo since the Aphrodite engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
Expand Down Expand Up @@ -432,10 +428,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
logits_processors=logit_processors,
)
except ValueError as e:
Expand All @@ -456,9 +454,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)

async def abort_request() -> None:
await engine.abort(request_id)

def create_stream_response_json(
index: int,
text: str,
Expand Down Expand Up @@ -518,19 +513,15 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
media_type="text/event-stream")

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
Expand Down
46 changes: 21 additions & 25 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field

from aphrodite.common.utils import random_uuid


class ErrorResponse(BaseModel):
object: str = "error"
message: str
Expand Down Expand Up @@ -55,7 +58,7 @@ class ChatCompletionRequest(BaseModel):
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = 16
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
Expand All @@ -66,10 +69,14 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True


class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
# a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
Expand All @@ -84,17 +91,19 @@ class CompletionRequest(BaseModel):
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by Aphrodite
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True


class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)


class CompletionResponseChoice(BaseModel):
Expand Down Expand Up @@ -127,15 +136,18 @@ class CompletionStreamResponse(BaseModel):
model: str
choices: List[CompletionResponseStreamChoice]


class ChatMessage(BaseModel):
role: str
content: str


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
Expand All @@ -144,37 +156,21 @@ class ChatCompletionResponse(BaseModel):
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo


class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None


class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl={random_uuid()}")
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

class TokenCheckRequestItem(BaseModel):
model: str
prompt: str
max_tokens: int


class TokenCheckRequest(BaseModel):
prompts: List[TokenCheckRequestItem]


class TokenCheckResponseItem(BaseModel):
fits: bool
tokenCount: int
contextLength: int


class TokenCheckResponse(BaseModel):
prompts: List[TokenCheckResponseItem]
choices: List[ChatCompletionResponseStreamChoice]
7 changes: 4 additions & 3 deletions aphrodite/engine/aphrodite_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _process_sequence_group_samples(
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
self._decode_sequence(seq)
self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)

# Non-beam search case
Expand Down Expand Up @@ -623,7 +623,8 @@ def _log_system_stats(
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now

def _decode_sequence(self, seq: Sequence) -> None:
def _decode_sequence(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
Expand All @@ -632,7 +633,7 @@ def _decode_sequence(self, seq: Sequence) -> None:
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=True,
skip_special_tokens=sampling_params.skip_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens
Expand Down
2 changes: 1 addition & 1 deletion aphrodite/engine/args_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def create_engine_configs(

@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
"""Arguments for asynchronous Aohrodite engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None
Expand Down
Loading

0 comments on commit 69a4c32

Please sign in to comment.