Enable mypy checking on V1 code (#11105)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2024-12-14 17:54:04 +00:00 committed by GitHub
parent 93abf23a64
commit 6d917d0eeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 160 additions and 121 deletions

View File

@ -29,3 +29,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1

View File

@ -135,6 +135,8 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
@ -263,12 +263,13 @@ class KVCacheManager:
"""
# Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, [])
ordered_blocks: Iterable[KVCacheBlock] = blocks
if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
blocks = reversed(blocks)
ordered_blocks = reversed(blocks)
for block in blocks:
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
self.free_block_queue.append(block)
@ -396,8 +397,7 @@ class KVCacheManager:
f"{request.request_id}({request})")
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
tuple(block_tokens))
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
# Update and added the full block to the cache.
blk.block_hash = block_hash

View File

@ -1,4 +1,5 @@
"""KV-Cache Utilities."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple
@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
collision happens when the hash value is the same.
"""
hash_value: int
token_ids: Tuple[int]
token_ids: Tuple[int, ...]
@dataclass
@ -79,8 +80,8 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks.
self.free_list_head = blocks[0]
self.free_list_tail = blocks[-1]
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
@ -159,7 +160,7 @@ class FreeKVCacheBlockQueue:
def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Tuple[int]) -> BlockHashType:
curr_block_token_ids: Sequence[int]) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int],
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A tuple of token ids in the current
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
Returns:
@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
curr_block_token_ids)
tuple(curr_block_token_ids))
def hash_request_tokens(block_size: int,
token_ids: List[int]) -> List[BlockHashType]:
token_ids: Sequence[int]) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break

View File

@ -152,6 +152,7 @@ class Scheduler:
break
if not can_schedule:
break
assert new_blocks is not None
# Schedule the request.
scheduled_running_reqs.append(request)

View File

@ -36,7 +36,7 @@ class EngineCoreRequest:
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]]
mm_hashes: Optional[List[str]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
eos_token_id: Optional[int]
@ -44,10 +44,11 @@ class EngineCoreRequest:
lora_request: Optional[LoRARequest]
class EngineCoreOutput(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
new_token_ids: List[int]
@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
stop_reason: Union[int, str, None] = None
class EngineCoreOutputs(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]

View File

@ -81,7 +81,7 @@ class AsyncLLM(EngineClient):
asyncio_mode=True,
)
self.output_handler = None
self.output_handler: Optional[asyncio.Task] = None
def __del__(self):
self.shutdown()
@ -126,7 +126,8 @@ class AsyncLLM(EngineClient):
handler.cancel()
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
@ -361,10 +362,10 @@ class AsyncLLM(EngineClient):
logger.debug("Called check_health.")
async def start_profile(self) -> None:
await self.engine_core.profile(True)
await self.engine_core.profile_async(True)
async def stop_profile(self) -> None:
await self.engine_core.profile(False)
await self.engine_core.profile_async(False)
@property
def is_running(self) -> bool:
@ -380,7 +381,7 @@ class AsyncLLM(EngineClient):
@property
def dead_error(self) -> BaseException:
return Exception
return Exception() # TODO: implement
# Retain V0 name for backwards compatibility.

View File

@ -5,7 +5,7 @@ import threading
import time
from dataclasses import dataclass
from multiprocessing.process import BaseProcess
from typing import List, Tuple, Type, Union
from typing import List, Tuple, Type
import zmq
import zmq.asyncio
@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
@ -97,8 +97,10 @@ class EngineCore:
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))
req = Request.from_engine_core_request(request)
@ -128,7 +130,7 @@ class EngineCore:
def shutdown(self):
self.model_executor.shutdown()
def profile(self, is_start=True):
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
self._last_logging_time = now
def _handle_client_request(
self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
if isinstance(request, EngineCoreRequest):

View File

@ -1,6 +1,6 @@
import atexit
import os
from typing import List, Union
from typing import List, Optional
import msgspec
import zmq
@ -10,8 +10,9 @@ from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
EngineCoreProcHandle)
from vllm.v1.serial_utils import PickleEncoder
logger = init_logger(__name__)
@ -59,7 +60,7 @@ class EngineCoreClient:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
async def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None:
@ -71,6 +72,9 @@ class EngineCoreClient:
async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError
@ -105,7 +109,7 @@ class InprocClient(EngineCoreClient):
def __del__(self):
self.shutdown()
def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)
@ -133,7 +137,10 @@ class MPClient(EngineCoreClient):
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
# ZMQ setup.
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
if asyncio_mode:
self.ctx = zmq.asyncio.Context()
else:
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Path for IPC.
ready_path = get_open_zmq_ipc_path()
@ -149,11 +156,13 @@ class MPClient(EngineCoreClient):
self.input_socket.bind(input_path)
# Start EngineCore in background process.
self.proc_handle: Optional[EngineCoreProcHandle]
self.proc_handle = EngineCoreProc.make_engine_core_process(
*args,
input_path=input_path,
output_path=output_path,
ready_path=ready_path,
input_path=
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
output_path=output_path, # type: ignore[misc]
ready_path=ready_path, # type: ignore[misc]
**kwargs,
)
atexit.register(self.shutdown)
@ -204,10 +213,8 @@ class SyncMPClient(MPClient):
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs
def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
@ -219,7 +226,7 @@ class SyncMPClient(MPClient):
def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
@ -237,10 +244,8 @@ class AsyncMPClient(MPClient):
return engine_core_outputs
async def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:
msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
@ -252,6 +257,6 @@ class AsyncMPClient(MPClient):
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile(self, is_start=True) -> None:
async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, Union
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
@ -97,7 +97,7 @@ class IncrementalDetokenizer:
self,
new_token_ids: List[int],
finish_reason: Optional[str],
stop_reason: Optional[str],
stop_reason: Optional[Union[int, str, None]],
) -> Optional[RequestOutput]:
"""
Update RequestState for the request_id by:

View File

@ -103,7 +103,8 @@ class LLMEngine:
multiprocess_mode=enable_multiprocessing)
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import PIL
from blake3 import blake3
@ -42,14 +42,14 @@ class MMInputMapperClient:
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
self.mm_cache_hits = 0
self.mm_cache_total = 0
def cache_hit_ratio(self, steps) -> float:
def cache_hit_ratio(self, steps):
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)
@ -60,7 +60,7 @@ class MMInputMapperClient:
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> List[MultiModalKwargs]:
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
@ -72,6 +72,7 @@ class MMInputMapperClient:
# Check if hash is enabled
use_hash = mm_hashes is not None
if use_hash:
assert mm_hashes is not None
assert num_inputs == len(
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
@ -79,7 +80,7 @@ class MMInputMapperClient:
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes = [] if use_hash else None
ret_hashes: Optional[List[str]] = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
@ -88,6 +89,7 @@ class MMInputMapperClient:
mm_hash = None
mm_input = None
if use_hash:
assert mm_hashes is not None
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
@ -105,12 +107,15 @@ class MMInputMapperClient:
if use_hash:
# Add to cache
assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input)
else:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
if use_hash:
assert mm_hash is not None
assert ret_hashes is not None
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input)
@ -120,17 +125,18 @@ class MMInputMapperClient:
class MMInputMapperServer:
def __init__(self, ):
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs(
self,
mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[Optional[str]],
mm_hashes: List[str],
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None:
mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None

View File

@ -56,7 +56,7 @@ class Processor:
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple
from typing import Tuple
from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput
@ -28,7 +28,7 @@ class Executor(ABC):
raise NotImplementedError
@abstractmethod
def profile(self, is_start=True):
def profile(self, is_start: bool = True):
raise NotImplementedError
@abstractmethod
@ -38,11 +38,3 @@ class Executor(ABC):
@abstractmethod
def check_health(self) -> None:
raise NotImplementedError
@abstractmethod
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
raise NotImplementedError

View File

@ -7,7 +7,7 @@ import time
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing.process import BaseProcess
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import zmq
@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_open_port,
get_open_zmq_ipc_path)
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import make_zmq_socket
from vllm.worker.worker_base import WorkerWrapperBase
@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
class MultiprocExecutor:
class MultiprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
# Call self.shutdown at exit to clean up
@ -103,7 +104,7 @@ class MultiprocExecutor:
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Execute an RPC call on workers.
@ -125,7 +126,7 @@ class MultiprocExecutor:
responses = [None] * self.world_size
for w in self.workers:
dequeue_timeout = timeout - (time.monotonic() - start_time()
dequeue_timeout = timeout - (time.monotonic() - start_time
) if timeout is not None else None
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout)
@ -153,7 +154,7 @@ class MultiprocExecutor:
args=(scheduler_output, ))[0]
return model_output
def profile(self, is_start=True):
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
return
@ -185,7 +186,6 @@ class MultiprocExecutor:
p.kill()
self._cleanup_sockets()
self.workers = None
def _cleanup_sockets(self):
for w in self.workers:
@ -200,7 +200,8 @@ class MultiprocExecutor:
# again
atexit.unregister(self.shutdown)
"""Properly shut down the executor and its workers"""
if (hasattr(self, 'workers') and self.workers is not None):
if getattr(self, 'shutting_down', False):
self.shutting_down = True
for w in self.workers: #TODO: not sure if needed
w.worker_response_mq = None
self._ensure_worker_termination()

View File

@ -4,13 +4,14 @@ from typing import Optional, Tuple
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
class UniprocExecutor:
class UniprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
@ -25,7 +26,7 @@ class UniprocExecutor:
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.worker = self._create_worker()
self.worker: Worker = self._create_worker()
self.worker.initialize()
self.worker.load_model()
@ -75,7 +76,7 @@ class UniprocExecutor:
self.worker.profile(is_start)
def shutdown(self):
self.worker = None
pass
def check_health(self) -> None:
# UniprocExecutor will always be healthy as long as

View File

@ -52,10 +52,9 @@ class Request:
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
else:
self.mm_inputs: List[MultiModalKwargs] = []
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":

View File

@ -1,6 +1,8 @@
from collections import OrderedDict
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Generic, Iterator, List, TypeVar, overload
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
overload)
import zmq
@ -11,7 +13,7 @@ logger = init_logger(__name__)
T = TypeVar("T")
class ConstantList(Generic[T]):
class ConstantList(Generic[T], Sequence):
def __init__(self, x: List[T]) -> None:
self._x = x
@ -34,29 +36,33 @@ class ConstantList(Generic[T]):
def clear(self):
raise Exception("Cannot clear a constant list")
def index(self, item):
return self._x.index(item)
def index(self,
item: T,
start: int = 0,
stop: Optional[int] = None) -> int:
return self._x.index(item, start,
stop if stop is not None else len(self._x))
@overload
def __getitem__(self, item) -> T:
def __getitem__(self, item: int) -> T:
...
@overload
def __getitem__(self, s: slice, /) -> List[T]:
...
def __getitem__(self, item):
def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
return self._x[item]
@overload
def __setitem__(self, item, value):
def __setitem__(self, item: int, value: T):
...
@overload
def __setitem__(self, s: slice, value, /):
def __setitem__(self, s: slice, value: T, /):
...
def __setitem__(self, item, value):
def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
raise Exception("Cannot set item in a constant list")
def __delitem__(self, item):
@ -73,10 +79,12 @@ class ConstantList(Generic[T]):
@contextmanager
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
def make_zmq_socket(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx = zmq.Context()
ctx = zmq.Context() # type: ignore[attr-defined]
try:
socket = ctx.socket(type)
@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
ctx.destroy(linger=0)
class LRUDictCache:
K = TypeVar('K')
V = TypeVar('V')
class LRUDictCache(Generic[K, V]):
def __init__(self, size: int):
self.cache = OrderedDict()
self.cache: OrderedDict[K, V] = OrderedDict()
self.size = size
def get(self, key, default=None):
def get(self, key: K, default=None) -> V:
if key not in self.cache:
return default
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
def put(self, key: K, value: V):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:

View File

@ -215,6 +215,7 @@ class InputBatch:
# Swap the states.
req_id = self.req_ids[last_req_index]
assert req_id is not None
self.req_ids[empty_index] = req_id
self.req_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index

View File

@ -1,6 +1,6 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
import numpy as np
import torch
@ -193,9 +193,9 @@ class GPUModelRunner:
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id
sampling_params = req_data.sampling_params
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
@ -204,25 +204,25 @@ class GPUModelRunner:
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
mm_inputs=req_data.mm_inputs,
mm_positions=req_data.mm_positions,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for req_data in scheduler_output.scheduled_resumed_reqs:
req_id = req_data.req_id
for res_req_data in scheduler_output.scheduled_resumed_reqs:
req_id = res_req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = req_data.block_ids
req_state.num_computed_tokens = req_data.num_computed_tokens
req_state.block_ids = res_req_data.block_ids
req_state.num_computed_tokens = res_req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
@ -259,6 +259,7 @@ class GPUModelRunner:
num_scheduled_tokens = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
@ -373,7 +374,7 @@ class GPUModelRunner:
# Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = []
req_input_ids: List[Tuple[int, int]] = []
req_input_ids: List[Tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
@ -406,6 +407,7 @@ class GPUModelRunner:
encoder_outputs: List[torch.Tensor] = []
num_reqs = self.input_batch.num_reqs
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
@ -514,6 +516,7 @@ class GPUModelRunner:
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
@ -539,8 +542,15 @@ class GPUModelRunner:
logprobs = None
else:
logprobs = sampler_output.logprobs.cpu()
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,

View File

@ -204,7 +204,7 @@ class Worker:
return output if self.rank == 0 else None
return output
def profile(self, is_start=True):
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start: