Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
Add timeouts and early stopping to backend (#87)
Browse files Browse the repository at this point in the history
Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored May 30, 2023
1 parent 0d65332 commit f4673f1
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 53 deletions.
15 changes: 13 additions & 2 deletions aviary/backend/llm/initializers/_llama_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import time
import uuid
from typing import Dict, Iterator, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Tuple, Union

from llama_cpp import Completion, CompletionChunk, CompletionLogprobs, Llama, llama_cpp

Expand All @@ -20,7 +20,8 @@ def __call__(
prompt: str,
suffix: Optional[str] = None,
max_tokens: int = 128,
min_tokens: int = 0,
min_tokens: int = 0, # new aviary argument
max_time_criteria: Optional[Tuple[float, float]] = None, # new aviary argument
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
Expand All @@ -41,6 +42,7 @@ def __call__(
suffix=suffix,
max_tokens=max_tokens,
min_tokens=min_tokens,
max_time_criteria=max_time_criteria,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
Expand All @@ -63,6 +65,7 @@ def create_completion(
suffix: Optional[str] = None,
max_tokens: int = 128,
min_tokens: int = 0,
max_time_criteria: Optional[Tuple[float, float]] = None, # new aviary argument
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
Expand All @@ -83,6 +86,7 @@ def create_completion(
suffix=suffix,
max_tokens=max_tokens,
min_tokens=min_tokens,
max_time_criteria=max_time_criteria,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
Expand Down Expand Up @@ -110,6 +114,7 @@ def _create_completion(
suffix: Optional[str] = None,
max_tokens: int = 16,
min_tokens: int = 0,
max_time_criteria: Optional[Tuple[float, float]] = None, # new aviary argument
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
Expand Down Expand Up @@ -187,6 +192,12 @@ def _create_completion(
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
):
if max_time_criteria:
max_time, initial_timestamp = max_time_criteria
if time.time() - initial_timestamp > max_time:
finish_reason = "time"
break

if token == llama_cpp.llama_token_eos():
if len(completion_tokens) >= min_tokens:
text = self.detokenize(completion_tokens)
Expand Down
12 changes: 12 additions & 0 deletions aviary/backend/llm/pipelines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from transformers import (
LogitsProcessorList,
MaxTimeCriteria,
MinNewTokensLengthLogitsProcessor,
PreTrainedModel,
PreTrainedTokenizer,
Expand Down Expand Up @@ -98,6 +99,10 @@ def _get_stopping_criteria(
stopping_sequences += [self.tokenizer.eos_token_id]
lst.append(StopOnTokens(stopping_sequences))

if generate_kwargs.get("max_time_criteria", None) is not None:
max_time, initial_time = generate_kwargs.pop("max_time_criteria")
lst.append(MaxTimeCriteria(max_time, initial_time))

return StoppingCriteriaList(lst)

def _get_logits_processors(
Expand Down Expand Up @@ -250,9 +255,12 @@ def _sanitize_parameters(
prefix=None,
handle_long_generation=None,
stop_sequence=None,
# New aviary arguments
return_token_type_ids=None,
stopping_sequences=None,
add_special_tokens=None,
timeout_s=None,
start_timestamp=None,
**generate_kwargs,
):
preprocess_params = {}
Expand Down Expand Up @@ -289,6 +297,10 @@ def _sanitize_parameters(

if stopping_sequences is not None:
generate_kwargs["stopping_sequences"] = stopping_sequences

if timeout_s is not None and start_timestamp is not None:
generate_kwargs["max_time_criteria"] = (timeout_s, start_timestamp)

forward_params = generate_kwargs

postprocess_params = {}
Expand Down
55 changes: 43 additions & 12 deletions aviary/backend/llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,31 @@ def init_model(
max_batch_size=self.llm_config.generation.max_batch_size,
)

def generate(self, data: List[Prompt], **kwargs) -> List[str]:
def generate(
self,
data: List[Prompt],
*,
timeout_s: Optional[float] = None,
start_timestamp: Optional[float] = None,
**kwargs,
) -> List[str]:
"""Generate text from prompts.
Args:
data (List[Prompt]): List of prompts.
data (List[Prompt]): Batch of prompts.
timeout_s (Optional[float], optional): Timeout for the generation.
Ignored if start_timestamp is None.
start_timestamp (Optional[float], optional): Timestamp of when the
batch was created. Defaults to None. If set, will early stop
the generation. Ignored if timeout_s is None.
"""
return generate(data, self.generator, **kwargs)
return generate(
data,
self.generator,
timeout_s=timeout_s,
start_timestamp=start_timestamp,
**kwargs,
)

def __repr__(self) -> str:
return f"{self.__class__.__name__}:{self.llm_config.model_id}"
Expand All @@ -204,23 +222,23 @@ def __init__(self) -> None:
self._base_worker_group_lock = asyncio.Lock()
self._new_worker_group_lock = asyncio.Lock()

async def rollover(self, scaling_config: ScalingConfig, pg_timeout: float = 600):
async def rollover(self, scaling_config: ScalingConfig, pg_timeout_s: float = 600):
"""Roll over to a new worker group.
The new worker group is created asynchronously and the old worker group
is replaced with the new worker group once it is ready.
Args:
scaling_config (ScalingConfig): Scaling configuration for the new worker group.
pg_timeout (float, optional): Timeout for the new worker group to be ready. Defaults to 600.
pg_timeout_s (float, optional): Timeout for the new worker group to be ready. Defaults to 600.
"""
if self._new_worker_group_lock.locked():
logger.info("Rollover already in progress")
return
async with self._new_worker_group_lock:
logger.info(f"Initializing new worker group {scaling_config}")
self.new_worker_group = await self._create_worker_group(
scaling_config, pg_timeout=pg_timeout
scaling_config, pg_timeout_s=pg_timeout_s
)
async with self._base_worker_group_lock:
logger.info(f"Rolling over to new worker group {self.new_worker_group}")
Expand All @@ -229,13 +247,13 @@ async def rollover(self, scaling_config: ScalingConfig, pg_timeout: float = 600)
gc.collect()

async def _create_worker_group(
self, scaling_config: ScalingConfig, pg_timeout: float = 600
self, scaling_config: ScalingConfig, pg_timeout_s: float = 600
) -> List[ray.ObjectRef]:
"""Create a new worker group.
Args:
scaling_config (ScalingConfig): Scaling configuration for the new worker group.
pg_timeout (float, optional): Timeout for the new worker group to be ready. Defaults to 600.
pg_timeout_s (float, optional): Timeout for the new worker group to be ready. Defaults to 600.
"""
gc.collect()

Expand All @@ -262,7 +280,7 @@ async def _create_worker_group(

logger.info("Waiting for placement group to be ready...")
# This will raise a timeout error.
await asyncio.wait_for(self.pg.ready(), timeout=pg_timeout)
await asyncio.wait_for(self.pg.ready(), timeout=pg_timeout_s)

logger.info("Starting initialize_node tasks...")
await asyncio.gather(
Expand Down Expand Up @@ -303,11 +321,22 @@ async def _create_worker_group(

return worker_group

async def _predict_async(self, data: List[Prompt], **kwargs) -> List[str]:
async def _predict_async(
self,
prompts: List[Prompt],
*,
timeout_s: float = 60,
start_timestamp: Optional[float] = None,
) -> List[str]:
"""Generate text for a list of prompts.
Args:
data: A list of prompts.
prompts (List[Prompt]): Batch of prompts to generate text from.
timeout_s (float, optional): Timeout for the generation. Defaults
to 60. Ignored if start_timestamp is None.
start_timestamp (Optional[float], optional): Timestamp of when the
batch was created. Defaults to None. If set, will early stop
the generation.
Returns:
A list of generated texts.
Expand All @@ -317,7 +346,9 @@ async def _predict_async(self, data: List[Prompt], **kwargs) -> List[str]:
await asyncio.gather(
*[
worker.generate.remote(
data,
prompts,
timeout_s=timeout_s,
start_timestamp=start_timestamp,
**self.args.model_config.generation.all_generate_kwargs,
)
for worker in self.base_worker_group
Expand Down
108 changes: 71 additions & 37 deletions aviary/backend/server/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import time
import traceback
from typing import Any, Dict, List, Optional, Union

Expand All @@ -19,13 +20,12 @@
DeepSpeed,
Prompt,
)
from aviary.common.constants import GATEWAY_TIMEOUT_S

logger = get_logger(__name__)

app = FastAPI()

GATEWAY_TIMEOUT_S = 90


@serve.deployment(
autoscaling_config={
Expand Down Expand Up @@ -96,7 +96,7 @@ async def reconfigure(
if should_reinit_worker_group:
await self.rollover(
self.args.air_scaling_config,
pg_timeout=self.args.scaling_config.pg_timeout,
pg_timeout_s=self.args.scaling_config.pg_timeout_s,
)
logger.info("Reconfigured.")

Expand Down Expand Up @@ -133,56 +133,90 @@ async def metadata(self) -> dict:
@app.post("/", include_in_schema=False)
async def generate_text(self, prompt: Prompt):
await self.validate_prompt(prompt)
text = await self.generate_text_batch(
prompt, priority=QueuePriority.GENERATE_TEXT
)
return text
start_timestamp = time.time()
with async_timeout.timeout(GATEWAY_TIMEOUT_S):
text = await self.generate_text_batch(
prompt,
priority=QueuePriority.GENERATE_TEXT,
start_timestamp=start_timestamp,
)
return text

@app.post("/batch", include_in_schema=False)
async def batch_generate_text(self, prompts: List[Prompt]):
for prompt in prompts:
await self.validate_prompt(prompt)
texts = await asyncio.gather(
*[
self.generate_text_batch(
prompt, priority=QueuePriority.BATCH_GENERATE_TEXT
)
for prompt in prompts
]
)
return texts
start_timestamp = time.time()
with async_timeout.timeout(GATEWAY_TIMEOUT_S):
texts = await asyncio.gather(
*[
self.generate_text_batch(
prompt,
priority=QueuePriority.BATCH_GENERATE_TEXT,
start_timestamp=start_timestamp,
)
for prompt in prompts
]
)
return texts

@batch(
max_batch_size=get_max_batch_size,
batch_wait_timeout_s=get_batch_wait_timeout_s,
batch_queue_cls=_PriorityBatchQueue,
)
async def generate_text_batch(self, prompts: List[Prompt]):
"""Generate text from the given prompts in batch"""
async def generate_text_batch(
self,
prompts: List[Prompt],
*,
start_timestamp: Optional[Union[float, List[float]]] = None,
timeout_s: Union[float, List[float]] = GATEWAY_TIMEOUT_S - 5,
):
"""Generate text from the given prompts in batch.
Args:
prompts (List[Prompt]): Batch of prompts to generate text from.
start_timestamp (Optional[float], optional): Timestamp of when the
batch was created. Defaults to None. If set, will early stop
the generation.
timeout_s (float, optional): Timeout for the generation. Defaults
to GATEWAY_TIMEOUT_S-5. Ignored if start_timestamp is None.
"""
if not prompts or prompts[0] is None:
return prompts
logger.info(f"Received {len(prompts)} prompts {prompts}")

if isinstance(start_timestamp, list):
start_timestamp = min(start_timestamp)
if isinstance(timeout_s, list):
timeout_s = min(timeout_s)

logger.info(
f"Received {len(prompts)} prompts {prompts}. start_timestamp {start_timestamp} timeout_s {timeout_s}"
)
data_ref = ray.put(prompts)

with async_timeout.timeout(GATEWAY_TIMEOUT_S):
while not self.base_worker_group:
logger.info("Waiting for worker group to be initialized...")
await asyncio.sleep(1)
while not self.base_worker_group:
logger.info("Waiting for worker group to be initialized...")
await asyncio.sleep(1)

try:
prediction = await self._predict_async(data_ref)
except RayActorError:
logger.warning(
f"Prediction failed due to RayActorError. "
f"Traceback:\n{traceback.print_exc()}"
)
await self.check_health()
prediction = await self._predict_async(data_ref)
try:
prediction = await self._predict_async(
data_ref, timeout_s=timeout_s, start_timestamp=start_timestamp
)
except RayActorError:
logger.warning(
f"Prediction failed due to RayActorError. "
f"Traceback:\n{traceback.print_exc()}"
)
await self.check_health()
prediction = await self._predict_async(
data_ref, timeout_s=timeout_s, start_timestamp=start_timestamp
)

logger.info(f"Predictions {prediction}")
if not isinstance(prediction, list):
return [prediction]
return prediction[: len(prompts)]
logger.info(f"Predictions {prediction}")
if not isinstance(prediction, list):
return [prediction]
return prediction[: len(prompts)]

# Called by Serve to check the replica's health.
async def check_health(self):
Expand All @@ -203,7 +237,7 @@ async def check_health(self):
self.base_worker_group = None
await self.rollover(
self.args.air_scaling_config,
pg_timeout=self.args.scaling_config.pg_timeout,
pg_timeout_s=self.args.scaling_config.pg_timeout_s,
)

def __repr__(self) -> str:
Expand Down
Loading

0 comments on commit f4673f1

Please sign in to comment.