Add docstrings for LLM (#137)
This commit is contained in:
parent
62ec38ea41
commit
8274ca23ac
@ -30,7 +30,6 @@ def main(args: argparse.Namespace):
|
|||||||
max_tokens=args.output_len,
|
max_tokens=args.output_len,
|
||||||
)
|
)
|
||||||
print(sampling_params)
|
print(sampling_params)
|
||||||
dummy_prompts = [""] * args.batch_size
|
|
||||||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||||
|
|
||||||
def run_to_completion(profile: bool = False):
|
def run_to_completion(profile: bool = False):
|
||||||
@ -38,7 +37,8 @@ def main(args: argparse.Namespace):
|
|||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
|
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -72,9 +72,9 @@ def main(args: argparse.Namespace):
|
|||||||
)
|
)
|
||||||
# FIXME(woosuk): Do not use internal method.
|
# FIXME(woosuk): Do not use internal method.
|
||||||
llm._add_request(
|
llm._add_request(
|
||||||
prompt="",
|
prompt=None,
|
||||||
sampling_params=sampling_params,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -85,7 +85,9 @@ def main(args: argparse.Namespace):
|
|||||||
len(prompt_token_ids) + output_len
|
len(prompt_token_ids) + output_len
|
||||||
for prompt_token_ids, output_len in requests
|
for prompt_token_ids, output_len in requests
|
||||||
)
|
)
|
||||||
print(f"Throughput: {total_num_tokens / (end - start):.2f} tokens/s")
|
elapsed_time = end - start
|
||||||
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -11,6 +11,28 @@ from cacheflow.utils import Counter
|
|||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
|
|
||||||
|
This class includes a tokenizer, a language model (possibly distributed
|
||||||
|
across multiple GPUs), and GPU memory space allocated for intermediate
|
||||||
|
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
||||||
|
this class generates texts from the model, using an intelligent batching
|
||||||
|
mechanism and efficient memory management.
|
||||||
|
|
||||||
|
NOTE: This class is intended to be used for offline inference. For online
|
||||||
|
serving, use the `AsyncLLMServer` class instead.
|
||||||
|
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The name or path of a HuggingFace Transformers model.
|
||||||
|
tensor_parallel_size: The number of GPUs to use for distributed
|
||||||
|
execution with tensor parallelism.
|
||||||
|
dtype: The data type for the model weights and activations. Currently,
|
||||||
|
we support `float16` and `bfloat16`. If `default`, we use the
|
||||||
|
`torch_dtype` attribute of the model config. If the `torch_dtype`
|
||||||
|
is `float32`, we use `float16` instead.
|
||||||
|
seed: The seed to initialize the random number generator for sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -39,19 +61,50 @@ class LLM:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: Union[str, List[str]],
|
prompts: Optional[Union[str, List[str]]] = None,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
|
NOTE: This class automatically batches the given prompts, considering
|
||||||
|
the memory constraint. For the best performance, put all of your prompts
|
||||||
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: A list of prompts to generate completions for.
|
||||||
|
sampling_params: The sampling parameters for text generation. If
|
||||||
|
None, we use the default sampling parameters.
|
||||||
|
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||||
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `RequestOutput` objects containing the generated
|
||||||
|
completions in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
if prompts is None and prompt_token_ids is None:
|
||||||
|
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||||
|
"provided.")
|
||||||
if isinstance(prompts, str):
|
if isinstance(prompts, str):
|
||||||
|
# Convert a single prompt to a list.
|
||||||
prompts = [prompts]
|
prompts = [prompts]
|
||||||
|
if prompts is not None and prompt_token_ids is not None:
|
||||||
|
if len(prompts) != len(prompt_token_ids):
|
||||||
|
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||||
|
"must be the same.")
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
# Use default sampling params.
|
# Use default sampling params.
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
# Add requests to the server.
|
# Add requests to the server.
|
||||||
for i in range(len(prompts)):
|
if prompts is not None:
|
||||||
prompt = prompts[i]
|
num_requests = len(prompts)
|
||||||
|
else:
|
||||||
|
num_requests = len(prompt_token_ids)
|
||||||
|
for i in range(num_requests):
|
||||||
|
prompt = prompts[i] if prompts is not None else None
|
||||||
if prompt_token_ids is None:
|
if prompt_token_ids is None:
|
||||||
token_ids = None
|
token_ids = None
|
||||||
else:
|
else:
|
||||||
@ -61,7 +114,7 @@ class LLM:
|
|||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: Optional[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -126,7 +126,7 @@ class LLMServer:
|
|||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: str,
|
prompt: Optional[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
@ -134,6 +134,7 @@ class LLMServer:
|
|||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
if prompt_token_ids is None:
|
if prompt_token_ids is None:
|
||||||
|
assert prompt is not None
|
||||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user