Enable mypy checking on V1 code (#11105)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
93abf23a64
commit
6d917d0eeb
@ -29,3 +29,4 @@ run_mypy vllm/plugins
|
|||||||
run_mypy vllm/prompt_adapter
|
run_mypy vllm/prompt_adapter
|
||||||
run_mypy vllm/spec_decode
|
run_mypy vllm/spec_decode
|
||||||
run_mypy vllm/worker
|
run_mypy vllm/worker
|
||||||
|
run_mypy vllm/v1
|
||||||
|
@ -135,6 +135,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@ -263,12 +263,13 @@ class KVCacheManager:
|
|||||||
"""
|
"""
|
||||||
# Default to [] in case a request is freed (aborted) before alloc.
|
# Default to [] in case a request is freed (aborted) before alloc.
|
||||||
blocks = self.req_to_blocks.pop(request.request_id, [])
|
blocks = self.req_to_blocks.pop(request.request_id, [])
|
||||||
|
ordered_blocks: Iterable[KVCacheBlock] = blocks
|
||||||
if self.enable_caching:
|
if self.enable_caching:
|
||||||
# Free blocks in reverse order so that the tail blocks are
|
# Free blocks in reverse order so that the tail blocks are
|
||||||
# freed first.
|
# freed first.
|
||||||
blocks = reversed(blocks)
|
ordered_blocks = reversed(blocks)
|
||||||
|
|
||||||
for block in blocks:
|
for block in ordered_blocks:
|
||||||
block.decr_ref()
|
block.decr_ref()
|
||||||
if block.ref_cnt == 0:
|
if block.ref_cnt == 0:
|
||||||
self.free_block_queue.append(block)
|
self.free_block_queue.append(block)
|
||||||
@ -396,8 +397,7 @@ class KVCacheManager:
|
|||||||
f"{request.request_id}({request})")
|
f"{request.request_id}({request})")
|
||||||
|
|
||||||
# Compute the hash of the current block.
|
# Compute the hash of the current block.
|
||||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
|
||||||
tuple(block_tokens))
|
|
||||||
|
|
||||||
# Update and added the full block to the cache.
|
# Update and added the full block to the cache.
|
||||||
blk.block_hash = block_hash
|
blk.block_hash = block_hash
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""KV-Cache Utilities."""
|
"""KV-Cache Utilities."""
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional, Tuple
|
from typing import List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
|
|||||||
collision happens when the hash value is the same.
|
collision happens when the hash value is the same.
|
||||||
"""
|
"""
|
||||||
hash_value: int
|
hash_value: int
|
||||||
token_ids: Tuple[int]
|
token_ids: Tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -79,8 +80,8 @@ class FreeKVCacheBlockQueue:
|
|||||||
self.num_free_blocks = len(blocks)
|
self.num_free_blocks = len(blocks)
|
||||||
|
|
||||||
# Initialize the doubly linked list of free blocks.
|
# Initialize the doubly linked list of free blocks.
|
||||||
self.free_list_head = blocks[0]
|
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
|
||||||
self.free_list_tail = blocks[-1]
|
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
|
||||||
for i in range(self.num_free_blocks):
|
for i in range(self.num_free_blocks):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
blocks[i].prev_free_block = blocks[i - 1]
|
blocks[i].prev_free_block = blocks[i - 1]
|
||||||
@ -159,7 +160,7 @@ class FreeKVCacheBlockQueue:
|
|||||||
|
|
||||||
|
|
||||||
def hash_block_tokens(parent_block_hash: Optional[int],
|
def hash_block_tokens(parent_block_hash: Optional[int],
|
||||||
curr_block_token_ids: Tuple[int]) -> BlockHashType:
|
curr_block_token_ids: Sequence[int]) -> BlockHashType:
|
||||||
"""Computes a hash value corresponding to the contents of a block and
|
"""Computes a hash value corresponding to the contents of a block and
|
||||||
the contents of the preceding block(s). The hash value is used for
|
the contents of the preceding block(s). The hash value is used for
|
||||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||||
@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int],
|
|||||||
Args:
|
Args:
|
||||||
parent_block_hash: The hash of the parent block. None
|
parent_block_hash: The hash of the parent block. None
|
||||||
if this is the first block.
|
if this is the first block.
|
||||||
curr_block_token_ids: A tuple of token ids in the current
|
curr_block_token_ids: A list of token ids in the current
|
||||||
block. The current block is assumed to be full.
|
block. The current block is assumed to be full.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int],
|
|||||||
The entire tuple is used as the hash key of the block.
|
The entire tuple is used as the hash key of the block.
|
||||||
"""
|
"""
|
||||||
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
|
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
|
||||||
curr_block_token_ids)
|
tuple(curr_block_token_ids))
|
||||||
|
|
||||||
|
|
||||||
def hash_request_tokens(block_size: int,
|
def hash_request_tokens(block_size: int,
|
||||||
token_ids: List[int]) -> List[BlockHashType]:
|
token_ids: Sequence[int]) -> List[BlockHashType]:
|
||||||
"""Computes hash values of a chain of blocks given a sequence of
|
"""Computes hash values of a chain of blocks given a sequence of
|
||||||
token IDs. The hash value is used for prefix caching.
|
token IDs. The hash value is used for prefix caching.
|
||||||
|
|
||||||
@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
|
|||||||
parent_block_hash_value = None
|
parent_block_hash_value = None
|
||||||
for start in range(0, len(token_ids), block_size):
|
for start in range(0, len(token_ids), block_size):
|
||||||
end = start + block_size
|
end = start + block_size
|
||||||
block_token_ids = tuple(token_ids[start:end])
|
block_token_ids = token_ids[start:end]
|
||||||
# Do not hash the block if it is not full.
|
# Do not hash the block if it is not full.
|
||||||
if len(block_token_ids) < block_size:
|
if len(block_token_ids) < block_size:
|
||||||
break
|
break
|
||||||
|
@ -152,6 +152,7 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
if not can_schedule:
|
if not can_schedule:
|
||||||
break
|
break
|
||||||
|
assert new_blocks is not None
|
||||||
|
|
||||||
# Schedule the request.
|
# Schedule the request.
|
||||||
scheduled_running_reqs.append(request)
|
scheduled_running_reqs.append(request)
|
||||||
|
@ -36,7 +36,7 @@ class EngineCoreRequest:
|
|||||||
prompt: Optional[str]
|
prompt: Optional[str]
|
||||||
prompt_token_ids: List[int]
|
prompt_token_ids: List[int]
|
||||||
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
|
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
|
||||||
mm_hashes: Optional[List[Optional[str]]]
|
mm_hashes: Optional[List[str]]
|
||||||
mm_placeholders: Optional[MultiModalPlaceholderDict]
|
mm_placeholders: Optional[MultiModalPlaceholderDict]
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
eos_token_id: Optional[int]
|
eos_token_id: Optional[int]
|
||||||
@ -44,10 +44,11 @@ class EngineCoreRequest:
|
|||||||
lora_request: Optional[LoRARequest]
|
lora_request: Optional[LoRARequest]
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreOutput(msgspec.Struct,
|
class EngineCoreOutput(
|
||||||
array_like=True,
|
msgspec.Struct,
|
||||||
omit_defaults=True,
|
array_like=True, # type: ignore[call-arg]
|
||||||
gc=False):
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False): # type: ignore[call-arg]
|
||||||
|
|
||||||
request_id: str
|
request_id: str
|
||||||
new_token_ids: List[int]
|
new_token_ids: List[int]
|
||||||
@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
|
|||||||
stop_reason: Union[int, str, None] = None
|
stop_reason: Union[int, str, None] = None
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreOutputs(msgspec.Struct,
|
class EngineCoreOutputs(
|
||||||
array_like=True,
|
msgspec.Struct,
|
||||||
omit_defaults=True,
|
array_like=True, # type: ignore[call-arg]
|
||||||
gc=False):
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False): # type: ignore[call-arg]
|
||||||
|
|
||||||
#NOTE(Nick): We could consider ways to make this more compact,
|
#NOTE(Nick): We could consider ways to make this more compact,
|
||||||
# e.g. columnwise layout and using an int enum for finish/stop reason
|
# e.g. columnwise layout and using an int enum for finish/stop reason
|
||||||
@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
ADD = b'\x00'
|
ADD = b'\x00'
|
||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
PROFILE = b'\x02'
|
PROFILE = b'\x02'
|
||||||
|
|
||||||
|
|
||||||
|
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
|
||||||
|
@ -81,7 +81,7 @@ class AsyncLLM(EngineClient):
|
|||||||
asyncio_mode=True,
|
asyncio_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.output_handler = None
|
self.output_handler: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
@ -126,7 +126,8 @@ class AsyncLLM(EngineClient):
|
|||||||
handler.cancel()
|
handler.cancel()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
|
||||||
|
executor_class: Type[Executor]
|
||||||
distributed_executor_backend = (
|
distributed_executor_backend = (
|
||||||
vllm_config.parallel_config.distributed_executor_backend)
|
vllm_config.parallel_config.distributed_executor_backend)
|
||||||
if distributed_executor_backend == "mp":
|
if distributed_executor_backend == "mp":
|
||||||
@ -361,10 +362,10 @@ class AsyncLLM(EngineClient):
|
|||||||
logger.debug("Called check_health.")
|
logger.debug("Called check_health.")
|
||||||
|
|
||||||
async def start_profile(self) -> None:
|
async def start_profile(self) -> None:
|
||||||
await self.engine_core.profile(True)
|
await self.engine_core.profile_async(True)
|
||||||
|
|
||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
await self.engine_core.profile(False)
|
await self.engine_core.profile_async(False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@ -380,7 +381,7 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dead_error(self) -> BaseException:
|
def dead_error(self) -> BaseException:
|
||||||
return Exception
|
return Exception() # TODO: implement
|
||||||
|
|
||||||
|
|
||||||
# Retain V0 name for backwards compatibility.
|
# Retain V0 name for backwards compatibility.
|
||||||
|
@ -5,7 +5,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from typing import List, Tuple, Type, Union
|
from typing import List, Tuple, Type
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
|
|||||||
from vllm.v1.core.scheduler import Scheduler
|
from vllm.v1.core.scheduler import Scheduler
|
||||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||||
EngineCoreProfile, EngineCoreRequest,
|
EngineCoreProfile, EngineCoreRequest,
|
||||||
EngineCoreRequestType)
|
EngineCoreRequestType, EngineCoreRequestUnion)
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
@ -97,8 +97,10 @@ class EngineCore:
|
|||||||
# Note that the cache here is mirrored with the client side of the
|
# Note that the cache here is mirrored with the client side of the
|
||||||
# MM mapper, so anything that has a hash must have a HIT cache
|
# MM mapper, so anything that has a hash must have a HIT cache
|
||||||
# entry here as well.
|
# entry here as well.
|
||||||
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
|
assert request.mm_inputs is not None
|
||||||
request.mm_inputs, request.mm_hashes)
|
request.mm_inputs, request.mm_hashes = (
|
||||||
|
self.mm_input_mapper_server.process_inputs(
|
||||||
|
request.mm_inputs, request.mm_hashes))
|
||||||
|
|
||||||
req = Request.from_engine_core_request(request)
|
req = Request.from_engine_core_request(request)
|
||||||
|
|
||||||
@ -128,7 +130,7 @@ class EngineCore:
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.model_executor.shutdown()
|
self.model_executor.shutdown()
|
||||||
|
|
||||||
def profile(self, is_start=True):
|
def profile(self, is_start: bool = True):
|
||||||
self.model_executor.profile(is_start)
|
self.model_executor.profile(is_start)
|
||||||
|
|
||||||
|
|
||||||
@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
# and to overlap some serialization/deserialization with the
|
# and to overlap some serialization/deserialization with the
|
||||||
# model forward pass.
|
# model forward pass.
|
||||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||||
self.input_queue = queue.Queue()
|
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
|
||||||
self.output_queue = queue.Queue()
|
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
|
||||||
threading.Thread(target=self.process_input_socket,
|
threading.Thread(target=self.process_input_socket,
|
||||||
args=(input_path, ),
|
args=(input_path, ),
|
||||||
daemon=True).start()
|
daemon=True).start()
|
||||||
@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
|
|
||||||
self._last_logging_time = now
|
self._last_logging_time = now
|
||||||
|
|
||||||
def _handle_client_request(
|
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
|
||||||
self, request: Union[EngineCoreRequest, EngineCoreProfile,
|
|
||||||
List[str]]) -> None:
|
|
||||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
||||||
|
|
||||||
if isinstance(request, EngineCoreRequest):
|
if isinstance(request, EngineCoreRequest):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
from typing import List, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import zmq
|
import zmq
|
||||||
@ -10,8 +10,9 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
|
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
|
||||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||||
EngineCoreProfile, EngineCoreRequest,
|
EngineCoreProfile, EngineCoreRequest,
|
||||||
EngineCoreRequestType)
|
EngineCoreRequestType, EngineCoreRequestUnion)
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
|
||||||
|
EngineCoreProcHandle)
|
||||||
from vllm.v1.serial_utils import PickleEncoder
|
from vllm.v1.serial_utils import PickleEncoder
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -59,7 +60,7 @@ class EngineCoreClient:
|
|||||||
def add_request(self, request: EngineCoreRequest) -> None:
|
def add_request(self, request: EngineCoreRequest) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def profile(self, is_start=True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
@ -71,6 +72,9 @@ class EngineCoreClient:
|
|||||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -105,7 +109,7 @@ class InprocClient(EngineCoreClient):
|
|||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|
||||||
def profile(self, is_start=True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self.engine_core.profile(is_start)
|
self.engine_core.profile(is_start)
|
||||||
|
|
||||||
|
|
||||||
@ -133,7 +137,10 @@ class MPClient(EngineCoreClient):
|
|||||||
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
|
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
|
||||||
|
|
||||||
# ZMQ setup.
|
# ZMQ setup.
|
||||||
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
|
if asyncio_mode:
|
||||||
|
self.ctx = zmq.asyncio.Context()
|
||||||
|
else:
|
||||||
|
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||||
|
|
||||||
# Path for IPC.
|
# Path for IPC.
|
||||||
ready_path = get_open_zmq_ipc_path()
|
ready_path = get_open_zmq_ipc_path()
|
||||||
@ -149,11 +156,13 @@ class MPClient(EngineCoreClient):
|
|||||||
self.input_socket.bind(input_path)
|
self.input_socket.bind(input_path)
|
||||||
|
|
||||||
# Start EngineCore in background process.
|
# Start EngineCore in background process.
|
||||||
|
self.proc_handle: Optional[EngineCoreProcHandle]
|
||||||
self.proc_handle = EngineCoreProc.make_engine_core_process(
|
self.proc_handle = EngineCoreProc.make_engine_core_process(
|
||||||
*args,
|
*args,
|
||||||
input_path=input_path,
|
input_path=
|
||||||
output_path=output_path,
|
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
|
||||||
ready_path=ready_path,
|
output_path=output_path, # type: ignore[misc]
|
||||||
|
ready_path=ready_path, # type: ignore[misc]
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
atexit.register(self.shutdown)
|
atexit.register(self.shutdown)
|
||||||
@ -204,10 +213,8 @@ class SyncMPClient(MPClient):
|
|||||||
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
|
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
def _send_input(
|
def _send_input(self, request_type: EngineCoreRequestType,
|
||||||
self, request_type: EngineCoreRequestType,
|
request: EngineCoreRequestUnion) -> None:
|
||||||
request: Union[EngineCoreRequest, EngineCoreProfile,
|
|
||||||
List[str]]) -> None:
|
|
||||||
|
|
||||||
# (RequestType, SerializedRequest)
|
# (RequestType, SerializedRequest)
|
||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
@ -219,7 +226,7 @@ class SyncMPClient(MPClient):
|
|||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
def profile(self, is_start=True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self._send_input(EngineCoreRequestType.PROFILE,
|
self._send_input(EngineCoreRequestType.PROFILE,
|
||||||
EngineCoreProfile(is_start))
|
EngineCoreProfile(is_start))
|
||||||
|
|
||||||
@ -237,10 +244,8 @@ class AsyncMPClient(MPClient):
|
|||||||
|
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
async def _send_input(
|
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||||
self, request_type: EngineCoreRequestType,
|
request: EngineCoreRequestUnion) -> None:
|
||||||
request: Union[EngineCoreRequest, EngineCoreProfile,
|
|
||||||
List[str]]) -> None:
|
|
||||||
|
|
||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
await self.input_socket.send_multipart(msg, copy=False)
|
await self.input_socket.send_multipart(msg, copy=False)
|
||||||
@ -252,6 +257,6 @@ class AsyncMPClient(MPClient):
|
|||||||
if len(request_ids) > 0:
|
if len(request_ids) > 0:
|
||||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
async def profile(self, is_start=True) -> None:
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.PROFILE,
|
await self._send_input(EngineCoreRequestType.PROFILE,
|
||||||
EngineCoreProfile(is_start))
|
EngineCoreProfile(is_start))
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -97,7 +97,7 @@ class IncrementalDetokenizer:
|
|||||||
self,
|
self,
|
||||||
new_token_ids: List[int],
|
new_token_ids: List[int],
|
||||||
finish_reason: Optional[str],
|
finish_reason: Optional[str],
|
||||||
stop_reason: Optional[str],
|
stop_reason: Optional[Union[int, str, None]],
|
||||||
) -> Optional[RequestOutput]:
|
) -> Optional[RequestOutput]:
|
||||||
"""
|
"""
|
||||||
Update RequestState for the request_id by:
|
Update RequestState for the request_id by:
|
||||||
|
@ -103,7 +103,8 @@ class LLMEngine:
|
|||||||
multiprocess_mode=enable_multiprocessing)
|
multiprocess_mode=enable_multiprocessing)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
|
||||||
|
executor_class: Type[Executor]
|
||||||
distributed_executor_backend = (
|
distributed_executor_backend = (
|
||||||
vllm_config.parallel_config.distributed_executor_backend)
|
vllm_config.parallel_config.distributed_executor_backend)
|
||||||
if distributed_executor_backend == "mp":
|
if distributed_executor_backend == "mp":
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
@ -42,14 +42,14 @@ class MMInputMapperClient:
|
|||||||
model_config)
|
model_config)
|
||||||
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
||||||
|
|
||||||
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
|
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||||
|
|
||||||
# DEBUG: Set to None to disable
|
# DEBUG: Set to None to disable
|
||||||
self.mm_debug_cache_hit_ratio_steps = None
|
self.mm_debug_cache_hit_ratio_steps = None
|
||||||
self.mm_cache_hits = 0
|
self.mm_cache_hits = 0
|
||||||
self.mm_cache_total = 0
|
self.mm_cache_total = 0
|
||||||
|
|
||||||
def cache_hit_ratio(self, steps) -> float:
|
def cache_hit_ratio(self, steps):
|
||||||
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
|
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
|
||||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||||
self.mm_cache_hits / self.mm_cache_total)
|
self.mm_cache_hits / self.mm_cache_total)
|
||||||
@ -60,7 +60,7 @@ class MMInputMapperClient:
|
|||||||
mm_hashes: Optional[List[str]],
|
mm_hashes: Optional[List[str]],
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
mm_processor_kwargs: Optional[Dict[str, Any]],
|
||||||
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
|
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
|
||||||
) -> List[MultiModalKwargs]:
|
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
|
||||||
if precomputed_mm_inputs is None:
|
if precomputed_mm_inputs is None:
|
||||||
image_inputs = mm_data["image"]
|
image_inputs = mm_data["image"]
|
||||||
if not isinstance(image_inputs, list):
|
if not isinstance(image_inputs, list):
|
||||||
@ -72,6 +72,7 @@ class MMInputMapperClient:
|
|||||||
# Check if hash is enabled
|
# Check if hash is enabled
|
||||||
use_hash = mm_hashes is not None
|
use_hash = mm_hashes is not None
|
||||||
if use_hash:
|
if use_hash:
|
||||||
|
assert mm_hashes is not None
|
||||||
assert num_inputs == len(
|
assert num_inputs == len(
|
||||||
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
|
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
|
||||||
num_inputs, len(mm_hashes))
|
num_inputs, len(mm_hashes))
|
||||||
@ -79,7 +80,7 @@ class MMInputMapperClient:
|
|||||||
# Process each image input separately, so that later we can schedule
|
# Process each image input separately, so that later we can schedule
|
||||||
# them in a fine-grained manner.
|
# them in a fine-grained manner.
|
||||||
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
||||||
ret_hashes = [] if use_hash else None
|
ret_hashes: Optional[List[str]] = [] if use_hash else None
|
||||||
ret_inputs: List[MultiModalKwargs] = []
|
ret_inputs: List[MultiModalKwargs] = []
|
||||||
for input_id in range(num_inputs):
|
for input_id in range(num_inputs):
|
||||||
if self.mm_debug_cache_hit_ratio_steps is not None:
|
if self.mm_debug_cache_hit_ratio_steps is not None:
|
||||||
@ -88,6 +89,7 @@ class MMInputMapperClient:
|
|||||||
mm_hash = None
|
mm_hash = None
|
||||||
mm_input = None
|
mm_input = None
|
||||||
if use_hash:
|
if use_hash:
|
||||||
|
assert mm_hashes is not None
|
||||||
mm_hash = mm_hashes[input_id]
|
mm_hash = mm_hashes[input_id]
|
||||||
mm_input = self.mm_cache.get(mm_hash)
|
mm_input = self.mm_cache.get(mm_hash)
|
||||||
|
|
||||||
@ -105,12 +107,15 @@ class MMInputMapperClient:
|
|||||||
|
|
||||||
if use_hash:
|
if use_hash:
|
||||||
# Add to cache
|
# Add to cache
|
||||||
|
assert mm_hash is not None
|
||||||
self.mm_cache.put(mm_hash, mm_input)
|
self.mm_cache.put(mm_hash, mm_input)
|
||||||
else:
|
else:
|
||||||
self.mm_cache_hits += 1
|
self.mm_cache_hits += 1
|
||||||
mm_input = None # Avoids sending mm_input to Server
|
mm_input = None # Avoids sending mm_input to Server
|
||||||
|
|
||||||
if use_hash:
|
if use_hash:
|
||||||
|
assert mm_hash is not None
|
||||||
|
assert ret_hashes is not None
|
||||||
ret_hashes.append(mm_hash)
|
ret_hashes.append(mm_hash)
|
||||||
ret_inputs.append(mm_input)
|
ret_inputs.append(mm_input)
|
||||||
|
|
||||||
@ -120,17 +125,18 @@ class MMInputMapperClient:
|
|||||||
class MMInputMapperServer:
|
class MMInputMapperServer:
|
||||||
|
|
||||||
def __init__(self, ):
|
def __init__(self, ):
|
||||||
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
|
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
mm_inputs: List[Optional[MultiModalKwargs]],
|
mm_inputs: List[Optional[MultiModalKwargs]],
|
||||||
mm_hashes: List[Optional[str]],
|
mm_hashes: List[str],
|
||||||
) -> List[MultiModalKwargs]:
|
) -> List[MultiModalKwargs]:
|
||||||
assert len(mm_inputs) == len(mm_hashes)
|
assert len(mm_inputs) == len(mm_hashes)
|
||||||
|
|
||||||
full_mm_inputs = []
|
full_mm_inputs = []
|
||||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||||
|
assert mm_hash is not None
|
||||||
if mm_input is None:
|
if mm_input is None:
|
||||||
mm_input = self.mm_cache.get(mm_hash)
|
mm_input = self.mm_cache.get(mm_hash)
|
||||||
assert mm_input is not None
|
assert mm_input is not None
|
||||||
|
@ -56,7 +56,7 @@ class Processor:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: float,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
@ -28,7 +28,7 @@ class Executor(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def profile(self, is_start=True):
|
def profile(self, is_start: bool = True):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -38,11 +38,3 @@ class Executor(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collective_rpc(self,
|
|
||||||
method: str,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
args: Tuple = (),
|
|
||||||
kwargs: Optional[Dict] = None) -> []:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -7,7 +7,7 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (get_distributed_init_method, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_open_port,
|
||||||
get_open_zmq_ipc_path)
|
get_open_zmq_ipc_path)
|
||||||
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.utils import make_zmq_socket
|
from vllm.v1.utils import make_zmq_socket
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
|
|||||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||||
|
|
||||||
|
|
||||||
class MultiprocExecutor:
|
class MultiprocExecutor(Executor):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||||
# Call self.shutdown at exit to clean up
|
# Call self.shutdown at exit to clean up
|
||||||
@ -103,7 +104,7 @@ class MultiprocExecutor:
|
|||||||
method: str,
|
method: str,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> []:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Execute an RPC call on workers.
|
Execute an RPC call on workers.
|
||||||
|
|
||||||
@ -125,7 +126,7 @@ class MultiprocExecutor:
|
|||||||
|
|
||||||
responses = [None] * self.world_size
|
responses = [None] * self.world_size
|
||||||
for w in self.workers:
|
for w in self.workers:
|
||||||
dequeue_timeout = timeout - (time.monotonic() - start_time()
|
dequeue_timeout = timeout - (time.monotonic() - start_time
|
||||||
) if timeout is not None else None
|
) if timeout is not None else None
|
||||||
status, result = w.worker_response_mq.dequeue(
|
status, result = w.worker_response_mq.dequeue(
|
||||||
timeout=dequeue_timeout)
|
timeout=dequeue_timeout)
|
||||||
@ -153,7 +154,7 @@ class MultiprocExecutor:
|
|||||||
args=(scheduler_output, ))[0]
|
args=(scheduler_output, ))[0]
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
def profile(self, is_start=True):
|
def profile(self, is_start: bool = True):
|
||||||
self.collective_rpc("profile", args=(is_start, ))
|
self.collective_rpc("profile", args=(is_start, ))
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -185,7 +186,6 @@ class MultiprocExecutor:
|
|||||||
p.kill()
|
p.kill()
|
||||||
|
|
||||||
self._cleanup_sockets()
|
self._cleanup_sockets()
|
||||||
self.workers = None
|
|
||||||
|
|
||||||
def _cleanup_sockets(self):
|
def _cleanup_sockets(self):
|
||||||
for w in self.workers:
|
for w in self.workers:
|
||||||
@ -200,7 +200,8 @@ class MultiprocExecutor:
|
|||||||
# again
|
# again
|
||||||
atexit.unregister(self.shutdown)
|
atexit.unregister(self.shutdown)
|
||||||
"""Properly shut down the executor and its workers"""
|
"""Properly shut down the executor and its workers"""
|
||||||
if (hasattr(self, 'workers') and self.workers is not None):
|
if getattr(self, 'shutting_down', False):
|
||||||
|
self.shutting_down = True
|
||||||
for w in self.workers: #TODO: not sure if needed
|
for w in self.workers: #TODO: not sure if needed
|
||||||
w.worker_response_mq = None
|
w.worker_response_mq = None
|
||||||
self._ensure_worker_termination()
|
self._ensure_worker_termination()
|
||||||
|
@ -4,13 +4,14 @@ from typing import Optional, Tuple
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.worker.gpu_worker import Worker
|
from vllm.v1.worker.gpu_worker import Worker
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UniprocExecutor:
|
class UniprocExecutor(Executor):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@ -25,7 +26,7 @@ class UniprocExecutor:
|
|||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||||
self.observability_config = vllm_config.observability_config
|
self.observability_config = vllm_config.observability_config
|
||||||
|
|
||||||
self.worker = self._create_worker()
|
self.worker: Worker = self._create_worker()
|
||||||
self.worker.initialize()
|
self.worker.initialize()
|
||||||
self.worker.load_model()
|
self.worker.load_model()
|
||||||
|
|
||||||
@ -75,7 +76,7 @@ class UniprocExecutor:
|
|||||||
self.worker.profile(is_start)
|
self.worker.profile(is_start)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.worker = None
|
pass
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# UniprocExecutor will always be healthy as long as
|
# UniprocExecutor will always be healthy as long as
|
||||||
|
@ -52,10 +52,9 @@ class Request:
|
|||||||
else:
|
else:
|
||||||
self.mm_positions = []
|
self.mm_positions = []
|
||||||
# Output of the mm input mapper (e.g., image tensors).
|
# Output of the mm input mapper (e.g., image tensors).
|
||||||
|
self.mm_inputs: List[MultiModalKwargs] = []
|
||||||
if self.inputs.multi_modal_inputs:
|
if self.inputs.multi_modal_inputs:
|
||||||
self.mm_inputs = self.inputs.multi_modal_inputs
|
self.mm_inputs = self.inputs.multi_modal_inputs
|
||||||
else:
|
|
||||||
self.mm_inputs: List[MultiModalKwargs] = []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Sequence
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Generic, Iterator, List, TypeVar, overload
|
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
|
||||||
|
overload)
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
@ -11,7 +13,7 @@ logger = init_logger(__name__)
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class ConstantList(Generic[T]):
|
class ConstantList(Generic[T], Sequence):
|
||||||
|
|
||||||
def __init__(self, x: List[T]) -> None:
|
def __init__(self, x: List[T]) -> None:
|
||||||
self._x = x
|
self._x = x
|
||||||
@ -34,29 +36,33 @@ class ConstantList(Generic[T]):
|
|||||||
def clear(self):
|
def clear(self):
|
||||||
raise Exception("Cannot clear a constant list")
|
raise Exception("Cannot clear a constant list")
|
||||||
|
|
||||||
def index(self, item):
|
def index(self,
|
||||||
return self._x.index(item)
|
item: T,
|
||||||
|
start: int = 0,
|
||||||
|
stop: Optional[int] = None) -> int:
|
||||||
|
return self._x.index(item, start,
|
||||||
|
stop if stop is not None else len(self._x))
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, item) -> T:
|
def __getitem__(self, item: int) -> T:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, s: slice, /) -> List[T]:
|
def __getitem__(self, s: slice, /) -> List[T]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
|
||||||
return self._x[item]
|
return self._x[item]
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __setitem__(self, item, value):
|
def __setitem__(self, item: int, value: T):
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __setitem__(self, s: slice, value, /):
|
def __setitem__(self, s: slice, value: T, /):
|
||||||
...
|
...
|
||||||
|
|
||||||
def __setitem__(self, item, value):
|
def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
|
||||||
raise Exception("Cannot set item in a constant list")
|
raise Exception("Cannot set item in a constant list")
|
||||||
|
|
||||||
def __delitem__(self, item):
|
def __delitem__(self, item):
|
||||||
@ -73,10 +79,12 @@ class ConstantList(Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
|
def make_zmq_socket(
|
||||||
|
path: str,
|
||||||
|
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||||
"""Context manager for a ZMQ socket"""
|
"""Context manager for a ZMQ socket"""
|
||||||
|
|
||||||
ctx = zmq.Context()
|
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||||
try:
|
try:
|
||||||
socket = ctx.socket(type)
|
socket = ctx.socket(type)
|
||||||
|
|
||||||
@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
|
|||||||
ctx.destroy(linger=0)
|
ctx.destroy(linger=0)
|
||||||
|
|
||||||
|
|
||||||
class LRUDictCache:
|
K = TypeVar('K')
|
||||||
|
V = TypeVar('V')
|
||||||
|
|
||||||
|
|
||||||
|
class LRUDictCache(Generic[K, V]):
|
||||||
|
|
||||||
def __init__(self, size: int):
|
def __init__(self, size: int):
|
||||||
self.cache = OrderedDict()
|
self.cache: OrderedDict[K, V] = OrderedDict()
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key: K, default=None) -> V:
|
||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
return self.cache[key]
|
return self.cache[key]
|
||||||
|
|
||||||
def put(self, key, value):
|
def put(self, key: K, value: V):
|
||||||
self.cache[key] = value
|
self.cache[key] = value
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
if len(self.cache) > self.size:
|
if len(self.cache) > self.size:
|
||||||
|
@ -215,6 +215,7 @@ class InputBatch:
|
|||||||
|
|
||||||
# Swap the states.
|
# Swap the states.
|
||||||
req_id = self.req_ids[last_req_index]
|
req_id = self.req_ids[last_req_index]
|
||||||
|
assert req_id is not None
|
||||||
self.req_ids[empty_index] = req_id
|
self.req_ids[empty_index] = req_id
|
||||||
self.req_ids[last_req_index] = None
|
self.req_ids[last_req_index] = None
|
||||||
self.req_id_to_index[req_id] = empty_index
|
self.req_id_to_index[req_id] = empty_index
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -193,9 +193,9 @@ class GPUModelRunner:
|
|||||||
|
|
||||||
req_ids_to_add: List[str] = []
|
req_ids_to_add: List[str] = []
|
||||||
# Add new requests to the cached states.
|
# Add new requests to the cached states.
|
||||||
for req_data in scheduler_output.scheduled_new_reqs:
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||||
req_id = req_data.req_id
|
req_id = new_req_data.req_id
|
||||||
sampling_params = req_data.sampling_params
|
sampling_params = new_req_data.sampling_params
|
||||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
generator = torch.Generator(device=self.device)
|
generator = torch.Generator(device=self.device)
|
||||||
generator.manual_seed(sampling_params.seed)
|
generator.manual_seed(sampling_params.seed)
|
||||||
@ -204,25 +204,25 @@ class GPUModelRunner:
|
|||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
self.requests[req_id] = CachedRequestState(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
prompt_token_ids=req_data.prompt_token_ids,
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
prompt=req_data.prompt,
|
prompt=new_req_data.prompt,
|
||||||
mm_inputs=req_data.mm_inputs,
|
mm_inputs=new_req_data.mm_inputs,
|
||||||
mm_positions=req_data.mm_positions,
|
mm_positions=new_req_data.mm_positions,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
block_ids=req_data.block_ids,
|
block_ids=new_req_data.block_ids,
|
||||||
num_computed_tokens=req_data.num_computed_tokens,
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||||
output_token_ids=[],
|
output_token_ids=[],
|
||||||
)
|
)
|
||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
# Update the cached states of the resumed requests.
|
# Update the cached states of the resumed requests.
|
||||||
for req_data in scheduler_output.scheduled_resumed_reqs:
|
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||||
req_id = req_data.req_id
|
req_id = res_req_data.req_id
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
req_state.block_ids = req_data.block_ids
|
req_state.block_ids = res_req_data.block_ids
|
||||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
@ -259,6 +259,7 @@ class GPUModelRunner:
|
|||||||
num_scheduled_tokens = []
|
num_scheduled_tokens = []
|
||||||
max_num_scheduled_tokens = 0
|
max_num_scheduled_tokens = 0
|
||||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||||
|
assert req_id is not None
|
||||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
num_scheduled_tokens.append(num_tokens)
|
num_scheduled_tokens.append(num_tokens)
|
||||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||||
@ -373,7 +374,7 @@ class GPUModelRunner:
|
|||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_inputs: List[MultiModalKwargs] = []
|
mm_inputs: List[MultiModalKwargs] = []
|
||||||
req_input_ids: List[Tuple[int, int]] = []
|
req_input_ids: List[Tuple[str, int]] = []
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
for input_id in encoder_input_ids:
|
for input_id in encoder_input_ids:
|
||||||
@ -406,6 +407,7 @@ class GPUModelRunner:
|
|||||||
encoder_outputs: List[torch.Tensor] = []
|
encoder_outputs: List[torch.Tensor] = []
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||||
|
assert req_id is not None
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
req_id]
|
req_id]
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
@ -514,6 +516,7 @@ class GPUModelRunner:
|
|||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||||
|
assert req_id is not None
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
@ -539,8 +542,15 @@ class GPUModelRunner:
|
|||||||
logprobs = None
|
logprobs = None
|
||||||
else:
|
else:
|
||||||
logprobs = sampler_output.logprobs.cpu()
|
logprobs = sampler_output.logprobs.cpu()
|
||||||
|
|
||||||
|
# num_reqs entries should be non-None
|
||||||
|
assert all(
|
||||||
|
req_id is not None for req_id in
|
||||||
|
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||||
|
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||||
|
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids[:num_reqs],
|
req_ids=req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=sampled_token_ids,
|
sampled_token_ids=sampled_token_ids,
|
||||||
logprob_token_ids_cpu=logprob_token_ids,
|
logprob_token_ids_cpu=logprob_token_ids,
|
||||||
|
@ -204,7 +204,7 @@ class Worker:
|
|||||||
return output if self.rank == 0 else None
|
return output if self.rank == 0 else None
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def profile(self, is_start=True):
|
def profile(self, is_start: bool = True):
|
||||||
if self.profiler is None:
|
if self.profiler is None:
|
||||||
raise RuntimeError("Profiler is not enabled.")
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
if is_start:
|
if is_start:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user