[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2024-10-16 18:55:59 -04:00 committed by GitHub
parent 8345045833
commit 776dbd74f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 109 additions and 74 deletions

View File

@ -13,24 +13,14 @@ run_mypy() {
run_mypy # Note that this is less strict than CI
run_mypy tests
run_mypy vllm/assets
run_mypy vllm/attention
#run_mypy vllm/compilation
#run_mypy vllm/core
run_mypy vllm/compilation
run_mypy vllm/distributed
run_mypy vllm/engine
run_mypy vllm/entrypoints
run_mypy vllm/executor
#run_mypy vllm/inputs
run_mypy vllm/logging
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/multimodal
run_mypy vllm/platforms
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/transformers_utils
run_mypy vllm/usage
#run_mypy vllm/vllm_flash_attn
run_mypy vllm/worker

View File

@ -92,7 +92,7 @@ class Attention(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:

View File

@ -244,8 +244,8 @@ def vllm_backend(
def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend = "eager"
return backend
backend_str = "eager"
return backend_str
assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"

View File

@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def cls_decorator_helper(cls: type):
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward)
for k in dynamic_arg_dims:
if k not in sig.parameters:
@ -63,13 +65,13 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__
old_init = cls.__init__ # type: ignore
def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)
TorchCompileWrapperWithCustomDispatcher.__init__(self)
cls.__init__ = __init__
cls.__init__ = __init__ # type: ignore
def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
@ -109,5 +111,5 @@ def _support_torch_compile(cls: type,
model_output = self.forward(*args, **kwargs)
return model_output
cls.__call__ = __call__
cls.__call__ = __call__ # type: ignore
return cls

View File

@ -73,7 +73,7 @@ class TorchCompileWrapperWithCustomDispatcher:
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while True:
while frame and frame.f_back:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]

View File

@ -626,13 +626,14 @@ class CacheConfig:
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None
self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks: Optional[int] = None
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
@ -709,7 +710,8 @@ class TokenizerPoolConfig:
@classmethod
def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
@ -1544,7 +1546,7 @@ class LoRAConfig:
max_loras: int
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_dtype: Optional[Union[torch.dtype, str]] = None
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256

View File

@ -4,8 +4,9 @@ import random
import time
from collections import deque
from dataclasses import dataclass, field
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
class SchedulerOutputs:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
num_prefill_groups: int
# Total number of batched tokens.

View File

@ -3,7 +3,7 @@ import dataclasses
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union)
Tuple, Type, Union, cast)
import torch
@ -89,7 +89,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
@ -181,7 +181,7 @@ class EngineArgs:
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
def __post_init__(self):
if self.tokenizer is None:
if not self.tokenizer:
self.tokenizer = self.model
# Setup plugins
@ -837,7 +837,8 @@ class EngineArgs:
def create_model_config(self) -> ModelConfig:
return ModelConfig(
model=self.model,
tokenizer=self.tokenizer,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
@ -908,8 +909,9 @@ class EngineArgs:
self.enable_prefix_caching = False
cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
(self.max_model_len if self.max_model_len is not None else 0),
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,

View File

@ -6,7 +6,7 @@ from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, overload
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar
@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
@ -188,7 +188,7 @@ class LLMEngine:
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
return cast(_O, output)
@classmethod
def validate_outputs(
@ -1039,6 +1039,7 @@ class LLMEngine:
scheduler_outputs.scheduled_seq_groups)
has_multiple_outputs: bool = len(outputs) > 1
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
@ -1084,6 +1085,7 @@ class LLMEngine:
finished_before.append(i)
continue
output: List[SequenceGroupOutput]
if has_multiple_outputs:
output = outputs_by_sequence_group[i]
else:
@ -1096,7 +1098,7 @@ class LLMEngine:
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
seq_group_meta.token_chunk_size or 0)
if outputs:
for o in outputs:
@ -1104,13 +1106,13 @@ class LLMEngine:
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
o.model_forward_time or 0)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time)
o.model_execute_time or 0)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
@ -1236,8 +1238,10 @@ class LLMEngine:
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
token_chunk_size = (seq_group_metadata.token_chunk_size
if seq_group_metadata.token_chunk_size
is not None else 0)
seq_group.update_num_computed_tokens(token_chunk_size)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union, cast
import numpy as np
import prometheus_client
@ -249,10 +249,11 @@ class _RayHistogramWrapper:
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)
boundaries=boundaries)
def labels(self, **labels):
self._histogram.set_default_tags(labels)
@ -267,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper
_gauge_cls: Type[prometheus_client.Gauge] = cast(
Type[prometheus_client.Gauge], _RayGaugeWrapper)
_counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)
def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:

View File

@ -3,7 +3,7 @@ import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, overload)
Optional, Union, cast, overload)
import cloudpickle
import zmq
@ -513,9 +513,14 @@ class MQLLMEngineClient(EngineClient):
assert (prompt is not None and pooling_params is not None
and request_id is not None)
return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers, None,
priority)
return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
async def _process_request(
self,
@ -543,7 +548,9 @@ class MQLLMEngineClient(EngineClient):
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.guided_decoding_backend
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
)
# 1) Create output queue for this requests.

View File

@ -73,11 +73,9 @@ class MQLLMEngine:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True
kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets

View File

@ -1,5 +1,5 @@
import functools
from typing import Callable, List
from typing import Callable, List, cast
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
@ -57,6 +59,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod
@ -100,8 +103,18 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
seq_id = seq.seq_id
assert all(
[seq_id == output.samples[0].parent_seq_id for output in outputs])
# This method is defined in the more generic
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert all([
isinstance(output, CompletionSequenceGroupOutput)
for output in outputs
])
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
assert all([
seq_id == output.samples[0].parent_seq_id
for output in compl_outputs
])
if is_async:
# Async case: We process tokens one by one. Here, we know the token
@ -113,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples = [output.samples[0] for output in outputs]
samples = [output.samples[0] for output in compl_outputs]
# entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens).

View File

@ -6,8 +6,9 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.sequence import (CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
@ -16,7 +17,7 @@ logger = init_logger(__name__)
def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: SequenceGroupOutput) -> None:
output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
@ -106,6 +107,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

View File

@ -57,7 +57,7 @@ class StopChecker:
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
@ -92,7 +92,7 @@ class StopChecker:
Returns the stop string if matched or else None.
"""
if not new_char_count:
if not new_char_count or not sampling_params.stop:
return None
for stop_str in sampling_params.stop:

View File

@ -1,22 +1,25 @@
from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from typing import cast
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput
from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
def create_output_by_sequence_group(
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
outputs: GenericSequence[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in outputs:
sequence_group_output: CompletionSequenceGroupOutput
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
return output_by_sequence_group
# Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)

View File

@ -1,4 +1,4 @@
from typing import List, Literal, Sequence, TypedDict, Union, overload
from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
from typing_extensions import TypeIs
@ -44,13 +44,16 @@ def parse_and_batch_prompt(
if is_list_of(prompt, str):
# case 2: array of strings
prompt = cast(List[str], prompt)
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int):
# case 3: array of tokens
prompt = cast(List[int], prompt)
return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list):
prompt = cast(List[List[int]], prompt)
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")

View File

@ -4,7 +4,7 @@ import warnings
from dataclasses import dataclass
from importlib.util import find_spec
from math import inf
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Iterator, List, Optional, Tuple, Union
import msgspec
import torch
@ -117,12 +117,15 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int):
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
return iter(self.outputs)
def __len__(self):
return len(self.outputs)

View File

@ -4,6 +4,7 @@ from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@ -92,7 +93,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: Optional[str],
prompt: Optional[PromptType],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],

View File

@ -788,7 +788,7 @@ class SequenceGroup:
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
self.init_multi_step(num_steps=num_lookahead_slots + 1)
def get_last_latency(self, now: float) -> Optional[float]:
def get_last_latency(self, now: float) -> float:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
@ -1198,7 +1198,7 @@ class PoolerOutput(
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
def __getitem__(self, idx: int):
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value):