""" This example shows how to use Ray Data for running offline batch inference distributively on a multi-nodes cluster. Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html """ from vllm import LLM, SamplingParams from typing import Dict import numpy as np import ray # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create a class to do batch inference. class LLMPredictor: def __init__(self): # Create an LLM. self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: # Generate texts from the prompts. # The output is a list of RequestOutput objects that contain the prompt, # generated text, and other information. outputs = self.llm.generate(batch["text"], sampling_params) prompt = [] generated_text = [] for output in outputs: prompt.append(output.prompt) generated_text.append(' '.join([o.text for o in output.outputs])) return { "prompt": prompt, "generated_text": generated_text, } # Read one text file from S3. Ray Data supports reading multiple files # from cloud storage (such as JSONL, Parquet, CSV, binary format). ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") # Apply batch inference for all input data. ds = ds.map_batches( LLMPredictor, # Set the concurrency to the number of LLM instances. concurrency=10, # Specify the number of GPUs required per LLM instance. # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism # (i.e., `tensor_parallel_size`). num_gpus=1, # Specify the batch size for inference. batch_size=32, ) # Peek first 10 results. # NOTE: This is for local testing and debugging. For production use case, # one should write full result out as shown below. outputs = ds.take(limit=10) for output in outputs: prompt = output["prompt"] generated_text = output["generated_text"] print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # Write inference output data out as Parquet files to S3. # Multiple files would be written to the output destination, # and each task would write one or more files separately. # # ds.write_parquet("s3://")