Enable mypy checking on V1 code (#11105)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2024-12-14 17:54:04 +00:00 committed by GitHub
parent 93abf23a64
commit 6d917d0eeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 160 additions and 121 deletions

View File

@ -29,3 +29,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode run_mypy vllm/spec_decode
run_mypy vllm/worker run_mypy vllm/worker
run_mypy vllm/v1

View File

@ -135,6 +135,8 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, ( assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output return output

View File

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
@ -263,12 +263,13 @@ class KVCacheManager:
""" """
# Default to [] in case a request is freed (aborted) before alloc. # Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, []) blocks = self.req_to_blocks.pop(request.request_id, [])
ordered_blocks: Iterable[KVCacheBlock] = blocks
if self.enable_caching: if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are # Free blocks in reverse order so that the tail blocks are
# freed first. # freed first.
blocks = reversed(blocks) ordered_blocks = reversed(blocks)
for block in blocks: for block in ordered_blocks:
block.decr_ref() block.decr_ref()
if block.ref_cnt == 0: if block.ref_cnt == 0:
self.free_block_queue.append(block) self.free_block_queue.append(block)
@ -396,8 +397,7 @@ class KVCacheManager:
f"{request.request_id}({request})") f"{request.request_id}({request})")
# Compute the hash of the current block. # Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
tuple(block_tokens))
# Update and added the full block to the cache. # Update and added the full block to the cache.
blk.block_hash = block_hash blk.block_hash = block_hash

View File

@ -1,4 +1,5 @@
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple from typing import List, NamedTuple, Optional, Tuple
@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
collision happens when the hash value is the same. collision happens when the hash value is the same.
""" """
hash_value: int hash_value: int
token_ids: Tuple[int] token_ids: Tuple[int, ...]
@dataclass @dataclass
@ -79,8 +80,8 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks = len(blocks) self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks. # Initialize the doubly linked list of free blocks.
self.free_list_head = blocks[0] self.free_list_head: Optional[KVCacheBlock] = blocks[0]
self.free_list_tail = blocks[-1] self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
for i in range(self.num_free_blocks): for i in range(self.num_free_blocks):
if i > 0: if i > 0:
blocks[i].prev_free_block = blocks[i - 1] blocks[i].prev_free_block = blocks[i - 1]
@ -159,7 +160,7 @@ class FreeKVCacheBlockQueue:
def hash_block_tokens(parent_block_hash: Optional[int], def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Tuple[int]) -> BlockHashType: curr_block_token_ids: Sequence[int]) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and """Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing prefix caching. We use LRU cache for this function to avoid recomputing
@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int],
Args: Args:
parent_block_hash: The hash of the parent block. None parent_block_hash: The hash of the parent block. None
if this is the first block. if this is the first block.
curr_block_token_ids: A tuple of token ids in the current curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full. block. The current block is assumed to be full.
Returns: Returns:
@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
""" """
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
curr_block_token_ids) tuple(curr_block_token_ids))
def hash_request_tokens(block_size: int, def hash_request_tokens(block_size: int,
token_ids: List[int]) -> List[BlockHashType]: token_ids: Sequence[int]) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of """Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching. token IDs. The hash value is used for prefix caching.
@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
parent_block_hash_value = None parent_block_hash_value = None
for start in range(0, len(token_ids), block_size): for start in range(0, len(token_ids), block_size):
end = start + block_size end = start + block_size
block_token_ids = tuple(token_ids[start:end]) block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full. # Do not hash the block if it is not full.
if len(block_token_ids) < block_size: if len(block_token_ids) < block_size:
break break

View File

@ -152,6 +152,7 @@ class Scheduler:
break break
if not can_schedule: if not can_schedule:
break break
assert new_blocks is not None
# Schedule the request. # Schedule the request.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)

View File

@ -36,7 +36,7 @@ class EngineCoreRequest:
prompt: Optional[str] prompt: Optional[str]
prompt_token_ids: List[int] prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]] mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]] mm_hashes: Optional[List[str]]
mm_placeholders: Optional[MultiModalPlaceholderDict] mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams sampling_params: SamplingParams
eos_token_id: Optional[int] eos_token_id: Optional[int]
@ -44,10 +44,11 @@ class EngineCoreRequest:
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
class EngineCoreOutput(msgspec.Struct, class EngineCoreOutput(
array_like=True, msgspec.Struct,
omit_defaults=True, array_like=True, # type: ignore[call-arg]
gc=False): omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str request_id: str
new_token_ids: List[int] new_token_ids: List[int]
@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
stop_reason: Union[int, str, None] = None stop_reason: Union[int, str, None] = None
class EngineCoreOutputs(msgspec.Struct, class EngineCoreOutputs(
array_like=True, msgspec.Struct,
omit_defaults=True, array_like=True, # type: ignore[call-arg]
gc=False): omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact, #NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason # e.g. columnwise layout and using an int enum for finish/stop reason
@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
PROFILE = b'\x02' PROFILE = b'\x02'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]

View File

@ -81,7 +81,7 @@ class AsyncLLM(EngineClient):
asyncio_mode=True, asyncio_mode=True,
) )
self.output_handler = None self.output_handler: Optional[asyncio.Task] = None
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
@ -126,7 +126,8 @@ class AsyncLLM(EngineClient):
handler.cancel() handler.cancel()
@classmethod @classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig): def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = ( distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend) vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp": if distributed_executor_backend == "mp":
@ -361,10 +362,10 @@ class AsyncLLM(EngineClient):
logger.debug("Called check_health.") logger.debug("Called check_health.")
async def start_profile(self) -> None: async def start_profile(self) -> None:
await self.engine_core.profile(True) await self.engine_core.profile_async(True)
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
await self.engine_core.profile(False) await self.engine_core.profile_async(False)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@ -380,7 +381,7 @@ class AsyncLLM(EngineClient):
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
return Exception return Exception() # TODO: implement
# Retain V0 name for backwards compatibility. # Retain V0 name for backwards compatibility.

View File

@ -5,7 +5,7 @@ import threading
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from typing import List, Tuple, Type, Union from typing import List, Tuple, Type
import zmq import zmq
import zmq.asyncio import zmq.asyncio
@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest, EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType) EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@ -97,8 +97,10 @@ class EngineCore:
# Note that the cache here is mirrored with the client side of the # Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache # MM mapper, so anything that has a hash must have a HIT cache
# entry here as well. # entry here as well.
request.mm_inputs = self.mm_input_mapper_server.process_inputs( assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes) request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
@ -128,7 +130,7 @@ class EngineCore:
def shutdown(self): def shutdown(self):
self.model_executor.shutdown() self.model_executor.shutdown()
def profile(self, is_start=True): def profile(self, is_start: bool = True):
self.model_executor.profile(is_start) self.model_executor.profile(is_start)
@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the # and to overlap some serialization/deserialization with the
# model forward pass. # model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue. # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = queue.Queue() self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue = queue.Queue() self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
threading.Thread(target=self.process_input_socket, threading.Thread(target=self.process_input_socket,
args=(input_path, ), args=(input_path, ),
daemon=True).start() daemon=True).start()
@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
self._last_logging_time = now self._last_logging_time = now
def _handle_client_request( def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client.""" """Handle EngineCoreRequest or EngineCoreABORT from Client."""
if isinstance(request, EngineCoreRequest): if isinstance(request, EngineCoreRequest):

View File

@ -1,6 +1,6 @@
import atexit import atexit
import os import os
from typing import List, Union from typing import List, Optional
import msgspec import msgspec
import zmq import zmq
@ -10,8 +10,9 @@ from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest, EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType) EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
EngineCoreProcHandle)
from vllm.v1.serial_utils import PickleEncoder from vllm.v1.serial_utils import PickleEncoder
logger = init_logger(__name__) logger = init_logger(__name__)
@ -59,7 +60,7 @@ class EngineCoreClient:
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError raise NotImplementedError
async def profile(self, is_start=True) -> None: def profile(self, is_start: bool = True) -> None:
raise NotImplementedError raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
@ -71,6 +72,9 @@ class EngineCoreClient:
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError raise NotImplementedError
async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
@ -105,7 +109,7 @@ class InprocClient(EngineCoreClient):
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
def profile(self, is_start=True) -> None: def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start) self.engine_core.profile(is_start)
@ -133,7 +137,10 @@ class MPClient(EngineCoreClient):
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context()) if asyncio_mode:
self.ctx = zmq.asyncio.Context()
else:
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Path for IPC. # Path for IPC.
ready_path = get_open_zmq_ipc_path() ready_path = get_open_zmq_ipc_path()
@ -149,11 +156,13 @@ class MPClient(EngineCoreClient):
self.input_socket.bind(input_path) self.input_socket.bind(input_path)
# Start EngineCore in background process. # Start EngineCore in background process.
self.proc_handle: Optional[EngineCoreProcHandle]
self.proc_handle = EngineCoreProc.make_engine_core_process( self.proc_handle = EngineCoreProc.make_engine_core_process(
*args, *args,
input_path=input_path, input_path=
output_path=output_path, input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
ready_path=ready_path, output_path=output_path, # type: ignore[misc]
ready_path=ready_path, # type: ignore[misc]
**kwargs, **kwargs,
) )
atexit.register(self.shutdown) atexit.register(self.shutdown)
@ -204,10 +213,8 @@ class SyncMPClient(MPClient):
engine_core_outputs = self.decoder.decode(frame.buffer).outputs engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs return engine_core_outputs
def _send_input( def _send_input(self, request_type: EngineCoreRequestType,
self, request_type: EngineCoreRequestType, request: EngineCoreRequestUnion) -> None:
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
# (RequestType, SerializedRequest) # (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
@ -219,7 +226,7 @@ class SyncMPClient(MPClient):
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids) self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start=True) -> None: def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE, self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start)) EngineCoreProfile(is_start))
@ -237,10 +244,8 @@ class AsyncMPClient(MPClient):
return engine_core_outputs return engine_core_outputs
async def _send_input( async def _send_input(self, request_type: EngineCoreRequestType,
self, request_type: EngineCoreRequestType, request: EngineCoreRequestUnion) -> None:
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False) await self.input_socket.send_multipart(msg, copy=False)
@ -252,6 +257,6 @@ class AsyncMPClient(MPClient):
if len(request_ids) > 0: if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile(self, is_start=True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE, await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start)) EngineCoreProfile(is_start))

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple, Union
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
@ -97,7 +97,7 @@ class IncrementalDetokenizer:
self, self,
new_token_ids: List[int], new_token_ids: List[int],
finish_reason: Optional[str], finish_reason: Optional[str],
stop_reason: Optional[str], stop_reason: Optional[Union[int, str, None]],
) -> Optional[RequestOutput]: ) -> Optional[RequestOutput]:
""" """
Update RequestState for the request_id by: Update RequestState for the request_id by:

View File

@ -103,7 +103,8 @@ class LLMEngine:
multiprocess_mode=enable_multiprocessing) multiprocess_mode=enable_multiprocessing)
@classmethod @classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig): def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = ( distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend) vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp": if distributed_executor_backend == "mp":

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
import PIL import PIL
from blake3 import blake3 from blake3 import blake3
@ -42,14 +42,14 @@ class MMInputMapperClient:
model_config) model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config) self.mm_registry.init_mm_limits_per_prompt(model_config)
self.mm_cache = LRUDictCache(MM_CACHE_SIZE) self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None self.mm_debug_cache_hit_ratio_steps = None
self.mm_cache_hits = 0 self.mm_cache_hits = 0
self.mm_cache_total = 0 self.mm_cache_total = 0
def cache_hit_ratio(self, steps) -> float: def cache_hit_ratio(self, steps):
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0: if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total) self.mm_cache_hits / self.mm_cache_total)
@ -60,7 +60,7 @@ class MMInputMapperClient:
mm_hashes: Optional[List[str]], mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]], mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]], precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> List[MultiModalKwargs]: ) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
if precomputed_mm_inputs is None: if precomputed_mm_inputs is None:
image_inputs = mm_data["image"] image_inputs = mm_data["image"]
if not isinstance(image_inputs, list): if not isinstance(image_inputs, list):
@ -72,6 +72,7 @@ class MMInputMapperClient:
# Check if hash is enabled # Check if hash is enabled
use_hash = mm_hashes is not None use_hash = mm_hashes is not None
if use_hash: if use_hash:
assert mm_hashes is not None
assert num_inputs == len( assert num_inputs == len(
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format( mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes)) num_inputs, len(mm_hashes))
@ -79,7 +80,7 @@ class MMInputMapperClient:
# Process each image input separately, so that later we can schedule # Process each image input separately, so that later we can schedule
# them in a fine-grained manner. # them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided) # Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes = [] if use_hash else None ret_hashes: Optional[List[str]] = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = [] ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs): for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None: if self.mm_debug_cache_hit_ratio_steps is not None:
@ -88,6 +89,7 @@ class MMInputMapperClient:
mm_hash = None mm_hash = None
mm_input = None mm_input = None
if use_hash: if use_hash:
assert mm_hashes is not None
mm_hash = mm_hashes[input_id] mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash) mm_input = self.mm_cache.get(mm_hash)
@ -105,12 +107,15 @@ class MMInputMapperClient:
if use_hash: if use_hash:
# Add to cache # Add to cache
assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input) self.mm_cache.put(mm_hash, mm_input)
else: else:
self.mm_cache_hits += 1 self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server mm_input = None # Avoids sending mm_input to Server
if use_hash: if use_hash:
assert mm_hash is not None
assert ret_hashes is not None
ret_hashes.append(mm_hash) ret_hashes.append(mm_hash)
ret_inputs.append(mm_input) ret_inputs.append(mm_input)
@ -120,17 +125,18 @@ class MMInputMapperClient:
class MMInputMapperServer: class MMInputMapperServer:
def __init__(self, ): def __init__(self, ):
self.mm_cache = LRUDictCache(MM_CACHE_SIZE) self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs( def process_inputs(
self, self,
mm_inputs: List[Optional[MultiModalKwargs]], mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[Optional[str]], mm_hashes: List[str],
) -> List[MultiModalKwargs]: ) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes) assert len(mm_inputs) == len(mm_hashes)
full_mm_inputs = [] full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes): for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None: if mm_input is None:
mm_input = self.mm_cache.get(mm_hash) mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None assert mm_input is not None

View File

@ -56,7 +56,7 @@ class Processor:
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple from typing import Tuple
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
@ -28,7 +28,7 @@ class Executor(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def profile(self, is_start=True): def profile(self, is_start: bool = True):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -38,11 +38,3 @@ class Executor(ABC):
@abstractmethod @abstractmethod
def check_health(self) -> None: def check_health(self) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
raise NotImplementedError

View File

@ -7,7 +7,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import zmq import zmq
@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_open_port, from vllm.utils import (get_distributed_init_method, get_open_port,
get_open_zmq_ipc_path) get_open_zmq_ipc_path)
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import make_zmq_socket from vllm.v1.utils import make_zmq_socket
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
class MultiprocExecutor: class MultiprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None: def __init__(self, vllm_config: VllmConfig) -> None:
# Call self.shutdown at exit to clean up # Call self.shutdown at exit to clean up
@ -103,7 +104,7 @@ class MultiprocExecutor:
method: str, method: str,
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> []: kwargs: Optional[Dict] = None) -> List[Any]:
""" """
Execute an RPC call on workers. Execute an RPC call on workers.
@ -125,7 +126,7 @@ class MultiprocExecutor:
responses = [None] * self.world_size responses = [None] * self.world_size
for w in self.workers: for w in self.workers:
dequeue_timeout = timeout - (time.monotonic() - start_time() dequeue_timeout = timeout - (time.monotonic() - start_time
) if timeout is not None else None ) if timeout is not None else None
status, result = w.worker_response_mq.dequeue( status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout) timeout=dequeue_timeout)
@ -153,7 +154,7 @@ class MultiprocExecutor:
args=(scheduler_output, ))[0] args=(scheduler_output, ))[0]
return model_output return model_output
def profile(self, is_start=True): def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, )) self.collective_rpc("profile", args=(is_start, ))
return return
@ -185,7 +186,6 @@ class MultiprocExecutor:
p.kill() p.kill()
self._cleanup_sockets() self._cleanup_sockets()
self.workers = None
def _cleanup_sockets(self): def _cleanup_sockets(self):
for w in self.workers: for w in self.workers:
@ -200,7 +200,8 @@ class MultiprocExecutor:
# again # again
atexit.unregister(self.shutdown) atexit.unregister(self.shutdown)
"""Properly shut down the executor and its workers""" """Properly shut down the executor and its workers"""
if (hasattr(self, 'workers') and self.workers is not None): if getattr(self, 'shutting_down', False):
self.shutting_down = True
for w in self.workers: #TODO: not sure if needed for w in self.workers: #TODO: not sure if needed
w.worker_response_mq = None w.worker_response_mq = None
self._ensure_worker_termination() self._ensure_worker_termination()

View File

@ -4,13 +4,14 @@ from typing import Optional, Tuple
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__) logger = init_logger(__name__)
class UniprocExecutor: class UniprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None: def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
@ -25,7 +26,7 @@ class UniprocExecutor:
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.worker = self._create_worker() self.worker: Worker = self._create_worker()
self.worker.initialize() self.worker.initialize()
self.worker.load_model() self.worker.load_model()
@ -75,7 +76,7 @@ class UniprocExecutor:
self.worker.profile(is_start) self.worker.profile(is_start)
def shutdown(self): def shutdown(self):
self.worker = None pass
def check_health(self) -> None: def check_health(self) -> None:
# UniprocExecutor will always be healthy as long as # UniprocExecutor will always be healthy as long as

View File

@ -52,10 +52,9 @@ class Request:
else: else:
self.mm_positions = [] self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors). # Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
if self.inputs.multi_modal_inputs: if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs self.mm_inputs = self.inputs.multi_modal_inputs
else:
self.mm_inputs: List[MultiModalKwargs] = []
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":

View File

@ -1,6 +1,8 @@
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Generic, Iterator, List, TypeVar, overload from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
overload)
import zmq import zmq
@ -11,7 +13,7 @@ logger = init_logger(__name__)
T = TypeVar("T") T = TypeVar("T")
class ConstantList(Generic[T]): class ConstantList(Generic[T], Sequence):
def __init__(self, x: List[T]) -> None: def __init__(self, x: List[T]) -> None:
self._x = x self._x = x
@ -34,29 +36,33 @@ class ConstantList(Generic[T]):
def clear(self): def clear(self):
raise Exception("Cannot clear a constant list") raise Exception("Cannot clear a constant list")
def index(self, item): def index(self,
return self._x.index(item) item: T,
start: int = 0,
stop: Optional[int] = None) -> int:
return self._x.index(item, start,
stop if stop is not None else len(self._x))
@overload @overload
def __getitem__(self, item) -> T: def __getitem__(self, item: int) -> T:
... ...
@overload @overload
def __getitem__(self, s: slice, /) -> List[T]: def __getitem__(self, s: slice, /) -> List[T]:
... ...
def __getitem__(self, item): def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
return self._x[item] return self._x[item]
@overload @overload
def __setitem__(self, item, value): def __setitem__(self, item: int, value: T):
... ...
@overload @overload
def __setitem__(self, s: slice, value, /): def __setitem__(self, s: slice, value: T, /):
... ...
def __setitem__(self, item, value): def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
raise Exception("Cannot set item in a constant list") raise Exception("Cannot set item in a constant list")
def __delitem__(self, item): def __delitem__(self, item):
@ -73,10 +79,12 @@ class ConstantList(Generic[T]):
@contextmanager @contextmanager
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]: def make_zmq_socket(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket""" """Context manager for a ZMQ socket"""
ctx = zmq.Context() ctx = zmq.Context() # type: ignore[attr-defined]
try: try:
socket = ctx.socket(type) socket = ctx.socket(type)
@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
ctx.destroy(linger=0) ctx.destroy(linger=0)
class LRUDictCache: K = TypeVar('K')
V = TypeVar('V')
class LRUDictCache(Generic[K, V]):
def __init__(self, size: int): def __init__(self, size: int):
self.cache = OrderedDict() self.cache: OrderedDict[K, V] = OrderedDict()
self.size = size self.size = size
def get(self, key, default=None): def get(self, key: K, default=None) -> V:
if key not in self.cache: if key not in self.cache:
return default return default
self.cache.move_to_end(key) self.cache.move_to_end(key)
return self.cache[key] return self.cache[key]
def put(self, key, value): def put(self, key: K, value: V):
self.cache[key] = value self.cache[key] = value
self.cache.move_to_end(key) self.cache.move_to_end(key)
if len(self.cache) > self.size: if len(self.cache) > self.size:

View File

@ -215,6 +215,7 @@ class InputBatch:
# Swap the states. # Swap the states.
req_id = self.req_ids[last_req_index] req_id = self.req_ids[last_req_index]
assert req_id is not None
self.req_ids[empty_index] = req_id self.req_ids[empty_index] = req_id
self.req_ids[last_req_index] = None self.req_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index self.req_id_to_index[req_id] = empty_index

View File

@ -1,6 +1,6 @@
import gc import gc
import time import time
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple, cast
import numpy as np import numpy as np
import torch import torch
@ -193,9 +193,9 @@ class GPUModelRunner:
req_ids_to_add: List[str] = [] req_ids_to_add: List[str] = []
# Add new requests to the cached states. # Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id req_id = new_req_data.req_id
sampling_params = req_data.sampling_params sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED: if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device) generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed) generator.manual_seed(sampling_params.seed)
@ -204,25 +204,25 @@ class GPUModelRunner:
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
prompt=req_data.prompt, prompt=new_req_data.prompt,
mm_inputs=req_data.mm_inputs, mm_inputs=new_req_data.mm_inputs,
mm_positions=req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
generator=generator, generator=generator,
block_ids=req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[], output_token_ids=[],
) )
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests. # Update the cached states of the resumed requests.
for req_data in scheduler_output.scheduled_resumed_reqs: for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = req_data.req_id req_id = res_req_data.req_id
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.block_ids = req_data.block_ids req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
@ -259,6 +259,7 @@ class GPUModelRunner:
num_scheduled_tokens = [] num_scheduled_tokens = []
max_num_scheduled_tokens = 0 max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]: for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens) num_scheduled_tokens.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens, max_num_scheduled_tokens = max(max_num_scheduled_tokens,
@ -373,7 +374,7 @@ class GPUModelRunner:
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = [] mm_inputs: List[MultiModalKwargs] = []
req_input_ids: List[Tuple[int, int]] = [] req_input_ids: List[Tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for input_id in encoder_input_ids: for input_id in encoder_input_ids:
@ -406,6 +407,7 @@ class GPUModelRunner:
encoder_outputs: List[torch.Tensor] = [] encoder_outputs: List[torch.Tensor] = []
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
for req_id in self.input_batch.req_ids[:num_reqs]: for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
req_state = self.requests[req_id] req_state = self.requests[req_id]
@ -514,6 +516,7 @@ class GPUModelRunner:
# the requests one by one. Optimize. # the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req_state = self.requests[req_id] req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
@ -539,8 +542,15 @@ class GPUModelRunner:
logprobs = None logprobs = None
else: else:
logprobs = sampler_output.logprobs.cpu() logprobs = sampler_output.logprobs.cpu()
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs], req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids, logprob_token_ids_cpu=logprob_token_ids,

View File

@ -204,7 +204,7 @@ class Worker:
return output if self.rank == 0 else None return output if self.rank == 0 else None
return output return output
def profile(self, is_start=True): def profile(self, is_start: bool = True):
if self.profiler is None: if self.profiler is None:
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
if is_start: if is_start: