[Ray] Improve documentation on batch inference (#16609)
Signed-off-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
9dbf7a2dc1
commit
8cac35ba43
90
examples/offline_inference/batch_llm_inference.py
Normal file
90
examples/offline_inference/batch_llm_inference.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to use Ray Data for data parallel batch inference.
|
||||
|
||||
Ray Data is a data processing framework that can handle large datasets
|
||||
and integrates tightly with vLLM for data-parallel inference.
|
||||
|
||||
As of Ray 2.44, Ray Data has a native integration with
|
||||
vLLM (under ray.data.llm).
|
||||
|
||||
Ray Data provides functionality for:
|
||||
* Reading and writing to cloud storage (S3, GCS, etc.)
|
||||
* Automatic sharding and load-balancing across a cluster
|
||||
* Optimized configuration of vLLM using continuous batching
|
||||
* Compatible with tensor/pipeline parallel inference as well.
|
||||
|
||||
Learn more about Ray Data's LLM integration:
|
||||
https://docs.ray.io/en/latest/data/working-with-llms.html
|
||||
"""
|
||||
import ray
|
||||
from packaging.version import Version
|
||||
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
|
||||
|
||||
assert Version(ray.__version__) >= Version(
|
||||
"2.44.1"), "Ray version must be at least 2.44.1"
|
||||
|
||||
# Uncomment to reduce clutter in stdout
|
||||
# ray.init(log_to_driver=False)
|
||||
# ray.data.DataContext.get_current().enable_progress_bars = False
|
||||
|
||||
# Read one text file from S3. Ray Data supports reading multiple files
|
||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||||
print(ds.schema())
|
||||
|
||||
size = ds.count()
|
||||
print(f"Size of dataset: {size} prompts")
|
||||
|
||||
# Configure vLLM engine.
|
||||
config = vLLMEngineProcessorConfig(
|
||||
model_source="unsloth/Llama-3.1-8B-Instruct",
|
||||
engine_kwargs={
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4096,
|
||||
"max_model_len": 16384,
|
||||
},
|
||||
concurrency=1, # set the number of parallel vLLM replicas
|
||||
batch_size=64,
|
||||
)
|
||||
|
||||
# Create a Processor object, which will be used to
|
||||
# do batch inference on the dataset
|
||||
vllm_processor = build_llm_processor(
|
||||
config,
|
||||
preprocess=lambda row: dict(
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a bot that responds with haikus."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": row["text"]
|
||||
}],
|
||||
sampling_params=dict(
|
||||
temperature=0.3,
|
||||
max_tokens=250,
|
||||
)),
|
||||
postprocess=lambda row: dict(
|
||||
answer=row["generated_text"],
|
||||
**row # This will return all the original columns in the dataset.
|
||||
),
|
||||
)
|
||||
|
||||
ds = vllm_processor(ds)
|
||||
|
||||
# Peek first 10 results.
|
||||
# NOTE: This is for local testing and debugging. For production use case,
|
||||
# one should write full result out as shown below.
|
||||
outputs = ds.take(limit=10)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output["prompt"]
|
||||
generated_text = output["generated_text"]
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
|
||||
# Write inference output data out as Parquet files to S3.
|
||||
# Multiple files would be written to the output destination,
|
||||
# and each task would write one or more files separately.
|
||||
#
|
||||
# ds.write_parquet("s3://<your-output-bucket>")
|
@ -1,109 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to use Ray Data for running offline batch inference
|
||||
distributively on a multi-nodes cluster.
|
||||
|
||||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from packaging.version import Version
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
assert Version(ray.__version__) >= Version(
|
||||
"2.22.0"), "Ray version must be at least 2.22.0"
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Set tensor parallelism per instance.
|
||||
tensor_parallel_size = 1
|
||||
|
||||
# Set number of instances. Each instance will use tensor_parallel_size GPUs.
|
||||
num_instances = 1
|
||||
|
||||
|
||||
# Create a class to do batch inference.
|
||||
class LLMPredictor:
|
||||
|
||||
def __init__(self):
|
||||
# Create an LLM.
|
||||
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
tensor_parallel_size=tensor_parallel_size)
|
||||
|
||||
def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]:
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects that contain the prompt,
|
||||
# generated text, and other information.
|
||||
outputs = self.llm.generate(batch["text"], sampling_params)
|
||||
prompt: list[str] = []
|
||||
generated_text: list[str] = []
|
||||
for output in outputs:
|
||||
prompt.append(output.prompt)
|
||||
generated_text.append(' '.join([o.text for o in output.outputs]))
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"generated_text": generated_text,
|
||||
}
|
||||
|
||||
|
||||
# Read one text file from S3. Ray Data supports reading multiple files
|
||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||||
|
||||
|
||||
# For tensor_parallel_size > 1, we need to create placement groups for vLLM
|
||||
# to use. Every actor has to have its own placement group.
|
||||
def scheduling_strategy_fn():
|
||||
# One bundle per tensor parallel worker
|
||||
pg = ray.util.placement_group(
|
||||
[{
|
||||
"GPU": 1,
|
||||
"CPU": 1
|
||||
}] * tensor_parallel_size,
|
||||
strategy="STRICT_PACK",
|
||||
)
|
||||
return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
pg, placement_group_capture_child_tasks=True))
|
||||
|
||||
|
||||
resources_kwarg: dict[str, Any] = {}
|
||||
if tensor_parallel_size == 1:
|
||||
# For tensor_parallel_size == 1, we simply set num_gpus=1.
|
||||
resources_kwarg["num_gpus"] = 1
|
||||
else:
|
||||
# Otherwise, we have to set num_gpus=0 and provide
|
||||
# a function that will create a placement group for
|
||||
# each instance.
|
||||
resources_kwarg["num_gpus"] = 0
|
||||
resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
|
||||
|
||||
# Apply batch inference for all input data.
|
||||
ds = ds.map_batches(
|
||||
LLMPredictor,
|
||||
# Set the concurrency to the number of LLM instances.
|
||||
concurrency=num_instances,
|
||||
# Specify the batch size for inference.
|
||||
batch_size=32,
|
||||
**resources_kwarg,
|
||||
)
|
||||
|
||||
# Peek first 10 results.
|
||||
# NOTE: This is for local testing and debugging. For production use case,
|
||||
# one should write full result out as shown below.
|
||||
outputs = ds.take(limit=10)
|
||||
for output in outputs:
|
||||
prompt = output["prompt"]
|
||||
generated_text = output["generated_text"]
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Write inference output data out as Parquet files to S3.
|
||||
# Multiple files would be written to the output destination,
|
||||
# and each task would write one or more files separately.
|
||||
#
|
||||
# ds.write_parquet("s3://<your-output-bucket>")
|
Loading…
x
Reference in New Issue
Block a user