Enable mypy checking on V1 code (#11105)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
93abf23a64
commit
6d917d0eeb
@ -29,3 +29,4 @@ run_mypy vllm/plugins
|
||||
run_mypy vllm/prompt_adapter
|
||||
run_mypy vllm/spec_decode
|
||||
run_mypy vllm/worker
|
||||
run_mypy vllm/v1
|
||||
|
@ -135,6 +135,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
@ -263,12 +263,13 @@ class KVCacheManager:
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
blocks = self.req_to_blocks.pop(request.request_id, [])
|
||||
ordered_blocks: Iterable[KVCacheBlock] = blocks
|
||||
if self.enable_caching:
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
blocks = reversed(blocks)
|
||||
ordered_blocks = reversed(blocks)
|
||||
|
||||
for block in blocks:
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
@ -396,8 +397,7 @@ class KVCacheManager:
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||
tuple(block_tokens))
|
||||
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""KV-Cache Utilities."""
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
|
||||
@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
|
||||
collision happens when the hash value is the same.
|
||||
"""
|
||||
hash_value: int
|
||||
token_ids: Tuple[int]
|
||||
token_ids: Tuple[int, ...]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -79,8 +80,8 @@ class FreeKVCacheBlockQueue:
|
||||
self.num_free_blocks = len(blocks)
|
||||
|
||||
# Initialize the doubly linked list of free blocks.
|
||||
self.free_list_head = blocks[0]
|
||||
self.free_list_tail = blocks[-1]
|
||||
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
|
||||
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
|
||||
for i in range(self.num_free_blocks):
|
||||
if i > 0:
|
||||
blocks[i].prev_free_block = blocks[i - 1]
|
||||
@ -159,7 +160,7 @@ class FreeKVCacheBlockQueue:
|
||||
|
||||
|
||||
def hash_block_tokens(parent_block_hash: Optional[int],
|
||||
curr_block_token_ids: Tuple[int]) -> BlockHashType:
|
||||
curr_block_token_ids: Sequence[int]) -> BlockHashType:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||
@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int],
|
||||
Args:
|
||||
parent_block_hash: The hash of the parent block. None
|
||||
if this is the first block.
|
||||
curr_block_token_ids: A tuple of token ids in the current
|
||||
curr_block_token_ids: A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
|
||||
Returns:
|
||||
@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int],
|
||||
The entire tuple is used as the hash key of the block.
|
||||
"""
|
||||
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
|
||||
curr_block_token_ids)
|
||||
tuple(curr_block_token_ids))
|
||||
|
||||
|
||||
def hash_request_tokens(block_size: int,
|
||||
token_ids: List[int]) -> List[BlockHashType]:
|
||||
token_ids: Sequence[int]) -> List[BlockHashType]:
|
||||
"""Computes hash values of a chain of blocks given a sequence of
|
||||
token IDs. The hash value is used for prefix caching.
|
||||
|
||||
@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
|
||||
parent_block_hash_value = None
|
||||
for start in range(0, len(token_ids), block_size):
|
||||
end = start + block_size
|
||||
block_token_ids = tuple(token_ids[start:end])
|
||||
block_token_ids = token_ids[start:end]
|
||||
# Do not hash the block if it is not full.
|
||||
if len(block_token_ids) < block_size:
|
||||
break
|
||||
|
@ -152,6 +152,7 @@ class Scheduler:
|
||||
break
|
||||
if not can_schedule:
|
||||
break
|
||||
assert new_blocks is not None
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
|
@ -36,7 +36,7 @@ class EngineCoreRequest:
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: List[int]
|
||||
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
|
||||
mm_hashes: Optional[List[Optional[str]]]
|
||||
mm_hashes: Optional[List[str]]
|
||||
mm_placeholders: Optional[MultiModalPlaceholderDict]
|
||||
sampling_params: SamplingParams
|
||||
eos_token_id: Optional[int]
|
||||
@ -44,10 +44,11 @@ class EngineCoreRequest:
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
|
||||
class EngineCoreOutput(msgspec.Struct,
|
||||
array_like=True,
|
||||
omit_defaults=True,
|
||||
gc=False):
|
||||
class EngineCoreOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
request_id: str
|
||||
new_token_ids: List[int]
|
||||
@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
|
||||
stop_reason: Union[int, str, None] = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(msgspec.Struct,
|
||||
array_like=True,
|
||||
omit_defaults=True,
|
||||
gc=False):
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
#NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout and using an int enum for finish/stop reason
|
||||
@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
PROFILE = b'\x02'
|
||||
|
||||
|
||||
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
|
||||
|
@ -81,7 +81,7 @@ class AsyncLLM(EngineClient):
|
||||
asyncio_mode=True,
|
||||
)
|
||||
|
||||
self.output_handler = None
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
@ -126,7 +126,8 @@ class AsyncLLM(EngineClient):
|
||||
handler.cancel()
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
|
||||
executor_class: Type[Executor]
|
||||
distributed_executor_backend = (
|
||||
vllm_config.parallel_config.distributed_executor_backend)
|
||||
if distributed_executor_backend == "mp":
|
||||
@ -361,10 +362,10 @@ class AsyncLLM(EngineClient):
|
||||
logger.debug("Called check_health.")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
await self.engine_core.profile(True)
|
||||
await self.engine_core.profile_async(True)
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
await self.engine_core.profile(False)
|
||||
await self.engine_core.profile_async(False)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -380,7 +381,7 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return Exception
|
||||
return Exception() # TODO: implement
|
||||
|
||||
|
||||
# Retain V0 name for backwards compatibility.
|
||||
|
@ -5,7 +5,7 @@ import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import List, Tuple, Type, Union
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
EngineCoreRequestType, EngineCoreRequestUnion)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -97,8 +97,10 @@ class EngineCore:
|
||||
# Note that the cache here is mirrored with the client side of the
|
||||
# MM mapper, so anything that has a hash must have a HIT cache
|
||||
# entry here as well.
|
||||
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
assert request.mm_inputs is not None
|
||||
request.mm_inputs, request.mm_hashes = (
|
||||
self.mm_input_mapper_server.process_inputs(
|
||||
request.mm_inputs, request.mm_hashes))
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
|
||||
@ -128,7 +130,7 @@ class EngineCore:
|
||||
def shutdown(self):
|
||||
self.model_executor.shutdown()
|
||||
|
||||
def profile(self, is_start=True):
|
||||
def profile(self, is_start: bool = True):
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
|
||||
@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = queue.Queue()
|
||||
self.output_queue = queue.Queue()
|
||||
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
|
||||
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, ),
|
||||
daemon=True).start()
|
||||
@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
self._last_logging_time = now
|
||||
|
||||
def _handle_client_request(
|
||||
self, request: Union[EngineCoreRequest, EngineCoreProfile,
|
||||
List[str]]) -> None:
|
||||
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
|
||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
||||
|
||||
if isinstance(request, EngineCoreRequest):
|
||||
|
@ -1,6 +1,6 @@
|
||||
import atexit
|
||||
import os
|
||||
from typing import List, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
@ -10,8 +10,9 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
EngineCoreRequestType, EngineCoreRequestUnion)
|
||||
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
|
||||
EngineCoreProcHandle)
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -59,7 +60,7 @@ class EngineCoreClient:
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def profile(self, is_start=True) -> None:
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
@ -71,6 +72,9 @@ class EngineCoreClient:
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -105,7 +109,7 @@ class InprocClient(EngineCoreClient):
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def profile(self, is_start=True) -> None:
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
self.engine_core.profile(is_start)
|
||||
|
||||
|
||||
@ -133,7 +137,10 @@ class MPClient(EngineCoreClient):
|
||||
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
|
||||
|
||||
# ZMQ setup.
|
||||
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
|
||||
if asyncio_mode:
|
||||
self.ctx = zmq.asyncio.Context()
|
||||
else:
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
# Path for IPC.
|
||||
ready_path = get_open_zmq_ipc_path()
|
||||
@ -149,11 +156,13 @@ class MPClient(EngineCoreClient):
|
||||
self.input_socket.bind(input_path)
|
||||
|
||||
# Start EngineCore in background process.
|
||||
self.proc_handle: Optional[EngineCoreProcHandle]
|
||||
self.proc_handle = EngineCoreProc.make_engine_core_process(
|
||||
*args,
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
ready_path=ready_path,
|
||||
input_path=
|
||||
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
|
||||
output_path=output_path, # type: ignore[misc]
|
||||
ready_path=ready_path, # type: ignore[misc]
|
||||
**kwargs,
|
||||
)
|
||||
atexit.register(self.shutdown)
|
||||
@ -204,10 +213,8 @@ class SyncMPClient(MPClient):
|
||||
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
|
||||
return engine_core_outputs
|
||||
|
||||
def _send_input(
|
||||
self, request_type: EngineCoreRequestType,
|
||||
request: Union[EngineCoreRequest, EngineCoreProfile,
|
||||
List[str]]) -> None:
|
||||
def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: EngineCoreRequestUnion) -> None:
|
||||
|
||||
# (RequestType, SerializedRequest)
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
@ -219,7 +226,7 @@ class SyncMPClient(MPClient):
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
def profile(self, is_start=True) -> None:
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
self._send_input(EngineCoreRequestType.PROFILE,
|
||||
EngineCoreProfile(is_start))
|
||||
|
||||
@ -237,10 +244,8 @@ class AsyncMPClient(MPClient):
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
async def _send_input(
|
||||
self, request_type: EngineCoreRequestType,
|
||||
request: Union[EngineCoreRequest, EngineCoreProfile,
|
||||
List[str]]) -> None:
|
||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: EngineCoreRequestUnion) -> None:
|
||||
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
await self.input_socket.send_multipart(msg, copy=False)
|
||||
@ -252,6 +257,6 @@ class AsyncMPClient(MPClient):
|
||||
if len(request_ids) > 0:
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
async def profile(self, is_start=True) -> None:
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
await self._send_input(EngineCoreRequestType.PROFILE,
|
||||
EngineCoreProfile(is_start))
|
||||
|
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
@ -97,7 +97,7 @@ class IncrementalDetokenizer:
|
||||
self,
|
||||
new_token_ids: List[int],
|
||||
finish_reason: Optional[str],
|
||||
stop_reason: Optional[str],
|
||||
stop_reason: Optional[Union[int, str, None]],
|
||||
) -> Optional[RequestOutput]:
|
||||
"""
|
||||
Update RequestState for the request_id by:
|
||||
|
@ -103,7 +103,8 @@ class LLMEngine:
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
|
||||
executor_class: Type[Executor]
|
||||
distributed_executor_backend = (
|
||||
vllm_config.parallel_config.distributed_executor_backend)
|
||||
if distributed_executor_backend == "mp":
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import PIL
|
||||
from blake3 import blake3
|
||||
@ -42,14 +42,14 @@ class MMInputMapperClient:
|
||||
model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
|
||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.mm_debug_cache_hit_ratio_steps = None
|
||||
self.mm_cache_hits = 0
|
||||
self.mm_cache_total = 0
|
||||
|
||||
def cache_hit_ratio(self, steps) -> float:
|
||||
def cache_hit_ratio(self, steps):
|
||||
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
|
||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||
self.mm_cache_hits / self.mm_cache_total)
|
||||
@ -60,7 +60,7 @@ class MMInputMapperClient:
|
||||
mm_hashes: Optional[List[str]],
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
||||
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
|
||||
) -> List[MultiModalKwargs]:
|
||||
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
|
||||
if precomputed_mm_inputs is None:
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
@ -72,6 +72,7 @@ class MMInputMapperClient:
|
||||
# Check if hash is enabled
|
||||
use_hash = mm_hashes is not None
|
||||
if use_hash:
|
||||
assert mm_hashes is not None
|
||||
assert num_inputs == len(
|
||||
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
|
||||
num_inputs, len(mm_hashes))
|
||||
@ -79,7 +80,7 @@ class MMInputMapperClient:
|
||||
# Process each image input separately, so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
||||
ret_hashes = [] if use_hash else None
|
||||
ret_hashes: Optional[List[str]] = [] if use_hash else None
|
||||
ret_inputs: List[MultiModalKwargs] = []
|
||||
for input_id in range(num_inputs):
|
||||
if self.mm_debug_cache_hit_ratio_steps is not None:
|
||||
@ -88,6 +89,7 @@ class MMInputMapperClient:
|
||||
mm_hash = None
|
||||
mm_input = None
|
||||
if use_hash:
|
||||
assert mm_hashes is not None
|
||||
mm_hash = mm_hashes[input_id]
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
|
||||
@ -105,12 +107,15 @@ class MMInputMapperClient:
|
||||
|
||||
if use_hash:
|
||||
# Add to cache
|
||||
assert mm_hash is not None
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
else:
|
||||
self.mm_cache_hits += 1
|
||||
mm_input = None # Avoids sending mm_input to Server
|
||||
|
||||
if use_hash:
|
||||
assert mm_hash is not None
|
||||
assert ret_hashes is not None
|
||||
ret_hashes.append(mm_hash)
|
||||
ret_inputs.append(mm_input)
|
||||
|
||||
@ -120,17 +125,18 @@ class MMInputMapperClient:
|
||||
class MMInputMapperServer:
|
||||
|
||||
def __init__(self, ):
|
||||
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
|
||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
mm_inputs: List[Optional[MultiModalKwargs]],
|
||||
mm_hashes: List[Optional[str]],
|
||||
mm_hashes: List[str],
|
||||
) -> List[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
full_mm_inputs = []
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
assert mm_hash is not None
|
||||
if mm_input is None:
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
assert mm_input is not None
|
||||
|
@ -56,7 +56,7 @@ class Processor:
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@ -28,7 +28,7 @@ class Executor(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def profile(self, is_start=True):
|
||||
def profile(self, is_start: bool = True):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -38,11 +38,3 @@ class Executor(ABC):
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> []:
|
||||
raise NotImplementedError
|
||||
|
@ -7,7 +7,7 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import zmq
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_open_port,
|
||||
get_open_zmq_ipc_path)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import make_zmq_socket
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
|
||||
|
||||
class MultiprocExecutor:
|
||||
class MultiprocExecutor(Executor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
@ -103,7 +104,7 @@ class MultiprocExecutor:
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> []:
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
"""
|
||||
Execute an RPC call on workers.
|
||||
|
||||
@ -125,7 +126,7 @@ class MultiprocExecutor:
|
||||
|
||||
responses = [None] * self.world_size
|
||||
for w in self.workers:
|
||||
dequeue_timeout = timeout - (time.monotonic() - start_time()
|
||||
dequeue_timeout = timeout - (time.monotonic() - start_time
|
||||
) if timeout is not None else None
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout)
|
||||
@ -153,7 +154,7 @@ class MultiprocExecutor:
|
||||
args=(scheduler_output, ))[0]
|
||||
return model_output
|
||||
|
||||
def profile(self, is_start=True):
|
||||
def profile(self, is_start: bool = True):
|
||||
self.collective_rpc("profile", args=(is_start, ))
|
||||
return
|
||||
|
||||
@ -185,7 +186,6 @@ class MultiprocExecutor:
|
||||
p.kill()
|
||||
|
||||
self._cleanup_sockets()
|
||||
self.workers = None
|
||||
|
||||
def _cleanup_sockets(self):
|
||||
for w in self.workers:
|
||||
@ -200,7 +200,8 @@ class MultiprocExecutor:
|
||||
# again
|
||||
atexit.unregister(self.shutdown)
|
||||
"""Properly shut down the executor and its workers"""
|
||||
if (hasattr(self, 'workers') and self.workers is not None):
|
||||
if getattr(self, 'shutting_down', False):
|
||||
self.shutting_down = True
|
||||
for w in self.workers: #TODO: not sure if needed
|
||||
w.worker_response_mq = None
|
||||
self._ensure_worker_termination()
|
||||
|
@ -4,13 +4,14 @@ from typing import Optional, Tuple
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class UniprocExecutor:
|
||||
class UniprocExecutor(Executor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
@ -25,7 +26,7 @@ class UniprocExecutor:
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.worker = self._create_worker()
|
||||
self.worker: Worker = self._create_worker()
|
||||
self.worker.initialize()
|
||||
self.worker.load_model()
|
||||
|
||||
@ -75,7 +76,7 @@ class UniprocExecutor:
|
||||
self.worker.profile(is_start)
|
||||
|
||||
def shutdown(self):
|
||||
self.worker = None
|
||||
pass
|
||||
|
||||
def check_health(self) -> None:
|
||||
# UniprocExecutor will always be healthy as long as
|
||||
|
@ -52,10 +52,9 @@ class Request:
|
||||
else:
|
||||
self.mm_positions = []
|
||||
# Output of the mm input mapper (e.g., image tensors).
|
||||
self.mm_inputs: List[MultiModalKwargs] = []
|
||||
if self.inputs.multi_modal_inputs:
|
||||
self.mm_inputs = self.inputs.multi_modal_inputs
|
||||
else:
|
||||
self.mm_inputs: List[MultiModalKwargs] = []
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
|
@ -1,6 +1,8 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generic, Iterator, List, TypeVar, overload
|
||||
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
|
||||
overload)
|
||||
|
||||
import zmq
|
||||
|
||||
@ -11,7 +13,7 @@ logger = init_logger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ConstantList(Generic[T]):
|
||||
class ConstantList(Generic[T], Sequence):
|
||||
|
||||
def __init__(self, x: List[T]) -> None:
|
||||
self._x = x
|
||||
@ -34,29 +36,33 @@ class ConstantList(Generic[T]):
|
||||
def clear(self):
|
||||
raise Exception("Cannot clear a constant list")
|
||||
|
||||
def index(self, item):
|
||||
return self._x.index(item)
|
||||
def index(self,
|
||||
item: T,
|
||||
start: int = 0,
|
||||
stop: Optional[int] = None) -> int:
|
||||
return self._x.index(item, start,
|
||||
stop if stop is not None else len(self._x))
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item) -> T:
|
||||
def __getitem__(self, item: int) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> List[T]:
|
||||
...
|
||||
|
||||
def __getitem__(self, item):
|
||||
def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
|
||||
return self._x[item]
|
||||
|
||||
@overload
|
||||
def __setitem__(self, item, value):
|
||||
def __setitem__(self, item: int, value: T):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value, /):
|
||||
def __setitem__(self, s: slice, value: T, /):
|
||||
...
|
||||
|
||||
def __setitem__(self, item, value):
|
||||
def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
|
||||
raise Exception("Cannot set item in a constant list")
|
||||
|
||||
def __delitem__(self, item):
|
||||
@ -73,10 +79,12 @@ class ConstantList(Generic[T]):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
|
||||
def make_zmq_socket(
|
||||
path: str,
|
||||
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context()
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
socket = ctx.socket(type)
|
||||
|
||||
@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
class LRUDictCache:
|
||||
K = TypeVar('K')
|
||||
V = TypeVar('V')
|
||||
|
||||
|
||||
class LRUDictCache(Generic[K, V]):
|
||||
|
||||
def __init__(self, size: int):
|
||||
self.cache = OrderedDict()
|
||||
self.cache: OrderedDict[K, V] = OrderedDict()
|
||||
self.size = size
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key: K, default=None) -> V:
|
||||
if key not in self.cache:
|
||||
return default
|
||||
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key, value):
|
||||
def put(self, key: K, value: V):
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
if len(self.cache) > self.size:
|
||||
|
@ -215,6 +215,7 @@ class InputBatch:
|
||||
|
||||
# Swap the states.
|
||||
req_id = self.req_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self.req_ids[empty_index] = req_id
|
||||
self.req_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
@ -1,6 +1,6 @@
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -193,9 +193,9 @@ class GPUModelRunner:
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req_data.req_id
|
||||
sampling_params = req_data.sampling_params
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
@ -204,25 +204,25 @@ class GPUModelRunner:
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=req_data.prompt_token_ids,
|
||||
prompt=req_data.prompt,
|
||||
mm_inputs=req_data.mm_inputs,
|
||||
mm_positions=req_data.mm_positions,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt=new_req_data.prompt,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=req_data.block_ids,
|
||||
num_computed_tokens=req_data.num_computed_tokens,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = req_data.req_id
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = req_data.block_ids
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
@ -259,6 +259,7 @@ class GPUModelRunner:
|
||||
num_scheduled_tokens = []
|
||||
max_num_scheduled_tokens = 0
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens.append(num_tokens)
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
@ -373,7 +374,7 @@ class GPUModelRunner:
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_inputs: List[MultiModalKwargs] = []
|
||||
req_input_ids: List[Tuple[int, int]] = []
|
||||
req_input_ids: List[Tuple[str, int]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
for input_id in encoder_input_ids:
|
||||
@ -406,6 +407,7 @@ class GPUModelRunner:
|
||||
encoder_outputs: List[torch.Tensor] = []
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
@ -514,6 +516,7 @@ class GPUModelRunner:
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
@ -539,8 +542,15 @@ class GPUModelRunner:
|
||||
logprobs = None
|
||||
else:
|
||||
logprobs = sampler_output.logprobs.cpu()
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids[:num_reqs],
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprob_token_ids_cpu=logprob_token_ids,
|
||||
|
@ -204,7 +204,7 @@ class Worker:
|
||||
return output if self.rank == 0 else None
|
||||
return output
|
||||
|
||||
def profile(self, is_start=True):
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
|
Loading…
x
Reference in New Issue
Block a user