diff --git a/examples/offline_inference_distributed.py b/examples/offline_inference_distributed.py new file mode 100644 index 0000000000000..0897045fd94ae --- /dev/null +++ b/examples/offline_inference_distributed.py @@ -0,0 +1,70 @@ +""" +This example shows how to use Ray Data for running offline batch inference +distributively on a multi-nodes cluster. + +Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html +""" + +from vllm import LLM, SamplingParams +from typing import Dict +import numpy as np +import ray + +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + +# Create a class to do batch inference. +class LLMPredictor: + + def __init__(self): + # Create an LLM. + self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") + + def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: + # Generate texts from the prompts. + # The output is a list of RequestOutput objects that contain the prompt, + # generated text, and other information. + outputs = self.llm.generate(batch["text"], sampling_params) + prompt = [] + generated_text = [] + for output in outputs: + prompt.append(output.prompt) + generated_text.append(' '.join([o.text for o in output.outputs])) + return { + "prompt": prompt, + "generated_text": generated_text, + } + + +# Read one text file from S3. Ray Data supports reading multiple files +# from cloud storage (such as JSONL, Parquet, CSV, binary format). +ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") + +# Apply batch inference for all input data. +ds = ds.map_batches( + LLMPredictor, + # Set the concurrency to the number of LLM instances. + concurrency=10, + # Specify the number of GPUs required per LLM instance. + # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism + # (i.e., `tensor_parallel_size`). + num_gpus=1, + # Specify the batch size for inference. + batch_size=32, +) + +# Peek first 10 results. +# NOTE: This is for local testing and debugging. For production use case, +# one should write full result out as shown below. +outputs = ds.take(limit=10) +for output in outputs: + prompt = output["prompt"] + generated_text = output["generated_text"] + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +# Write inference output data out as Parquet files to S3. +# Multiple files would be written to the output destination, +# and each task would write one or more files separately. +# +# ds.write_parquet("s3://")