forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add one example to run batch inference distributed on Ray (vllm-proje…
- Loading branch information
Showing
1 changed file
with
70 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
This example shows how to use Ray Data for running offline batch inference | ||
distributively on a multi-nodes cluster. | ||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html | ||
""" | ||
|
||
from vllm import LLM, SamplingParams | ||
from typing import Dict | ||
import numpy as np | ||
import ray | ||
|
||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
|
||
# Create a class to do batch inference. | ||
class LLMPredictor: | ||
|
||
def __init__(self): | ||
# Create an LLM. | ||
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") | ||
|
||
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: | ||
# Generate texts from the prompts. | ||
# The output is a list of RequestOutput objects that contain the prompt, | ||
# generated text, and other information. | ||
outputs = self.llm.generate(batch["text"], sampling_params) | ||
prompt = [] | ||
generated_text = [] | ||
for output in outputs: | ||
prompt.append(output.prompt) | ||
generated_text.append(' '.join([o.text for o in output.outputs])) | ||
return { | ||
"prompt": prompt, | ||
"generated_text": generated_text, | ||
} | ||
|
||
|
||
# Read one text file from S3. Ray Data supports reading multiple files | ||
# from cloud storage (such as JSONL, Parquet, CSV, binary format). | ||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") | ||
|
||
# Apply batch inference for all input data. | ||
ds = ds.map_batches( | ||
LLMPredictor, | ||
# Set the concurrency to the number of LLM instances. | ||
concurrency=10, | ||
# Specify the number of GPUs required per LLM instance. | ||
# NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism | ||
# (i.e., `tensor_parallel_size`). | ||
num_gpus=1, | ||
# Specify the batch size for inference. | ||
batch_size=32, | ||
) | ||
|
||
# Peek first 10 results. | ||
# NOTE: This is for local testing and debugging. For production use case, | ||
# one should write full result out as shown below. | ||
outputs = ds.take(limit=10) | ||
for output in outputs: | ||
prompt = output["prompt"] | ||
generated_text = output["generated_text"] | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
||
# Write inference output data out as Parquet files to S3. | ||
# Multiple files would be written to the output destination, | ||
# and each task would write one or more files separately. | ||
# | ||
# ds.write_parquet("s3://<your-output-bucket>") |