2023-09-03 21:43:43 -07:00
|
|
|
import time
|
2024-05-29 04:29:31 +08:00
|
|
|
from contextlib import contextmanager
|
|
|
|
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
|
|
|
|
from typing import Sequence as GenericSequence
|
2024-06-15 12:45:31 +08:00
|
|
|
from typing import Set, Type, TypeVar, Union
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-04-18 21:13:36 -07:00
|
|
|
from transformers import GenerationConfig, PreTrainedTokenizer
|
2024-03-11 10:17:16 +08:00
|
|
|
|
2024-04-16 11:34:39 -07:00
|
|
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
|
|
|
LoRAConfig, ModelConfig, ParallelConfig,
|
|
|
|
SchedulerConfig, SpeculativeConfig,
|
2024-04-16 08:54:57 +03:00
|
|
|
VisionLanguageConfig)
|
2024-05-01 11:21:39 +08:00
|
|
|
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
|
|
|
SchedulerOutputs)
|
2023-06-17 03:07:40 -07:00
|
|
|
from vllm.engine.arg_utils import EngineArgs
|
2024-01-31 14:58:07 -08:00
|
|
|
from vllm.engine.metrics import StatLogger, Stats
|
2024-04-16 13:09:21 -07:00
|
|
|
from vllm.engine.output_processor.interfaces import (
|
|
|
|
SequenceGroupOutputProcessor)
|
|
|
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
|
|
|
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.executor.executor_base import ExecutorBase
|
2024-04-24 23:52:22 -07:00
|
|
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
2024-05-29 04:29:31 +08:00
|
|
|
from vllm.inputs import LLMInputs, PromptInputs
|
2023-06-17 03:07:40 -07:00
|
|
|
from vllm.logger import init_logger
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.lora.request import LoRARequest
|
2024-05-11 11:30:37 -07:00
|
|
|
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
|
|
|
RequestOutputFactory)
|
|
|
|
from vllm.pooling_params import PoolingParams
|
2023-06-17 03:07:40 -07:00
|
|
|
from vllm.sampling_params import SamplingParams
|
2024-05-11 11:30:37 -07:00
|
|
|
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
2024-05-29 04:29:31 +08:00
|
|
|
PoolerOutput, SamplerOutput, Sequence,
|
|
|
|
SequenceGroup, SequenceGroupMetadata,
|
2024-04-29 01:59:33 +03:00
|
|
|
SequenceStatus)
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
2024-03-15 16:37:01 -07:00
|
|
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
|
|
|
get_tokenizer_group)
|
2024-03-28 22:16:12 -07:00
|
|
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
|
|
|
usage_message)
|
2024-03-11 11:03:45 -07:00
|
|
|
from vllm.utils import Counter
|
2024-06-14 02:21:39 +08:00
|
|
|
from vllm.version import __version__ as VLLM_VERSION
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
2024-01-31 14:58:07 -08:00
|
|
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
2023-08-02 16:42:01 -07:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-04-18 21:13:36 -07:00
|
|
|
def _load_generation_config_dict(model_config: ModelConfig):
|
|
|
|
try:
|
|
|
|
return GenerationConfig.from_pretrained(
|
|
|
|
model_config.model,
|
|
|
|
revision=model_config.revision,
|
|
|
|
).to_diff_dict()
|
|
|
|
except OSError:
|
|
|
|
# Not found.
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
|
|
|
|
|
|
|
|
2023-06-17 00:13:02 +08:00
|
|
|
class LLMEngine:
|
2023-06-17 17:25:21 +08:00
|
|
|
"""An LLM engine that receives requests and generates texts.
|
2023-06-07 18:25:20 +08:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
This is the main class for the vLLM engine. It receives requests
|
2023-06-07 18:25:20 +08:00
|
|
|
from clients and generates texts from the LLM. It includes a tokenizer, a
|
|
|
|
language model (possibly distributed across multiple GPUs), and GPU memory
|
|
|
|
space allocated for intermediate states (aka KV cache). This class utilizes
|
|
|
|
iteration-level scheduling and efficient memory management to maximize the
|
|
|
|
serving throughput.
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
The :class:`~vllm.LLM` class wraps this class for offline batched inference
|
|
|
|
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
|
2023-06-07 18:25:20 +08:00
|
|
|
|
2024-05-31 00:59:23 +08:00
|
|
|
The config arguments are derived from :class:`~vllm.EngineArgs`. (See
|
|
|
|
:ref:`engine_args`)
|
2023-06-07 18:25:20 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model_config: The configuration related to the LLM model.
|
|
|
|
cache_config: The configuration related to the KV cache memory
|
|
|
|
management.
|
|
|
|
parallel_config: The configuration related to distributed execution.
|
|
|
|
scheduler_config: The configuration related to the request scheduler.
|
2024-02-02 07:46:39 +08:00
|
|
|
device_config: The configuration related to the device.
|
2024-04-02 17:40:57 -07:00
|
|
|
lora_config (Optional): The configuration related to serving multi-LoRA.
|
|
|
|
vision_language_config (Optional): The configuration related to vision
|
|
|
|
language models.
|
|
|
|
speculative_config (Optional): The configuration related to speculative
|
|
|
|
decoding.
|
2024-03-11 11:03:45 -07:00
|
|
|
executor_class: The model executor class for managing distributed
|
|
|
|
execution.
|
2023-06-07 18:25:20 +08:00
|
|
|
log_stats: Whether to log statistics.
|
2024-05-29 04:29:31 +08:00
|
|
|
usage_context: Specified entry point, used for usage info collection.
|
2023-06-07 18:25:20 +08:00
|
|
|
"""
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
|
|
|
|
"""A flag to toggle whether to validate the type of request output."""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@contextmanager
|
|
|
|
def enable_output_validation(cls):
|
|
|
|
cls.DO_VALIDATE_OUTPUT = True
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
cls.DO_VALIDATE_OUTPUT = False
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def validate_output(
|
|
|
|
cls,
|
|
|
|
output: object,
|
|
|
|
output_type: Type[_O],
|
|
|
|
) -> _O:
|
|
|
|
do_validate = cls.DO_VALIDATE_OUTPUT
|
|
|
|
|
|
|
|
if ((TYPE_CHECKING or do_validate)
|
|
|
|
and not isinstance(output, output_type)):
|
|
|
|
raise TypeError(f"Expected output of type {output_type}, "
|
|
|
|
f"but found type {type(output)}")
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def validate_outputs(
|
|
|
|
cls,
|
|
|
|
outputs: GenericSequence[object],
|
|
|
|
output_type: Type[_O],
|
|
|
|
) -> List[_O]:
|
|
|
|
do_validate = cls.DO_VALIDATE_OUTPUT
|
|
|
|
|
|
|
|
outputs_: List[_O]
|
|
|
|
if TYPE_CHECKING or do_validate:
|
|
|
|
outputs_ = []
|
|
|
|
for output in outputs:
|
|
|
|
if not isinstance(output, output_type):
|
|
|
|
raise TypeError(f"Expected output of type {output_type}, "
|
|
|
|
f"but found type {type(output)}")
|
|
|
|
|
|
|
|
outputs_.append(output)
|
|
|
|
else:
|
|
|
|
outputs_ = outputs
|
|
|
|
|
|
|
|
return outputs_
|
|
|
|
|
|
|
|
tokenizer: Optional[BaseTokenizerGroup]
|
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_config: ModelConfig,
|
|
|
|
cache_config: CacheConfig,
|
|
|
|
parallel_config: ParallelConfig,
|
|
|
|
scheduler_config: SchedulerConfig,
|
2024-02-02 07:46:39 +08:00
|
|
|
device_config: DeviceConfig,
|
2024-04-16 11:34:39 -07:00
|
|
|
load_config: LoadConfig,
|
2024-01-24 00:26:37 +01:00
|
|
|
lora_config: Optional[LoRAConfig],
|
2024-04-02 17:40:57 -07:00
|
|
|
vision_language_config: Optional[VisionLanguageConfig],
|
|
|
|
speculative_config: Optional[SpeculativeConfig],
|
2024-04-16 08:54:57 +03:00
|
|
|
decoding_config: Optional[DecodingConfig],
|
2024-03-11 11:03:45 -07:00
|
|
|
executor_class: Type[ExecutorBase],
|
2023-05-21 17:04:18 -07:00
|
|
|
log_stats: bool,
|
2024-03-28 22:16:12 -07:00
|
|
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
2023-05-20 13:06:59 -07:00
|
|
|
) -> None:
|
|
|
|
logger.info(
|
2024-04-26 16:16:58 +09:00
|
|
|
"Initializing an LLM engine (v%s) with config: "
|
|
|
|
"model=%r, speculative_config=%r, tokenizer=%r, "
|
|
|
|
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
2024-06-11 17:42:26 +00:00
|
|
|
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
2024-05-22 05:32:35 +00:00
|
|
|
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
|
|
|
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
|
|
|
"disable_custom_all_reduce=%s, quantization=%s, "
|
|
|
|
"enforce_eager=%s, kv_cache_dtype=%s, "
|
2024-04-26 16:16:58 +09:00
|
|
|
"quantization_param_path=%s, device_config=%s, "
|
2024-05-05 06:39:34 +08:00
|
|
|
"decoding_config=%r, seed=%d, served_model_name=%s)",
|
2024-06-14 02:21:39 +08:00
|
|
|
VLLM_VERSION,
|
2024-04-26 16:16:58 +09:00
|
|
|
model_config.model,
|
|
|
|
speculative_config,
|
|
|
|
model_config.tokenizer,
|
|
|
|
model_config.skip_tokenizer_init,
|
|
|
|
model_config.tokenizer_mode,
|
|
|
|
model_config.revision,
|
2024-05-22 05:32:35 +00:00
|
|
|
model_config.rope_scaling,
|
2024-06-11 17:42:26 +00:00
|
|
|
model_config.rope_theta,
|
2024-04-26 16:16:58 +09:00
|
|
|
model_config.tokenizer_revision,
|
|
|
|
model_config.trust_remote_code,
|
|
|
|
model_config.dtype,
|
|
|
|
model_config.max_model_len,
|
|
|
|
load_config.download_dir,
|
|
|
|
load_config.load_format,
|
|
|
|
parallel_config.tensor_parallel_size,
|
|
|
|
parallel_config.disable_custom_all_reduce,
|
|
|
|
model_config.quantization,
|
|
|
|
model_config.enforce_eager,
|
|
|
|
cache_config.cache_dtype,
|
|
|
|
model_config.quantization_param_path,
|
|
|
|
device_config.device,
|
|
|
|
decoding_config,
|
|
|
|
model_config.seed,
|
2024-05-05 06:39:34 +08:00
|
|
|
model_config.served_model_name,
|
2024-04-26 16:16:58 +09:00
|
|
|
)
|
2023-05-20 13:06:59 -07:00
|
|
|
# TODO(woosuk): Print more configs in debug mode.
|
|
|
|
|
|
|
|
self.model_config = model_config
|
|
|
|
self.cache_config = cache_config
|
2024-01-24 00:26:37 +01:00
|
|
|
self.lora_config = lora_config
|
2024-03-25 14:16:30 -07:00
|
|
|
self.vision_language_config = vision_language_config
|
2023-05-20 13:06:59 -07:00
|
|
|
self.parallel_config = parallel_config
|
|
|
|
self.scheduler_config = scheduler_config
|
2024-02-02 07:46:39 +08:00
|
|
|
self.device_config = device_config
|
2024-04-02 17:40:57 -07:00
|
|
|
self.speculative_config = speculative_config
|
2024-04-16 11:34:39 -07:00
|
|
|
self.load_config = load_config
|
2024-04-16 08:54:57 +03:00
|
|
|
self.decoding_config = decoding_config or DecodingConfig()
|
2023-05-20 13:06:59 -07:00
|
|
|
self.log_stats = log_stats
|
|
|
|
|
2024-04-21 15:06:46 -07:00
|
|
|
if not self.model_config.skip_tokenizer_init:
|
2024-05-29 04:29:31 +08:00
|
|
|
self.tokenizer = self._init_tokenizer()
|
2024-04-21 15:06:46 -07:00
|
|
|
self.detokenizer = Detokenizer(self.tokenizer)
|
|
|
|
else:
|
|
|
|
self.tokenizer = None
|
2024-05-29 04:29:31 +08:00
|
|
|
self.detokenizer = None
|
2024-04-21 15:06:46 -07:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
self.seq_counter = Counter()
|
2024-04-18 21:13:36 -07:00
|
|
|
self.generation_config_fields = _load_generation_config_dict(
|
|
|
|
model_config)
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-04-02 17:40:57 -07:00
|
|
|
self.model_executor = executor_class(
|
|
|
|
model_config=model_config,
|
|
|
|
cache_config=cache_config,
|
|
|
|
parallel_config=parallel_config,
|
|
|
|
scheduler_config=scheduler_config,
|
|
|
|
device_config=device_config,
|
|
|
|
lora_config=lora_config,
|
|
|
|
vision_language_config=vision_language_config,
|
|
|
|
speculative_config=speculative_config,
|
2024-04-16 11:34:39 -07:00
|
|
|
load_config=load_config,
|
2024-04-02 17:40:57 -07:00
|
|
|
)
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-05-11 11:30:37 -07:00
|
|
|
if not self.model_config.embedding_mode:
|
|
|
|
self._initialize_kv_caches()
|
2024-04-09 11:44:15 -07:00
|
|
|
|
2024-03-28 22:16:12 -07:00
|
|
|
# If usage stat is enabled, collect relevant info.
|
|
|
|
if is_usage_stats_enabled():
|
2024-04-03 02:56:26 +08:00
|
|
|
from vllm.model_executor.model_loader import (
|
|
|
|
get_architecture_class_name)
|
2024-03-28 22:16:12 -07:00
|
|
|
usage_message.report_usage(
|
|
|
|
get_architecture_class_name(model_config),
|
|
|
|
usage_context,
|
|
|
|
extra_kvs={
|
|
|
|
# Common configuration
|
|
|
|
"dtype":
|
|
|
|
str(model_config.dtype),
|
|
|
|
"tensor_parallel_size":
|
|
|
|
parallel_config.tensor_parallel_size,
|
|
|
|
"block_size":
|
|
|
|
cache_config.block_size,
|
|
|
|
"gpu_memory_utilization":
|
|
|
|
cache_config.gpu_memory_utilization,
|
|
|
|
|
|
|
|
# Quantization
|
|
|
|
"quantization":
|
|
|
|
model_config.quantization,
|
|
|
|
"kv_cache_dtype":
|
|
|
|
cache_config.cache_dtype,
|
|
|
|
|
|
|
|
# Feature flags
|
|
|
|
"enable_lora":
|
|
|
|
bool(lora_config),
|
|
|
|
"enable_prefix_caching":
|
|
|
|
cache_config.enable_prefix_caching,
|
|
|
|
"enforce_eager":
|
|
|
|
model_config.enforce_eager,
|
|
|
|
"disable_custom_all_reduce":
|
|
|
|
parallel_config.disable_custom_all_reduce,
|
|
|
|
})
|
|
|
|
|
2024-04-21 15:06:46 -07:00
|
|
|
if self.tokenizer:
|
|
|
|
# Ping the tokenizer to ensure liveness if it runs in a
|
|
|
|
# different process.
|
|
|
|
self.tokenizer.ping()
|
2024-03-15 16:37:01 -07:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
# Create the scheduler.
|
2024-03-11 11:03:45 -07:00
|
|
|
# NOTE: the cache_config here have been updated with the numbers of
|
|
|
|
# GPU and CPU blocks, which are profiled in the distributed executor.
|
2024-01-24 00:26:37 +01:00
|
|
|
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
2023-08-02 16:42:01 -07:00
|
|
|
|
2024-01-31 14:58:07 -08:00
|
|
|
# Metric Logging.
|
|
|
|
if self.log_stats:
|
|
|
|
self.stat_logger = StatLogger(
|
2024-02-25 19:54:00 +00:00
|
|
|
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
2024-05-05 06:39:34 +08:00
|
|
|
labels=dict(model_name=model_config.served_model_name),
|
2024-04-29 01:59:33 +03:00
|
|
|
max_model_len=self.model_config.max_model_len)
|
2024-02-29 14:15:18 +08:00
|
|
|
self.stat_logger.info("cache_config", self.cache_config)
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
# Create sequence output processor, e.g. for beam search or
|
|
|
|
# speculative decoding.
|
|
|
|
self.output_processor = (
|
|
|
|
SequenceGroupOutputProcessor.create_output_processor(
|
|
|
|
self.scheduler_config,
|
|
|
|
self.detokenizer,
|
|
|
|
self.scheduler,
|
|
|
|
self.seq_counter,
|
|
|
|
self.get_tokenizer_for_seq,
|
|
|
|
stop_checker=StopChecker(
|
|
|
|
self.scheduler_config.max_model_len,
|
|
|
|
self.get_tokenizer_for_seq,
|
|
|
|
),
|
|
|
|
))
|
|
|
|
|
2024-04-09 11:44:15 -07:00
|
|
|
def _initialize_kv_caches(self) -> None:
|
|
|
|
"""Initialize the KV cache in the worker(s).
|
|
|
|
|
|
|
|
The workers will determine the number of blocks in both the GPU cache
|
|
|
|
and the swap CPU cache.
|
|
|
|
"""
|
|
|
|
num_gpu_blocks, num_cpu_blocks = (
|
|
|
|
self.model_executor.determine_num_available_blocks())
|
|
|
|
|
|
|
|
if self.cache_config.num_gpu_blocks_override is not None:
|
|
|
|
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
|
2024-04-26 16:16:58 +09:00
|
|
|
logger.info(
|
|
|
|
"Overriding num_gpu_blocks=%d with "
|
|
|
|
"num_gpu_blocks_override=%d", num_gpu_blocks,
|
|
|
|
num_gpu_blocks_override)
|
2024-04-09 11:44:15 -07:00
|
|
|
num_gpu_blocks = num_gpu_blocks_override
|
|
|
|
|
|
|
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
|
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
|
|
|
|
|
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
|
|
|
|
2024-03-11 11:03:45 -07:00
|
|
|
@classmethod
|
2024-03-28 22:16:12 -07:00
|
|
|
def from_engine_args(
|
|
|
|
cls,
|
|
|
|
engine_args: EngineArgs,
|
|
|
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
|
|
|
) -> "LLMEngine":
|
2024-03-11 11:03:45 -07:00
|
|
|
"""Creates an LLM engine from the engine arguments."""
|
|
|
|
# Create the engine configs.
|
2024-04-02 17:40:57 -07:00
|
|
|
engine_config = engine_args.create_engine_config()
|
2024-05-14 10:38:59 -07:00
|
|
|
distributed_executor_backend = (
|
|
|
|
engine_config.parallel_config.distributed_executor_backend)
|
2024-03-11 11:03:45 -07:00
|
|
|
|
|
|
|
# Initialize the cluster and specify the executor class.
|
2024-04-02 17:40:57 -07:00
|
|
|
if engine_config.device_config.device_type == "neuron":
|
2024-03-21 18:22:17 -07:00
|
|
|
from vllm.executor.neuron_executor import NeuronExecutor
|
|
|
|
executor_class = NeuronExecutor
|
2024-06-12 11:53:03 -07:00
|
|
|
elif engine_config.device_config.device_type == "tpu":
|
|
|
|
from vllm.executor.tpu_executor import TPUExecutor
|
|
|
|
executor_class = TPUExecutor
|
2024-04-02 17:40:57 -07:00
|
|
|
elif engine_config.device_config.device_type == "cpu":
|
2024-04-02 13:07:30 +08:00
|
|
|
from vllm.executor.cpu_executor import CPUExecutor
|
|
|
|
executor_class = CPUExecutor
|
2024-06-18 02:01:25 +08:00
|
|
|
elif engine_config.device_config.device_type == "xpu":
|
|
|
|
if distributed_executor_backend == "ray":
|
|
|
|
initialize_ray_cluster(engine_config.parallel_config)
|
|
|
|
from vllm.executor.ray_xpu_executor import RayXPUExecutor
|
|
|
|
executor_class = RayXPUExecutor
|
|
|
|
else:
|
|
|
|
from vllm.executor.xpu_executor import XPUExecutor
|
|
|
|
executor_class = XPUExecutor
|
2024-05-14 10:38:59 -07:00
|
|
|
elif distributed_executor_backend == "ray":
|
2024-04-02 17:40:57 -07:00
|
|
|
initialize_ray_cluster(engine_config.parallel_config)
|
2024-03-11 11:03:45 -07:00
|
|
|
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
|
|
|
executor_class = RayGPUExecutor
|
2024-05-14 10:38:59 -07:00
|
|
|
elif distributed_executor_backend == "mp":
|
|
|
|
from vllm.executor.multiproc_gpu_executor import (
|
|
|
|
MultiprocessingGPUExecutor)
|
|
|
|
executor_class = MultiprocessingGPUExecutor
|
2024-03-11 11:03:45 -07:00
|
|
|
else:
|
|
|
|
from vllm.executor.gpu_executor import GPUExecutor
|
|
|
|
executor_class = GPUExecutor
|
|
|
|
|
|
|
|
# Create the LLM engine.
|
2024-03-28 22:16:12 -07:00
|
|
|
engine = cls(
|
2024-04-02 17:40:57 -07:00
|
|
|
**engine_config.to_dict(),
|
2024-03-28 22:16:12 -07:00
|
|
|
executor_class=executor_class,
|
|
|
|
log_stats=not engine_args.disable_log_stats,
|
|
|
|
usage_context=usage_context,
|
|
|
|
)
|
2024-03-11 11:03:45 -07:00
|
|
|
return engine
|
2024-02-09 02:57:25 +09:00
|
|
|
|
2024-03-05 16:17:20 -08:00
|
|
|
def __reduce__(self):
|
|
|
|
# This is to ensure that the LLMEngine is not referenced in
|
|
|
|
# the closure used to initialize Ray worker actors
|
|
|
|
raise RuntimeError("LLMEngine should not be pickled!")
|
|
|
|
|
2024-04-25 16:32:48 -07:00
|
|
|
def __del__(self):
|
|
|
|
# Shutdown model executor when engine is garbage collected
|
|
|
|
# Use getattr since __init__ can fail before the field is set
|
|
|
|
if model_executor := getattr(self, "model_executor", None):
|
|
|
|
model_executor.shutdown()
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
|
|
|
|
"skip_tokenizer_init is True")
|
|
|
|
|
|
|
|
def get_tokenizer_group(
|
|
|
|
self,
|
|
|
|
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
|
|
|
|
if self.tokenizer is None:
|
|
|
|
raise ValueError(fail_msg)
|
|
|
|
|
|
|
|
return self.tokenizer
|
|
|
|
|
2024-03-11 10:17:16 +08:00
|
|
|
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
2024-05-29 04:29:31 +08:00
|
|
|
return self.get_tokenizer_group().get_lora_tokenizer(None)
|
2024-03-11 10:17:16 +08:00
|
|
|
|
|
|
|
def get_tokenizer_for_seq(self,
|
|
|
|
sequence: Sequence) -> "PreTrainedTokenizer":
|
2024-05-29 04:29:31 +08:00
|
|
|
return self.get_tokenizer_group().get_lora_tokenizer(
|
|
|
|
sequence.lora_request)
|
2024-01-24 00:26:37 +01:00
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
|
2024-01-24 00:26:37 +01:00
|
|
|
init_kwargs = dict(
|
2024-03-15 16:37:01 -07:00
|
|
|
tokenizer_id=self.model_config.tokenizer,
|
2024-01-24 00:26:37 +01:00
|
|
|
enable_lora=bool(self.lora_config),
|
|
|
|
max_num_seqs=self.scheduler_config.max_num_seqs,
|
|
|
|
max_input_length=None,
|
|
|
|
tokenizer_mode=self.model_config.tokenizer_mode,
|
|
|
|
trust_remote_code=self.model_config.trust_remote_code,
|
|
|
|
revision=self.model_config.tokenizer_revision)
|
|
|
|
init_kwargs.update(tokenizer_init_kwargs)
|
2024-05-29 04:29:31 +08:00
|
|
|
|
|
|
|
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
|
|
|
|
**init_kwargs)
|
2024-01-24 00:26:37 +01:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def _verify_args(self) -> None:
|
|
|
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
2023-05-23 18:22:26 -07:00
|
|
|
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
2024-01-24 00:26:37 +01:00
|
|
|
if self.lora_config:
|
|
|
|
self.lora_config.verify_with_model_config(self.model_config)
|
|
|
|
self.lora_config.verify_with_scheduler_config(
|
|
|
|
self.scheduler_config)
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
def _get_eos_token_id(
|
|
|
|
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
|
|
|
|
if self.tokenizer is None:
|
|
|
|
logger.warning("Using None for EOS token id because tokenizer "
|
|
|
|
"is not initialized")
|
|
|
|
return None
|
|
|
|
|
|
|
|
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
|
|
|
|
|
|
|
def _add_processed_request(
|
|
|
|
self,
|
|
|
|
request_id: str,
|
|
|
|
processed_inputs: LLMInputs,
|
|
|
|
params: Union[SamplingParams, PoolingParams],
|
|
|
|
arrival_time: float,
|
|
|
|
lora_request: Optional[LoRARequest],
|
|
|
|
) -> None:
|
|
|
|
# Create the sequences.
|
|
|
|
block_size = self.cache_config.block_size
|
|
|
|
seq_id = next(self.seq_counter)
|
|
|
|
eos_token_id = self._get_eos_token_id(lora_request)
|
|
|
|
|
|
|
|
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
|
|
|
lora_request)
|
|
|
|
|
|
|
|
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
|
|
|
if isinstance(params, SamplingParams):
|
|
|
|
seq_group = self._create_sequence_group_with_sampling(
|
|
|
|
request_id,
|
|
|
|
seq,
|
|
|
|
params,
|
|
|
|
arrival_time=arrival_time,
|
|
|
|
lora_request=lora_request,
|
|
|
|
)
|
|
|
|
elif isinstance(params, PoolingParams):
|
|
|
|
seq_group = self._create_sequence_group_with_pooling(
|
|
|
|
request_id,
|
|
|
|
seq,
|
|
|
|
params,
|
|
|
|
arrival_time=arrival_time,
|
|
|
|
lora_request=lora_request,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Either SamplingParams or PoolingParams must be provided.")
|
|
|
|
|
|
|
|
# Add the sequence group to the scheduler.
|
|
|
|
self.scheduler.add_seq_group(seq_group)
|
|
|
|
|
|
|
|
def process_model_inputs(
|
2024-01-24 00:26:37 +01:00
|
|
|
self,
|
2024-05-29 04:29:31 +08:00
|
|
|
request_id: str,
|
|
|
|
inputs: PromptInputs,
|
2024-01-24 00:26:37 +01:00
|
|
|
lora_request: Optional[LoRARequest] = None,
|
2024-05-29 04:29:31 +08:00
|
|
|
) -> LLMInputs:
|
|
|
|
if isinstance(inputs, str):
|
|
|
|
inputs = {"prompt": inputs}
|
|
|
|
|
|
|
|
if "prompt_token_ids" not in inputs:
|
|
|
|
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
|
|
|
"skip_tokenizer_init is True")
|
|
|
|
|
|
|
|
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
|
|
|
prompt=inputs["prompt"],
|
|
|
|
lora_request=lora_request)
|
|
|
|
else:
|
|
|
|
prompt_token_ids = inputs["prompt_token_ids"]
|
|
|
|
|
|
|
|
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
|
|
|
prompt=inputs.get("prompt"),
|
|
|
|
multi_modal_data=inputs.get("multi_modal_data"))
|
2024-01-24 00:26:37 +01:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def add_request(
|
|
|
|
self,
|
|
|
|
request_id: str,
|
2024-05-29 04:29:31 +08:00
|
|
|
inputs: PromptInputs,
|
2024-05-11 11:30:37 -07:00
|
|
|
params: Union[SamplingParams, PoolingParams],
|
2023-05-20 13:06:59 -07:00
|
|
|
arrival_time: Optional[float] = None,
|
2024-01-24 00:26:37 +01:00
|
|
|
lora_request: Optional[LoRARequest] = None,
|
2023-05-20 13:06:59 -07:00
|
|
|
) -> None:
|
2023-06-17 17:25:21 +08:00
|
|
|
"""Add a request to the engine's request pool.
|
2023-06-07 18:25:20 +08:00
|
|
|
|
|
|
|
The request is added to the request pool and will be processed by the
|
2023-06-17 17:25:21 +08:00
|
|
|
scheduler as `engine.step()` is called. The exact scheduling policy is
|
2023-06-07 18:25:20 +08:00
|
|
|
determined by the scheduler.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
request_id: The unique ID of the request.
|
2024-05-29 04:29:31 +08:00
|
|
|
inputs: The inputs to the LLM. See
|
|
|
|
:class:`~vllm.inputs.PromptInputs`
|
|
|
|
for more details about the format of each input.
|
|
|
|
params: Parameters for sampling or pooling.
|
|
|
|
:class:`~vllm.SamplingParams` for text generation.
|
|
|
|
:class:`~vllm.PoolingParams` for pooling.
|
2023-06-07 18:25:20 +08:00
|
|
|
arrival_time: The arrival time of the request. If None, we use
|
2023-10-02 19:22:05 -07:00
|
|
|
the current monotonic time.
|
2024-01-12 11:26:49 +08:00
|
|
|
|
|
|
|
Details:
|
|
|
|
- Set arrival_time to the current time if it is None.
|
|
|
|
- Set prompt_token_ids to the encoded prompt if it is None.
|
|
|
|
- Create `best_of` number of :class:`~vllm.Sequence` objects.
|
|
|
|
- Create a :class:`~vllm.SequenceGroup` object
|
|
|
|
from the list of :class:`~vllm.Sequence`.
|
|
|
|
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> # initialize engine
|
|
|
|
>>> engine = LLMEngine.from_engine_args(engine_args)
|
|
|
|
>>> # set request arguments
|
|
|
|
>>> example_prompt = "Who is the president of the United States?"
|
|
|
|
>>> sampling_params = SamplingParams(temperature=0.0)
|
|
|
|
>>> request_id = 0
|
|
|
|
>>>
|
|
|
|
>>> # add the request to the engine
|
|
|
|
>>> engine.add_request(
|
|
|
|
>>> str(request_id),
|
|
|
|
>>> example_prompt,
|
|
|
|
>>> SamplingParams(temperature=0.0))
|
|
|
|
>>> # continue the request processing
|
|
|
|
>>> ...
|
2023-06-07 18:25:20 +08:00
|
|
|
"""
|
2024-01-24 00:26:37 +01:00
|
|
|
if lora_request is not None and not self.lora_config:
|
|
|
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
|
|
|
"not enabled!")
|
2023-05-20 13:06:59 -07:00
|
|
|
if arrival_time is None:
|
2024-03-16 02:25:43 +08:00
|
|
|
arrival_time = time.time()
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
processed_inputs = self.process_model_inputs(request_id=request_id,
|
|
|
|
inputs=inputs,
|
|
|
|
lora_request=lora_request)
|
2024-05-11 11:30:37 -07:00
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
self._add_processed_request(
|
|
|
|
request_id=request_id,
|
|
|
|
processed_inputs=processed_inputs,
|
|
|
|
params=params,
|
|
|
|
arrival_time=arrival_time,
|
|
|
|
lora_request=lora_request,
|
|
|
|
)
|
2024-05-11 11:30:37 -07:00
|
|
|
|
|
|
|
def _create_sequence_group_with_sampling(
|
|
|
|
self,
|
|
|
|
request_id: str,
|
|
|
|
seq: Sequence,
|
|
|
|
sampling_params: SamplingParams,
|
2024-05-29 04:29:31 +08:00
|
|
|
arrival_time: float,
|
|
|
|
lora_request: Optional[LoRARequest],
|
2024-05-11 11:30:37 -07:00
|
|
|
) -> SequenceGroup:
|
|
|
|
"""Creates a SequenceGroup with SamplingParams."""
|
|
|
|
max_logprobs = self.get_model_config().max_logprobs
|
|
|
|
if (sampling_params.logprobs
|
|
|
|
and sampling_params.logprobs > max_logprobs) or (
|
|
|
|
sampling_params.prompt_logprobs
|
|
|
|
and sampling_params.prompt_logprobs > max_logprobs):
|
|
|
|
raise ValueError(f"Cannot request more than "
|
|
|
|
f"{max_logprobs} logprobs.")
|
|
|
|
|
2024-02-29 11:20:42 -08:00
|
|
|
# Defensive copy of SamplingParams, which are used by the sampler,
|
|
|
|
# this doesn't deep-copy LogitsProcessor objects
|
|
|
|
sampling_params = sampling_params.clone()
|
2024-04-27 09:52:46 -07:00
|
|
|
# Add the eos token id into the sampling_params to support min_tokens
|
2024-03-25 11:14:26 -06:00
|
|
|
# processing
|
2024-04-27 09:52:46 -07:00
|
|
|
if seq.eos_token_id is not None:
|
|
|
|
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
|
2024-04-18 21:13:36 -07:00
|
|
|
sampling_params.update_from_generation_config(
|
|
|
|
self.generation_config_fields)
|
2024-02-17 11:18:04 -08:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
# Create the sequence group.
|
2024-05-11 11:30:37 -07:00
|
|
|
seq_group = SequenceGroup(request_id=request_id,
|
|
|
|
seqs=[seq],
|
|
|
|
arrival_time=arrival_time,
|
|
|
|
sampling_params=sampling_params,
|
2024-05-29 04:29:31 +08:00
|
|
|
lora_request=lora_request)
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2024-05-11 11:30:37 -07:00
|
|
|
return seq_group
|
|
|
|
|
|
|
|
def _create_sequence_group_with_pooling(
|
|
|
|
self,
|
|
|
|
request_id: str,
|
|
|
|
seq: Sequence,
|
|
|
|
pooling_params: PoolingParams,
|
2024-05-29 04:29:31 +08:00
|
|
|
arrival_time: float,
|
|
|
|
lora_request: Optional[LoRARequest],
|
2024-05-11 11:30:37 -07:00
|
|
|
) -> SequenceGroup:
|
|
|
|
"""Creates a SequenceGroup with PoolingParams."""
|
|
|
|
# Defensive copy of PoolingParams, which are used by the pooler
|
|
|
|
pooling_params = pooling_params.clone()
|
|
|
|
# Create the sequence group.
|
|
|
|
seq_group = SequenceGroup(request_id=request_id,
|
|
|
|
seqs=[seq],
|
|
|
|
arrival_time=arrival_time,
|
|
|
|
lora_request=lora_request,
|
|
|
|
pooling_params=pooling_params)
|
|
|
|
return seq_group
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2023-09-03 21:43:43 -07:00
|
|
|
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
|
|
|
"""Aborts a request(s) with the given ID.
|
2023-06-07 18:25:20 +08:00
|
|
|
|
|
|
|
Args:
|
2023-09-03 21:43:43 -07:00
|
|
|
request_id: The ID(s) of the request to abort.
|
2024-01-12 11:26:49 +08:00
|
|
|
|
|
|
|
Details:
|
|
|
|
- Refer to the
|
|
|
|
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
|
|
|
|
from class :class:`~vllm.core.scheduler.Scheduler`.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> # initialize engine and add a request with request_id
|
|
|
|
>>> request_id = str(0)
|
|
|
|
>>> # abort the request
|
|
|
|
>>> engine.abort_request(request_id)
|
2023-06-07 18:25:20 +08:00
|
|
|
"""
|
2023-06-05 23:44:50 +08:00
|
|
|
self.scheduler.abort_seq_group(request_id)
|
|
|
|
|
2023-07-03 14:50:56 -07:00
|
|
|
def get_model_config(self) -> ModelConfig:
|
|
|
|
"""Gets the model configuration."""
|
|
|
|
return self.model_config
|
|
|
|
|
2024-04-27 19:30:08 +08:00
|
|
|
def get_decoding_config(self) -> DecodingConfig:
|
|
|
|
"""Gets the decoding configuration."""
|
|
|
|
return self.decoding_config
|
|
|
|
|
2023-05-28 03:20:05 -07:00
|
|
|
def get_num_unfinished_requests(self) -> int:
|
2023-06-07 18:25:20 +08:00
|
|
|
"""Gets the number of unfinished requests."""
|
2023-05-28 03:20:05 -07:00
|
|
|
return self.scheduler.get_num_unfinished_seq_groups()
|
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def has_unfinished_requests(self) -> bool:
|
2023-06-07 18:25:20 +08:00
|
|
|
"""Returns True if there are unfinished requests."""
|
2023-05-20 13:06:59 -07:00
|
|
|
return self.scheduler.has_unfinished_seqs()
|
|
|
|
|
2024-05-11 11:30:37 -07:00
|
|
|
def _process_sequence_group_outputs(
|
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
outputs: List[EmbeddingSequenceGroupOutput],
|
|
|
|
) -> None:
|
|
|
|
seq_group.embeddings = outputs[0].embeddings
|
|
|
|
|
|
|
|
for seq in seq_group.get_seqs():
|
|
|
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
|
|
|
|
|
|
|
return
|
|
|
|
|
2023-09-04 17:29:42 -07:00
|
|
|
def _process_model_outputs(
|
2024-04-26 22:02:02 +09:00
|
|
|
self,
|
2024-05-29 04:29:31 +08:00
|
|
|
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
|
2024-05-01 11:21:39 +08:00
|
|
|
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
2024-04-26 22:02:02 +09:00
|
|
|
ignored_seq_groups: List[SequenceGroup],
|
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
2024-05-11 11:30:37 -07:00
|
|
|
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
2024-04-16 13:09:21 -07:00
|
|
|
"""Apply the model output to the sequences in the scheduled seq groups.
|
2024-04-18 21:13:36 -07:00
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
Returns RequestOutputs that can be returned to the client.
|
|
|
|
"""
|
|
|
|
|
2024-02-20 21:55:57 -08:00
|
|
|
now = time.time()
|
2024-04-16 13:09:21 -07:00
|
|
|
|
|
|
|
# Organize outputs by [sequence group][step] instead of
|
|
|
|
# [step][sequence group].
|
|
|
|
output_by_sequence_group = create_output_by_sequence_group(
|
2024-05-29 04:29:31 +08:00
|
|
|
output, num_seq_groups=len(scheduled_seq_groups))
|
2024-04-16 13:09:21 -07:00
|
|
|
|
2023-09-04 17:29:42 -07:00
|
|
|
# Update the scheduled sequence groups with the model outputs.
|
2024-04-26 22:02:02 +09:00
|
|
|
for scheduled_seq_group, outputs, seq_group_meta in zip(
|
|
|
|
scheduled_seq_groups, output_by_sequence_group,
|
|
|
|
seq_group_metadata_list):
|
2024-03-29 02:06:01 +09:00
|
|
|
seq_group = scheduled_seq_group.seq_group
|
2024-04-06 02:17:58 +09:00
|
|
|
seq_group.update_num_computed_tokens(
|
|
|
|
scheduled_seq_group.token_chunk_size)
|
2024-05-11 11:30:37 -07:00
|
|
|
if self.model_config.embedding_mode:
|
|
|
|
self._process_sequence_group_outputs(seq_group, outputs)
|
|
|
|
continue
|
2024-04-23 01:02:36 -07:00
|
|
|
|
2024-04-26 22:02:02 +09:00
|
|
|
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
|
|
|
if seq_group_meta.do_sample:
|
2024-04-16 13:09:21 -07:00
|
|
|
self.output_processor.process_outputs(seq_group, outputs)
|
2023-05-21 11:18:00 -07:00
|
|
|
|
|
|
|
# Free the finished sequence groups.
|
|
|
|
self.scheduler.free_finished_seq_groups()
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
# Create the outputs.
|
2024-05-11 11:30:37 -07:00
|
|
|
request_outputs: List[Union[RequestOutput,
|
|
|
|
EmbeddingRequestOutput]] = []
|
2024-03-29 02:06:01 +09:00
|
|
|
for scheduled_seq_group in scheduled_seq_groups:
|
|
|
|
seq_group = scheduled_seq_group.seq_group
|
2024-02-20 21:55:57 -08:00
|
|
|
seq_group.maybe_set_first_token_time(now)
|
2024-05-11 11:30:37 -07:00
|
|
|
request_output = RequestOutputFactory.create(seq_group)
|
2024-01-07 19:48:07 +02:00
|
|
|
request_outputs.append(request_output)
|
2024-04-16 13:09:21 -07:00
|
|
|
for seq_group in ignored_seq_groups:
|
2024-05-11 11:30:37 -07:00
|
|
|
request_output = RequestOutputFactory.create(seq_group)
|
2023-05-20 13:06:59 -07:00
|
|
|
request_outputs.append(request_output)
|
|
|
|
return request_outputs
|
|
|
|
|
2024-05-11 11:30:37 -07:00
|
|
|
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
2023-09-03 21:43:43 -07:00
|
|
|
"""Performs one decoding iteration and returns newly generated results.
|
|
|
|
|
2024-01-12 11:26:49 +08:00
|
|
|
.. figure:: https://i.imgur.com/sv2HssD.png
|
|
|
|
:alt: Overview of the step function
|
|
|
|
:align: center
|
|
|
|
|
|
|
|
Overview of the step function.
|
|
|
|
|
|
|
|
Details:
|
|
|
|
- Step 1: Schedules the sequences to be executed in the next
|
|
|
|
iteration and the token blocks to be swapped in/out/copy.
|
|
|
|
|
|
|
|
- Depending on the scheduling policy,
|
|
|
|
sequences may be `preempted/reordered`.
|
|
|
|
- A Sequence Group (SG) refer to a group of sequences
|
|
|
|
that are generated from the same prompt.
|
|
|
|
|
2024-03-11 11:03:45 -07:00
|
|
|
- Step 2: Calls the distributed executor to execute the model.
|
2024-01-12 11:26:49 +08:00
|
|
|
- Step 3: Processes the model output. This mainly includes:
|
|
|
|
|
|
|
|
- Decodes the relevant outputs.
|
|
|
|
- Updates the scheduled sequence groups with model outputs
|
|
|
|
based on its `sampling parameters` (`use_beam_search` or not).
|
|
|
|
- Frees the finished sequence groups.
|
|
|
|
|
|
|
|
- Finally, it creates and returns the newly generated results.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> # Please see the example/ folder for more detailed examples.
|
|
|
|
>>>
|
|
|
|
>>> # initialize engine and request arguments
|
|
|
|
>>> engine = LLMEngine.from_engine_args(engine_args)
|
|
|
|
>>> example_inputs = [(0, "What is LLM?",
|
|
|
|
>>> SamplingParams(temperature=0.0))]
|
|
|
|
>>>
|
|
|
|
>>> # Start the engine with an event loop
|
|
|
|
>>> while True:
|
|
|
|
>>> if example_inputs:
|
|
|
|
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
2024-05-11 11:30:37 -07:00
|
|
|
>>> engine.add_request(str(req_id),prompt,sampling_params)
|
2024-01-12 11:26:49 +08:00
|
|
|
>>>
|
|
|
|
>>> # continue the request processing
|
|
|
|
>>> request_outputs = engine.step()
|
|
|
|
>>> for request_output in request_outputs:
|
|
|
|
>>> if request_output.finished:
|
|
|
|
>>> # return or show the request output
|
|
|
|
>>>
|
|
|
|
>>> if not (engine.has_unfinished_requests() or example_inputs):
|
|
|
|
>>> break
|
2023-09-03 21:43:43 -07:00
|
|
|
"""
|
2023-12-26 13:41:09 +08:00
|
|
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
2023-09-03 21:43:43 -07:00
|
|
|
|
2024-01-04 03:30:22 +08:00
|
|
|
if not scheduler_outputs.is_empty():
|
2024-05-03 17:47:07 -07:00
|
|
|
execute_model_req = ExecuteModelRequest(
|
2024-04-16 13:09:21 -07:00
|
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
|
|
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
|
|
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
2024-05-03 17:47:07 -07:00
|
|
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
|
|
|
running_queue_size=scheduler_outputs.running_queue_size,
|
|
|
|
)
|
|
|
|
output = self.model_executor.execute_model(
|
|
|
|
execute_model_req=execute_model_req)
|
2024-01-04 03:30:22 +08:00
|
|
|
else:
|
|
|
|
output = []
|
2023-09-03 21:43:43 -07:00
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
request_outputs = self._process_model_outputs(
|
|
|
|
output, scheduler_outputs.scheduled_seq_groups,
|
2024-04-26 22:02:02 +09:00
|
|
|
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
2024-04-16 13:09:21 -07:00
|
|
|
|
|
|
|
# Log stats.
|
2024-05-02 04:08:14 +08:00
|
|
|
self.do_log_stats(scheduler_outputs, output)
|
2024-04-16 13:09:21 -07:00
|
|
|
|
2024-05-22 14:17:27 -07:00
|
|
|
if not request_outputs:
|
|
|
|
# Stop the execute model loop in parallel workers until there are
|
|
|
|
# more requests to process. This avoids waiting indefinitely in
|
|
|
|
# torch.distributed ops which may otherwise timeout, and unblocks
|
|
|
|
# the RPC thread in the workers so that they can process any other
|
|
|
|
# queued control plane messages, such as add/remove lora adapters.
|
|
|
|
self.model_executor.stop_remote_worker_execution_loop()
|
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
return request_outputs
|
2023-09-03 21:43:43 -07:00
|
|
|
|
2024-05-02 04:08:14 +08:00
|
|
|
def do_log_stats(
|
|
|
|
self,
|
|
|
|
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
|
|
|
model_output: Optional[List[SamplerOutput]] = None) -> None:
|
2024-01-31 14:58:07 -08:00
|
|
|
"""Forced log when no requests active."""
|
|
|
|
if self.log_stats:
|
2024-05-02 04:08:14 +08:00
|
|
|
self.stat_logger.log(
|
|
|
|
self._get_stats(scheduler_outputs, model_output))
|
2024-01-05 15:24:42 +02:00
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
def _get_stats(
|
|
|
|
self,
|
|
|
|
scheduler_outputs: Optional[SchedulerOutputs],
|
|
|
|
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
|
|
|
|
"""Get Stats to be Logged to Prometheus.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
scheduler_outputs: Optional, used to populate metrics related to
|
|
|
|
the scheduled batch,
|
|
|
|
model_output: Optional, used to emit speculative decoding metrics
|
|
|
|
which are created by the workers.
|
|
|
|
"""
|
2024-03-16 02:25:43 +08:00
|
|
|
now = time.time()
|
2023-08-02 16:42:01 -07:00
|
|
|
|
2024-04-29 01:59:33 +03:00
|
|
|
# System State
|
|
|
|
# Scheduler State
|
|
|
|
num_running_sys = len(self.scheduler.running)
|
|
|
|
num_swapped_sys = len(self.scheduler.swapped)
|
|
|
|
num_waiting_sys = len(self.scheduler.waiting)
|
|
|
|
|
|
|
|
# KV Cache Usage in %
|
2024-01-31 14:58:07 -08:00
|
|
|
num_total_gpu = self.cache_config.num_gpu_blocks
|
2024-05-11 11:30:37 -07:00
|
|
|
gpu_cache_usage_sys = 0.
|
|
|
|
if num_total_gpu is not None:
|
|
|
|
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
|
|
|
|
)
|
|
|
|
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
2023-08-02 16:42:01 -07:00
|
|
|
|
2024-01-31 14:58:07 -08:00
|
|
|
num_total_cpu = self.cache_config.num_cpu_blocks
|
2024-04-29 01:59:33 +03:00
|
|
|
cpu_cache_usage_sys = 0.
|
2024-05-11 11:30:37 -07:00
|
|
|
if num_total_cpu is not None and num_total_cpu > 0:
|
2024-01-31 14:58:07 -08:00
|
|
|
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
|
|
|
)
|
2024-04-29 01:59:33 +03:00
|
|
|
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
|
|
|
|
|
|
|
# Iteration stats
|
|
|
|
num_prompt_tokens_iter = 0
|
|
|
|
num_generation_tokens_iter = 0
|
|
|
|
time_to_first_tokens_iter: List[float] = []
|
|
|
|
time_per_output_tokens_iter: List[float] = []
|
2024-05-13 23:50:44 +09:00
|
|
|
num_preemption_iter = (0 if scheduler_outputs is None else
|
|
|
|
scheduler_outputs.preempted)
|
2024-04-29 01:59:33 +03:00
|
|
|
|
|
|
|
# Request stats
|
|
|
|
# Latency
|
|
|
|
time_e2e_requests: List[float] = []
|
|
|
|
# Metadata
|
|
|
|
num_prompt_tokens_requests: List[int] = []
|
|
|
|
num_generation_tokens_requests: List[int] = []
|
|
|
|
best_of_requests: List[int] = []
|
|
|
|
n_requests: List[int] = []
|
|
|
|
finished_reason_requests: List[str] = []
|
|
|
|
|
|
|
|
# NOTE: This loop assumes prefill seq_groups are before
|
|
|
|
# decode seq_groups in scheduled_seq_groups.
|
2024-01-31 14:58:07 -08:00
|
|
|
if scheduler_outputs is not None:
|
2024-04-29 01:59:33 +03:00
|
|
|
num_generation_tokens_from_prefill_groups = 0.
|
2024-05-03 12:48:08 +08:00
|
|
|
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
|
|
|
|
# the len of scheduler_outputs.scheduled_seq_groups is !=
|
|
|
|
# scheduler_outputs.num_prefill_groups, this means that
|
|
|
|
# chunked prefills have been detected.
|
2024-04-29 01:59:33 +03:00
|
|
|
|
|
|
|
for idx, scheduled_seq_group in enumerate(
|
|
|
|
scheduler_outputs.scheduled_seq_groups):
|
|
|
|
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
|
2024-03-29 02:06:01 +09:00
|
|
|
seq_group = scheduled_seq_group.seq_group
|
2024-04-29 01:59:33 +03:00
|
|
|
|
|
|
|
# NOTE: a seq_group that completed all of its prefill tokens
|
|
|
|
# in the last iteration will have seq_group.is_prefill() = False
|
|
|
|
# with group_was_prefill = True
|
|
|
|
if group_was_prefill:
|
|
|
|
# Number of prompt tokens.
|
|
|
|
num_prompt_tokens_iter += (
|
|
|
|
scheduled_seq_group.token_chunk_size)
|
|
|
|
|
|
|
|
# If the seq_group just finished the prefill state
|
|
|
|
# get TTFT.
|
|
|
|
if not seq_group.is_prefill():
|
|
|
|
latency = seq_group.get_last_latency(now)
|
|
|
|
time_to_first_tokens_iter.append(latency)
|
|
|
|
|
|
|
|
# One generation token per finished prefill.
|
|
|
|
num_generation_tokens_from_prefill_groups += (
|
|
|
|
seq_group.num_seqs())
|
|
|
|
else:
|
|
|
|
# TPOTs.
|
|
|
|
latency = seq_group.get_last_latency(now)
|
|
|
|
time_per_output_tokens_iter.append(latency)
|
|
|
|
|
|
|
|
# Because of chunked prefill, we can have a single sequence
|
|
|
|
# group that does multiple prompt_runs. To prevent logging
|
|
|
|
# the same metadata more than once per request, we standardize
|
|
|
|
# on logging request level information for finished requests,
|
|
|
|
# which can only happen once.
|
2024-01-31 14:58:07 -08:00
|
|
|
if seq_group.is_finished():
|
2024-04-29 01:59:33 +03:00
|
|
|
# Latency timings
|
2024-02-20 21:55:57 -08:00
|
|
|
time_e2e_requests.append(now -
|
|
|
|
seq_group.metrics.arrival_time)
|
2024-01-31 14:58:07 -08:00
|
|
|
|
2024-04-29 01:59:33 +03:00
|
|
|
# Metadata
|
|
|
|
num_prompt_tokens_requests.append(
|
|
|
|
len(seq_group.prompt_token_ids))
|
|
|
|
num_generation_tokens_requests.extend([
|
|
|
|
seq.get_output_len()
|
|
|
|
for seq in seq_group.get_finished_seqs()
|
|
|
|
])
|
2024-05-11 11:30:37 -07:00
|
|
|
if seq_group.sampling_params is not None:
|
|
|
|
best_of_requests.append(
|
|
|
|
seq_group.sampling_params.best_of)
|
|
|
|
n_requests.append(seq_group.sampling_params.n)
|
2024-04-29 01:59:33 +03:00
|
|
|
finished_reason_requests.extend([
|
|
|
|
SequenceStatus.get_finished_reason(seq.status)
|
|
|
|
for seq in seq_group.get_finished_seqs()
|
|
|
|
])
|
|
|
|
|
|
|
|
# Number of generation tokens.
|
|
|
|
# num_batched_tokens equals the number of prompt_tokens plus the
|
|
|
|
# number of decode_tokens in a single iteration. So,
|
|
|
|
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
|
|
|
|
# + num_generation_tokens_from_prefill_groups (since we generate
|
|
|
|
# one token on prefills on iters where the prefill finishes).
|
|
|
|
num_generation_tokens_iter = (
|
|
|
|
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
|
|
|
|
num_generation_tokens_from_prefill_groups)
|
2024-01-31 14:58:07 -08:00
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
# Spec decode, if enabled, emits specialized metrics from the worker in
|
|
|
|
# sampler output.
|
|
|
|
if model_output and (model_output[0].spec_decode_worker_metrics
|
|
|
|
is not None):
|
|
|
|
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
|
|
|
else:
|
|
|
|
spec_decode_metrics = None
|
|
|
|
|
2024-01-31 14:58:07 -08:00
|
|
|
return Stats(
|
|
|
|
now=now,
|
2024-04-29 01:59:33 +03:00
|
|
|
# System stats
|
|
|
|
# Scheduler State
|
|
|
|
num_running_sys=num_running_sys,
|
|
|
|
num_swapped_sys=num_swapped_sys,
|
|
|
|
num_waiting_sys=num_waiting_sys,
|
|
|
|
# KV Cache Usage in %
|
|
|
|
gpu_cache_usage_sys=gpu_cache_usage_sys,
|
|
|
|
cpu_cache_usage_sys=cpu_cache_usage_sys,
|
|
|
|
|
|
|
|
# Iteration stats
|
|
|
|
num_prompt_tokens_iter=num_prompt_tokens_iter,
|
|
|
|
num_generation_tokens_iter=num_generation_tokens_iter,
|
|
|
|
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
|
|
|
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
2024-04-23 01:02:36 -07:00
|
|
|
spec_decode_metrics=spec_decode_metrics,
|
2024-05-13 23:50:44 +09:00
|
|
|
num_preemption_iter=num_preemption_iter,
|
2024-04-29 01:59:33 +03:00
|
|
|
|
|
|
|
# Request stats
|
|
|
|
# Latency
|
|
|
|
time_e2e_requests=time_e2e_requests,
|
|
|
|
# Metadata
|
|
|
|
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
|
|
|
num_generation_tokens_requests=num_generation_tokens_requests,
|
|
|
|
best_of_requests=best_of_requests,
|
|
|
|
n_requests=n_requests,
|
|
|
|
finished_reason_requests=finished_reason_requests,
|
2023-12-02 16:37:44 -08:00
|
|
|
)
|
|
|
|
|
2024-01-24 00:26:37 +01:00
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
2024-03-11 11:03:45 -07:00
|
|
|
return self.model_executor.add_lora(lora_request)
|
2024-01-24 00:26:37 +01:00
|
|
|
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
2024-03-11 11:03:45 -07:00
|
|
|
return self.model_executor.remove_lora(lora_id)
|
2024-01-24 00:26:37 +01:00
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
def list_loras(self) -> Set[int]:
|
2024-03-11 11:03:45 -07:00
|
|
|
return self.model_executor.list_loras()
|
2024-03-04 14:01:40 -08:00
|
|
|
|
|
|
|
def check_health(self) -> None:
|
2024-03-11 11:03:45 -07:00
|
|
|
self.model_executor.check_health()
|