From 8cac35ba435906fb7eb07e44fe1a8c26e8744f4e Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 16 Apr 2025 22:19:26 -0700 Subject: [PATCH] [Ray] Improve documentation on batch inference (#16609) Signed-off-by: Richard Liaw --- .../offline_inference/batch_llm_inference.py | 90 +++++++++++++++ examples/offline_inference/distributed.py | 109 ------------------ 2 files changed, 90 insertions(+), 109 deletions(-) create mode 100644 examples/offline_inference/batch_llm_inference.py delete mode 100644 examples/offline_inference/distributed.py diff --git a/examples/offline_inference/batch_llm_inference.py b/examples/offline_inference/batch_llm_inference.py new file mode 100644 index 00000000..6548857b --- /dev/null +++ b/examples/offline_inference/batch_llm_inference.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to use Ray Data for data parallel batch inference. + +Ray Data is a data processing framework that can handle large datasets +and integrates tightly with vLLM for data-parallel inference. + +As of Ray 2.44, Ray Data has a native integration with +vLLM (under ray.data.llm). + +Ray Data provides functionality for: +* Reading and writing to cloud storage (S3, GCS, etc.) +* Automatic sharding and load-balancing across a cluster +* Optimized configuration of vLLM using continuous batching +* Compatible with tensor/pipeline parallel inference as well. + +Learn more about Ray Data's LLM integration: +https://docs.ray.io/en/latest/data/working-with-llms.html +""" +import ray +from packaging.version import Version +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig + +assert Version(ray.__version__) >= Version( + "2.44.1"), "Ray version must be at least 2.44.1" + +# Uncomment to reduce clutter in stdout +# ray.init(log_to_driver=False) +# ray.data.DataContext.get_current().enable_progress_bars = False + +# Read one text file from S3. Ray Data supports reading multiple files +# from cloud storage (such as JSONL, Parquet, CSV, binary format). +ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") +print(ds.schema()) + +size = ds.count() +print(f"Size of dataset: {size} prompts") + +# Configure vLLM engine. +config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4096, + "max_model_len": 16384, + }, + concurrency=1, # set the number of parallel vLLM replicas + batch_size=64, +) + +# Create a Processor object, which will be used to +# do batch inference on the dataset +vllm_processor = build_llm_processor( + config, + preprocess=lambda row: dict( + messages=[{ + "role": "system", + "content": "You are a bot that responds with haikus." + }, { + "role": "user", + "content": row["text"] + }], + sampling_params=dict( + temperature=0.3, + max_tokens=250, + )), + postprocess=lambda row: dict( + answer=row["generated_text"], + **row # This will return all the original columns in the dataset. + ), +) + +ds = vllm_processor(ds) + +# Peek first 10 results. +# NOTE: This is for local testing and debugging. For production use case, +# one should write full result out as shown below. +outputs = ds.take(limit=10) + +for output in outputs: + prompt = output["prompt"] + generated_text = output["generated_text"] + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}") + +# Write inference output data out as Parquet files to S3. +# Multiple files would be written to the output destination, +# and each task would write one or more files separately. +# +# ds.write_parquet("s3://") diff --git a/examples/offline_inference/distributed.py b/examples/offline_inference/distributed.py deleted file mode 100644 index e890c6da..00000000 --- a/examples/offline_inference/distributed.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This example shows how to use Ray Data for running offline batch inference -distributively on a multi-nodes cluster. - -Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html -""" - -from typing import Any - -import numpy as np -import ray -from packaging.version import Version -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -from vllm import LLM, SamplingParams - -assert Version(ray.__version__) >= Version( - "2.22.0"), "Ray version must be at least 2.22.0" - -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Set tensor parallelism per instance. -tensor_parallel_size = 1 - -# Set number of instances. Each instance will use tensor_parallel_size GPUs. -num_instances = 1 - - -# Create a class to do batch inference. -class LLMPredictor: - - def __init__(self): - # Create an LLM. - self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", - tensor_parallel_size=tensor_parallel_size) - - def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]: - # Generate texts from the prompts. - # The output is a list of RequestOutput objects that contain the prompt, - # generated text, and other information. - outputs = self.llm.generate(batch["text"], sampling_params) - prompt: list[str] = [] - generated_text: list[str] = [] - for output in outputs: - prompt.append(output.prompt) - generated_text.append(' '.join([o.text for o in output.outputs])) - return { - "prompt": prompt, - "generated_text": generated_text, - } - - -# Read one text file from S3. Ray Data supports reading multiple files -# from cloud storage (such as JSONL, Parquet, CSV, binary format). -ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") - - -# For tensor_parallel_size > 1, we need to create placement groups for vLLM -# to use. Every actor has to have its own placement group. -def scheduling_strategy_fn(): - # One bundle per tensor parallel worker - pg = ray.util.placement_group( - [{ - "GPU": 1, - "CPU": 1 - }] * tensor_parallel_size, - strategy="STRICT_PACK", - ) - return dict(scheduling_strategy=PlacementGroupSchedulingStrategy( - pg, placement_group_capture_child_tasks=True)) - - -resources_kwarg: dict[str, Any] = {} -if tensor_parallel_size == 1: - # For tensor_parallel_size == 1, we simply set num_gpus=1. - resources_kwarg["num_gpus"] = 1 -else: - # Otherwise, we have to set num_gpus=0 and provide - # a function that will create a placement group for - # each instance. - resources_kwarg["num_gpus"] = 0 - resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn - -# Apply batch inference for all input data. -ds = ds.map_batches( - LLMPredictor, - # Set the concurrency to the number of LLM instances. - concurrency=num_instances, - # Specify the batch size for inference. - batch_size=32, - **resources_kwarg, -) - -# Peek first 10 results. -# NOTE: This is for local testing and debugging. For production use case, -# one should write full result out as shown below. -outputs = ds.take(limit=10) -for output in outputs: - prompt = output["prompt"] - generated_text = output["generated_text"] - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -# Write inference output data out as Parquet files to S3. -# Multiple files would be written to the output destination, -# and each task would write one or more files separately. -# -# ds.write_parquet("s3://")