91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
![]() |
# SPDX-License-Identifier: Apache-2.0
|
||
|
"""
|
||
|
This example shows how to use Ray Data for data parallel batch inference.
|
||
|
|
||
|
Ray Data is a data processing framework that can handle large datasets
|
||
|
and integrates tightly with vLLM for data-parallel inference.
|
||
|
|
||
|
As of Ray 2.44, Ray Data has a native integration with
|
||
|
vLLM (under ray.data.llm).
|
||
|
|
||
|
Ray Data provides functionality for:
|
||
|
* Reading and writing to cloud storage (S3, GCS, etc.)
|
||
|
* Automatic sharding and load-balancing across a cluster
|
||
|
* Optimized configuration of vLLM using continuous batching
|
||
|
* Compatible with tensor/pipeline parallel inference as well.
|
||
|
|
||
|
Learn more about Ray Data's LLM integration:
|
||
|
https://docs.ray.io/en/latest/data/working-with-llms.html
|
||
|
"""
|
||
|
import ray
|
||
|
from packaging.version import Version
|
||
|
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
|
||
|
|
||
|
assert Version(ray.__version__) >= Version(
|
||
|
"2.44.1"), "Ray version must be at least 2.44.1"
|
||
|
|
||
|
# Uncomment to reduce clutter in stdout
|
||
|
# ray.init(log_to_driver=False)
|
||
|
# ray.data.DataContext.get_current().enable_progress_bars = False
|
||
|
|
||
|
# Read one text file from S3. Ray Data supports reading multiple files
|
||
|
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||
|
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||
|
print(ds.schema())
|
||
|
|
||
|
size = ds.count()
|
||
|
print(f"Size of dataset: {size} prompts")
|
||
|
|
||
|
# Configure vLLM engine.
|
||
|
config = vLLMEngineProcessorConfig(
|
||
|
model_source="unsloth/Llama-3.1-8B-Instruct",
|
||
|
engine_kwargs={
|
||
|
"enable_chunked_prefill": True,
|
||
|
"max_num_batched_tokens": 4096,
|
||
|
"max_model_len": 16384,
|
||
|
},
|
||
|
concurrency=1, # set the number of parallel vLLM replicas
|
||
|
batch_size=64,
|
||
|
)
|
||
|
|
||
|
# Create a Processor object, which will be used to
|
||
|
# do batch inference on the dataset
|
||
|
vllm_processor = build_llm_processor(
|
||
|
config,
|
||
|
preprocess=lambda row: dict(
|
||
|
messages=[{
|
||
|
"role": "system",
|
||
|
"content": "You are a bot that responds with haikus."
|
||
|
}, {
|
||
|
"role": "user",
|
||
|
"content": row["text"]
|
||
|
}],
|
||
|
sampling_params=dict(
|
||
|
temperature=0.3,
|
||
|
max_tokens=250,
|
||
|
)),
|
||
|
postprocess=lambda row: dict(
|
||
|
answer=row["generated_text"],
|
||
|
**row # This will return all the original columns in the dataset.
|
||
|
),
|
||
|
)
|
||
|
|
||
|
ds = vllm_processor(ds)
|
||
|
|
||
|
# Peek first 10 results.
|
||
|
# NOTE: This is for local testing and debugging. For production use case,
|
||
|
# one should write full result out as shown below.
|
||
|
outputs = ds.take(limit=10)
|
||
|
|
||
|
for output in outputs:
|
||
|
prompt = output["prompt"]
|
||
|
generated_text = output["generated_text"]
|
||
|
print(f"Prompt: {prompt!r}")
|
||
|
print(f"Generated text: {generated_text!r}")
|
||
|
|
||
|
# Write inference output data out as Parquet files to S3.
|
||
|
# Multiple files would be written to the output destination,
|
||
|
# and each task would write one or more files separately.
|
||
|
#
|
||
|
# ds.write_parquet("s3://<your-output-bucket>")
|