Add one example to run batch inference distributed on Ray (#2696)
This commit is contained in:
parent
0e163fce18
commit
4abf6336ec
70
examples/offline_inference_distributed.py
Normal file
70
examples/offline_inference_distributed.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""
|
||||
This example shows how to use Ray Data for running offline batch inference
|
||||
distributively on a multi-nodes cluster.
|
||||
|
||||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
# Create a class to do batch inference.
|
||||
class LLMPredictor:
|
||||
|
||||
def __init__(self):
|
||||
# Create an LLM.
|
||||
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")
|
||||
|
||||
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects that contain the prompt,
|
||||
# generated text, and other information.
|
||||
outputs = self.llm.generate(batch["text"], sampling_params)
|
||||
prompt = []
|
||||
generated_text = []
|
||||
for output in outputs:
|
||||
prompt.append(output.prompt)
|
||||
generated_text.append(' '.join([o.text for o in output.outputs]))
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"generated_text": generated_text,
|
||||
}
|
||||
|
||||
|
||||
# Read one text file from S3. Ray Data supports reading multiple files
|
||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||||
|
||||
# Apply batch inference for all input data.
|
||||
ds = ds.map_batches(
|
||||
LLMPredictor,
|
||||
# Set the concurrency to the number of LLM instances.
|
||||
concurrency=10,
|
||||
# Specify the number of GPUs required per LLM instance.
|
||||
# NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
|
||||
# (i.e., `tensor_parallel_size`).
|
||||
num_gpus=1,
|
||||
# Specify the batch size for inference.
|
||||
batch_size=32,
|
||||
)
|
||||
|
||||
# Peek first 10 results.
|
||||
# NOTE: This is for local testing and debugging. For production use case,
|
||||
# one should write full result out as shown below.
|
||||
outputs = ds.take(limit=10)
|
||||
for output in outputs:
|
||||
prompt = output["prompt"]
|
||||
generated_text = output["generated_text"]
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Write inference output data out as Parquet files to S3.
|
||||
# Multiple files would be written to the output destination,
|
||||
# and each task would write one or more files separately.
|
||||
#
|
||||
# ds.write_parquet("s3://<your-output-bucket>")
|
Loading…
x
Reference in New Issue
Block a user