[mypy] Add mypy type annotation part 1 (#4006)

This commit is contained in:
SangBin Cho 2024-04-13 06:35:50 +09:00 committed by GitHub
parent d4ec9ffb95
commit 09473ee41c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 171 additions and 72 deletions

50
.github/workflows/mypy.yaml vendored Normal file
View File

@ -0,0 +1,50 @@
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -93,9 +93,23 @@ fi
echo 'vLLM yapf: Done' echo 'vLLM yapf: Done'
# Run mypy # Run mypy
# TODO(zhuohan): Enable mypy echo 'vLLM mypy:'
# echo 'vLLM mypy:' mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
# mypy mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
CODESPELL_EXCLUDES=( CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**' '--skip' '*docs/source/_build/**'
@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
exit 1 exit 1
fi fi

View File

@ -46,10 +46,13 @@ ignore = [
python_version = "3.8" python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true
files = "vllm" files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
]
[tool.codespell] [tool.codespell]

View File

@ -11,4 +11,5 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server. pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0 prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0 outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions

View File

@ -7,7 +7,7 @@ codespell==2.2.6
isort==5.13.2 isort==5.13.2
# type checking # type checking
mypy==0.991 mypy==1.9.0
types-PyYAML types-PyYAML
types-requests types-requests
types-setuptools types-setuptools

View File

@ -2,7 +2,7 @@ import enum
import json import json
import os import os
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch import torch
from packaging.version import Version from packaging.version import Version
@ -141,7 +141,7 @@ class ModelConfig:
supported_load_format = [ supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy" "auto", "pt", "safetensors", "npcache", "dummy"
] ]
rocm_not_supported_load_format = [] rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format: if load_format not in supported_load_format:
raise ValueError( raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of " f"Unknown load format: {self.load_format}. Must be one of "
@ -679,6 +679,9 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found " "num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.") f"{speculative_model=} and {num_speculative_tokens=}.")
assert (speculative_model is not None
and num_speculative_tokens is not None)
# TODO: The user should be able to specify revision/quantization/max # TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported. # model len for the draft model. It is not currently supported.
draft_revision = None draft_revision = None
@ -993,7 +996,7 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor derived_max_model_len *= scaling_factor
if max_model_len is None: if max_model_len is None:
max_model_len = derived_max_model_len max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len: elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length # Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input # that will be bigger than derived_max_model_len. We compare user input

View File

@ -1,5 +1,6 @@
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile from itertools import count, takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
@ -231,10 +232,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching: if self.enable_caching:
logger.info("Automatic prefix caching is enabled.") logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
num_gpu_blocks) Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
num_cpu_blocks) Device.CPU, block_size, num_cpu_blocks)
else: else:
self.gpu_allocator = UncachedBlockAllocator( self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks) Device.GPU, block_size, num_gpu_blocks)
@ -588,7 +589,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for b in takewhile(lambda b: b.computed, block_table[:-1]) for b in takewhile(lambda b: b.computed, block_table[:-1])
] ]
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group. """Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks). Used in prefill (can skip prefill of some blocks).

View File

@ -1,4 +1,5 @@
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.core.block.block_table import BlockTable from vllm.core.block.block_table import BlockTable
@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# as computed. # as computed.
self.block_allocator.mark_blocks_as_computed() self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill. """Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks. With prefix caching we can skip prefill for previously-generated blocks.

View File

@ -1,5 +1,6 @@
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from typing import Dict, List from typing import Dict, List
from vllm.sequence import Sequence, SequenceGroup from vllm.sequence import Sequence, SequenceGroup
@ -103,7 +104,8 @@ class BlockSpaceManager(ABC):
pass pass
@abstractmethod @abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass pass
@abstractmethod @abstractmethod

View File

@ -42,8 +42,8 @@ class SchedulingBudget:
""" """
token_budget: int token_budget: int
max_num_seqs: int max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0 _num_batched_tokens: int = 0
_num_curr_seqs: int = 0 _num_curr_seqs: int = 0
@ -133,7 +133,7 @@ class SchedulerOutputs:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool: def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted( self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups, self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@ -337,7 +337,8 @@ class Scheduler:
self.free_seq(seq) self.free_seq(seq)
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
@ -404,7 +405,7 @@ class Scheduler:
budget.subtract_num_seqs(seq_group.request_id, budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs) num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id) curr_loras.remove(seq_group.lora_int_id)
if running_queue: if running_queue:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence groups.
@ -496,7 +497,7 @@ class Scheduler:
now = time.time() now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue) swapped_queue = policy.sort_by_priority(now, swapped_queue)
leftover_swapped = deque() leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue: while swapped_queue:
seq_group = swapped_queue[0] seq_group = swapped_queue[0]
@ -507,7 +508,9 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras): and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so # We don't have a space for another LoRA, so
# we ignore this request for now. # we ignore this request for now.
@ -593,7 +596,7 @@ class Scheduler:
# Copy the queue so that the input queue is not modified. # Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue]) waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences = deque() leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue: while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0] seq_group = waiting_queue[0]
@ -635,6 +638,8 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0 if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras): and len(curr_loras) >= self.lora_config.max_loras):
@ -780,7 +785,7 @@ class Scheduler:
token_budget=self.scheduler_config.max_num_batched_tokens, token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs, max_num_seqs=self.scheduler_config.max_num_seqs,
) )
curr_loras = set() curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting, remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty()) SchedulerPrefillOutputs.create_empty())
@ -1087,7 +1092,7 @@ class Scheduler:
def _get_num_new_tokens(self, seq_group: SequenceGroup, def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool, status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> Tuple[int, bool]: budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group """Get the next new tokens to compute for a given sequence group
that's in a given `status`. that's in a given `status`.

View File

@ -1,5 +1,5 @@
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -144,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]: ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
@ -157,10 +157,10 @@ def broadcast_tensor_dict(
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
if rank == src: if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
assert value.is_cuda, ( assert value.is_cuda, (
@ -190,10 +190,10 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
src=src, src=src,
group=group) group=group)
metadata_list = recv_metadata_list[0] assert recv_metadata_list[0] is not None
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []
for key, value in metadata_list: for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
dtype=value.dtype, dtype=value.dtype,

View File

@ -1,9 +1,10 @@
import pickle import pickle
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.worker.worker import Worker
logger = init_logger(__name__) logger = init_logger(__name__)
@ -18,15 +19,20 @@ try:
if init_cached_hf_modules: if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules() init_hf_modules()
self.worker = None self._worker: Optional[Worker] = None
# Since the compiled DAG runs a main execution # Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device. # in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on # The flag indicates is set_device is called on
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn): def init_worker(self, worker_init_fn: Callable[[], Worker]):
self.worker = worker_init_fn() self._worker = worker_init_fn()
@property
def worker(self) -> Worker:
assert self._worker is not None
return self._worker
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.worker, name) return getattr(self.worker, name)
@ -70,8 +76,8 @@ except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray`.") "`pip install ray`.")
ray = None ray = None # type: ignore
RayWorkerVllm = None RayWorkerVllm = None # type: ignore
def initialize_ray_cluster( def initialize_ray_cluster(

View File

@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case

View File

@ -170,8 +170,12 @@ class LLM:
multi_modal_data.data = multi_modal_data.data.to(torch.float16) multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len( if prompts is not None:
prompt_token_ids) num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[

View File

@ -1,5 +1,5 @@
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import torch import torch
@ -61,7 +61,7 @@ class CPUExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
@ -66,7 +66,7 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
@ -47,7 +47,7 @@ class NeuronExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """

View File

@ -3,7 +3,7 @@ import copy
import os import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
@ -197,7 +197,7 @@ class RayGPUExecutor(ExecutorBase):
max_parallel_loading_workers, max_parallel_loading_workers,
) )
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks. """Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes This invokes `determine_num_available_blocks` on each worker and takes
@ -205,7 +205,7 @@ class RayGPUExecutor(ExecutorBase):
compatible with all workers. compatible with all workers.
Returns: Returns:
- tuple[num_gpu_blocks, num_cpu_blocks] - Tuple[num_gpu_blocks, num_cpu_blocks]
""" """
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", ) num_blocks = self._run_workers("determine_num_available_blocks", )
@ -276,7 +276,7 @@ class RayGPUExecutor(ExecutorBase):
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False, use_ray_compiled_dag: bool = False,
@ -291,6 +291,7 @@ class RayGPUExecutor(ExecutorBase):
if use_ray_compiled_dag: if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single # Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it. # input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1) output_channels = self.forward_dag.execute(1)
else: else:
# Start the ray workers first. # Start the ray workers first.
@ -369,7 +370,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:

View File

@ -5,7 +5,8 @@ from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from pydantic import conint from pydantic import Field
from typing_extensions import Annotated
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@ -127,7 +128,7 @@ class SamplingParams:
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n

View File

@ -171,10 +171,10 @@ class SequenceData:
return self.prompt_token_ids[-1] return self.prompt_token_ids[-1]
return self.output_token_ids[-1] return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int: def get_prompt_token_ids(self) -> List[int]:
return self.prompt_token_ids return self.prompt_token_ids
def get_output_token_ids(self) -> int: def get_output_token_ids(self) -> List[int]:
return self.output_token_ids return self.output_token_ids
@property @property
@ -370,7 +370,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group""" """Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling # torch.Generator used in seeded sampling
generator: Optional = None generator: Optional = None # type: ignore
class MultiModalData: class MultiModalData:
@ -599,7 +599,7 @@ class SequenceGroupMetadata:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property @property
def token_chunk_size(self) -> int: def token_chunk_size(self) -> Optional[int]:
"""Return the number of tokens to be processed (chunk size).""" """Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size return self._token_chunk_size

View File

@ -2,7 +2,8 @@ from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,

View File

@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders(
# NOTE(woosuk): The following code is slow because it runs a for loop over # NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow # the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple. # even when the loop body is very simple.
sub_texts = [] sub_texts: List[str] = []
current_sub_text = [] current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens) all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens: for token in output_tokens:
if skip_special_tokens and token in all_special_tokens: if skip_special_tokens and token in all_special_tokens:
@ -263,6 +263,7 @@ def detokenize_incrementally(
tokenizer, tokenizer,
all_input_ids[:-1], all_input_ids[:-1],
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string. # If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer): if new_token_id >= len(tokenizer):
@ -271,6 +272,8 @@ def detokenize_incrementally(
# Put new_token_id in a list so skip_special_tokens is respected # Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens( new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens) [new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
output_tokens = prev_tokens + new_tokens output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens. # If this is the first iteration, return all tokens.

View File

@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import * from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async from vllm.utils import make_async
logger = init_logger(__name__) logger = init_logger(__name__)
@ -28,7 +28,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer) tokenizer_len = len(tokenizer)
class CachedTokenizer(tokenizer.__class__): class CachedTokenizer(tokenizer.__class__): # type: ignore
@property @property
def all_special_ids(self): def all_special_ids(self):

View File

@ -7,7 +7,7 @@ import time
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Dict, Optional from typing import Any, Dict, Optional
from uuid import uuid4 from uuid import uuid4
import cpuinfo import cpuinfo
@ -124,7 +124,7 @@ class UsageMessage:
def report_usage(self, def report_usage(self,
model_architecture: str, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any] = None) -> None: extra_kvs: Optional[Dict[str, Any]] = None) -> None:
t = Thread(target=self._report_usage_worker, t = Thread(target=self._report_usage_worker,
args=(model_architecture, usage_context, extra_kvs or {}), args=(model_architecture, usage_context, extra_kvs or {}),
daemon=True) daemon=True)
@ -132,13 +132,13 @@ class UsageMessage:
def _report_usage_worker(self, model_architecture: str, def _report_usage_worker(self, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None: extra_kvs: Dict[str, Any]) -> None:
self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continous_usage() self._report_continous_usage()
def _report_usage_once(self, model_architecture: str, def _report_usage_once(self, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None: extra_kvs: Dict[str, Any]) -> None:
# Platform information # Platform information
if torch.cuda.is_available(): if torch.cuda.is_available():
device_property = torch.cuda.get_device_properties(0) device_property = torch.cuda.get_device_properties(0)

View File

@ -60,7 +60,7 @@ class LRUCache(Generic[T]):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> T: def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key) return self.get(key)
def __setitem__(self, key: Hashable, value: T) -> None: def __setitem__(self, key: Hashable, value: T) -> None:
@ -76,7 +76,7 @@ class LRUCache(Generic[T]):
key: Hashable, key: Hashable,
default_value: Optional[T] = None) -> Optional[T]: default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache: if key in self.cache:
value = self.cache[key] value: Optional[T] = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
else: else:
value = default_value value = default_value
@ -87,7 +87,7 @@ class LRUCache(Generic[T]):
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: T): def _on_remove(self, key: Hashable, value: Optional[T]):
pass pass
def remove_oldest(self): def remove_oldest(self):
@ -100,9 +100,11 @@ class LRUCache(Generic[T]):
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
self.remove_oldest() self.remove_oldest()
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value = self.cache.pop(key, default_value) value: Optional[T] = self.cache.pop(key, default_value)
if run_on_remove: if run_on_remove:
self._on_remove(key, value) self._on_remove(key, value)
return value return value