diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 930b34a0..fbb09331 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -30,7 +30,6 @@ def main(args: argparse.Namespace): max_tokens=args.output_len, ) print(sampling_params) - dummy_prompts = [""] * args.batch_size dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size def run_to_completion(profile: bool = False): @@ -38,7 +37,8 @@ def main(args: argparse.Namespace): torch.cuda.cudart().cudaProfilerStart() start_time = time.time() - llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids, + llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, use_tqdm=False) end_time = time.time() diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 4b4e94ad..b2cdda2b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -72,9 +72,9 @@ def main(args: argparse.Namespace): ) # FIXME(woosuk): Do not use internal method. llm._add_request( - prompt="", - sampling_params=sampling_params, + prompt=None, prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, ) start = time.time() @@ -85,7 +85,9 @@ def main(args: argparse.Namespace): len(prompt_token_ids) + output_len for prompt_token_ids, output_len in requests ) - print(f"Throughput: {total_num_tokens / (end - start):.2f} tokens/s") + elapsed_time = end - start + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") if __name__ == "__main__": diff --git a/cacheflow/entrypoints/llm.py b/cacheflow/entrypoints/llm.py index 7f2fdd7b..75a92fa9 100644 --- a/cacheflow/entrypoints/llm.py +++ b/cacheflow/entrypoints/llm.py @@ -11,6 +11,28 @@ from cacheflow.utils import Counter class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMServer` class instead. + NOTE: For the comprehensive list of arguments, see `ServerArgs`. + + Args: + model: The name or path of a HuggingFace Transformers model. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float16` and `bfloat16`. If `default`, we use the + `torch_dtype` attribute of the model config. If the `torch_dtype` + is `float32`, we use `float16` instead. + seed: The seed to initialize the random number generator for sampling. + """ def __init__( self, @@ -39,19 +61,50 @@ class LLM: def generate( self, - prompts: Union[str, List[str]], + prompts: Optional[Union[str, List[str]]] = None, sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, ) -> List[RequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + use_tqdm: Whether to use tqdm to display the progress bar. + + Returns: + A list of `RequestOutput` objects containing the generated + completions in the same order as the input prompts. + """ + if prompts is None and prompt_token_ids is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") if isinstance(prompts, str): + # Convert a single prompt to a list. prompts = [prompts] + if prompts is not None and prompt_token_ids is not None: + if len(prompts) != len(prompt_token_ids): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() + # Add requests to the server. - for i in range(len(prompts)): - prompt = prompts[i] + if prompts is not None: + num_requests = len(prompts) + else: + num_requests = len(prompt_token_ids) + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None if prompt_token_ids is None: token_ids = None else: @@ -61,7 +114,7 @@ class LLM: def _add_request( self, - prompt: str, + prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], ) -> None: diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 2397f7d4..0032a768 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -126,7 +126,7 @@ class LLMServer: def add_request( self, request_id: str, - prompt: str, + prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, @@ -134,6 +134,7 @@ class LLMServer: if arrival_time is None: arrival_time = time.time() if prompt_token_ids is None: + assert prompt is not None prompt_token_ids = self.tokenizer.encode(prompt) # Create the sequences.