Skip to content

Commit

Permalink
Add one example to run batch inference distributed on Ray (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
c21 authored Feb 2, 2024
1 parent 0e163fc commit 4abf633
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions examples/offline_inference_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
This example shows how to use Ray Data for running offline batch inference
distributively on a multi-nodes cluster.
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
"""

from vllm import LLM, SamplingParams
from typing import Dict
import numpy as np
import ray

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)


# Create a class to do batch inference.
class LLMPredictor:

def __init__(self):
# Create an LLM.
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
# Generate texts from the prompts.
# The output is a list of RequestOutput objects that contain the prompt,
# generated text, and other information.
outputs = self.llm.generate(batch["text"], sampling_params)
prompt = []
generated_text = []
for output in outputs:
prompt.append(output.prompt)
generated_text.append(' '.join([o.text for o in output.outputs]))
return {
"prompt": prompt,
"generated_text": generated_text,
}


# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")

# Apply batch inference for all input data.
ds = ds.map_batches(
LLMPredictor,
# Set the concurrency to the number of LLM instances.
concurrency=10,
# Specify the number of GPUs required per LLM instance.
# NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
# (i.e., `tensor_parallel_size`).
num_gpus=1,
# Specify the batch size for inference.
batch_size=32,
)

# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs = ds.take(limit=10)
for output in outputs:
prompt = output["prompt"]
generated_text = output["generated_text"]
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
# ds.write_parquet("s3://<your-output-bucket>")

0 comments on commit 4abf633

Please sign in to comment.