Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LLM and AsyncLLM to expose more functionality #90

Open
wants to merge 2 commits into
base: MLPerf_4.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ def has_unfinished_seqs(self) -> bool:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

def get_num_waiting_seq_groups(self) -> int:
return len(self.waiting)

def _schedule_running(
self,
running_queue: deque,
Expand Down
32 changes: 26 additions & 6 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,28 @@ def _raise_exception_on_finish(

class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
that can be iterated over asynchronously.
Args:
first_item_only: Only emit the first and the finished request to the queue.
"""

def __init__(self, request_id: str) -> None:
def __init__(self, request_id: str, first_item_only: bool = False) -> None:
self.request_id = request_id
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
self._first_item_only = first_item_only
self._first_item = first_item_only

def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
if self._first_item_only:
if self._first_item or item.finished:
self._first_item = False
self._queue.put_nowait(item)
else:
self._queue.put_nowait(item)

def finish(self) -> None:
self._queue.put_nowait(StopAsyncIteration())
Expand Down Expand Up @@ -135,14 +145,16 @@ def process_exception(self,
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)

def add_request(self, request_id: str,
def add_request(self,
request_id: str,
first_token_only: bool = False,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")

stream = AsyncStream(request_id)
stream = AsyncStream(request_id, first_token_only)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
Expand Down Expand Up @@ -223,7 +235,7 @@ async def step_async(
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
output = await self.model_executor.execute_model_async(
output = self.model_executor.execute_model(
execute_model_req)
else:
output = []
Expand Down Expand Up @@ -534,6 +546,7 @@ async def add_request(
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
first_token_only: bool = False,
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
Expand Down Expand Up @@ -587,6 +600,7 @@ async def add_request(
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
first_token_only=first_token_only,
)

return stream
Expand All @@ -597,6 +611,7 @@ async def generate(
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
first_token_only: bool = False,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.

Expand All @@ -611,6 +626,7 @@ async def generate(
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
first_token_only: Only return the first token(s) and the final output

Yields:
The output `RequestOutput` objects from the LLMEngine
Expand Down Expand Up @@ -663,6 +679,7 @@ async def generate(
request_id,
inputs,
sampling_params,
first_token_only=first_token_only,
lora_request=lora_request,
):
yield LLMEngine.validate_output(output, RequestOutput)
Expand Down Expand Up @@ -737,6 +754,7 @@ async def encode(
request_id,
inputs,
pooling_params,
first_token_only=False,
lora_request=lora_request,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
Expand All @@ -748,6 +766,7 @@ async def _process_request(
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
first_token_only: bool = False,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
Expand All @@ -759,6 +778,7 @@ async def _process_request(
params,
arrival_time=arrival_time,
lora_request=lora_request,
first_token_only=first_token_only,
)

try:
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,10 @@ def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()

def get_num_waiting_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_waiting_seq_groups()

def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
Expand Down
27 changes: 27 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 32768,
disable_custom_all_reduce: bool = False,
sample_injection_cb = None,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
Expand Down Expand Up @@ -145,6 +146,9 @@ def __init__(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()

self.sample_injection_cb = sample_injection_cb


def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer.tokenizer
Expand Down Expand Up @@ -549,6 +553,12 @@ def _run_engine(
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_toks = 0
while self.llm_engine.has_unfinished_requests():
if self.sample_injection_cb:
num_added_requests = self.sample_injection_cb()
if num_added_requests < 0:
self.sample_injection_cb = None
if use_tqdm:
pbar.total += num_added_requests
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
Expand All @@ -567,3 +577,20 @@ def _run_engine(
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))

def add_prompt_token_ids(
self,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,):
inputs = self._convert_v1_inputs(
prompts=None,
prompt_token_ids=prompt_token_ids,
multi_modal_data=None,
)

self._validate_and_add_requests(
inputs=inputs,
params=sampling_params,
lora_request=None,
)