[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2024-10-16 18:55:59 -04:00 committed by GitHub
parent 8345045833
commit 776dbd74f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 109 additions and 74 deletions

View File

@ -13,24 +13,14 @@ run_mypy() {
run_mypy # Note that this is less strict than CI run_mypy # Note that this is less strict than CI
run_mypy tests run_mypy tests
run_mypy vllm/assets
run_mypy vllm/attention run_mypy vllm/attention
#run_mypy vllm/compilation run_mypy vllm/compilation
#run_mypy vllm/core
run_mypy vllm/distributed run_mypy vllm/distributed
run_mypy vllm/engine run_mypy vllm/engine
run_mypy vllm/entrypoints
run_mypy vllm/executor run_mypy vllm/executor
#run_mypy vllm/inputs
run_mypy vllm/logging
run_mypy vllm/lora run_mypy vllm/lora
run_mypy vllm/model_executor run_mypy vllm/model_executor
run_mypy vllm/multimodal
run_mypy vllm/platforms
run_mypy vllm/plugins run_mypy vllm/plugins
run_mypy vllm/prompt_adapter run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode run_mypy vllm/spec_decode
run_mypy vllm/transformers_utils
run_mypy vllm/usage
#run_mypy vllm/vllm_flash_attn
run_mypy vllm/worker run_mypy vllm/worker

View File

@ -92,7 +92,7 @@ class Attention(nn.Module):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:

View File

@ -244,8 +244,8 @@ def vllm_backend(
def select_default_backend(level: int) -> Union[str, Callable]: def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend = "eager" backend_str = "eager"
return backend return backend_str
assert level in [ assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}" ], f"Invalid level {level}"

View File

@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def cls_decorator_helper(cls: type): def cls_decorator_helper(cls: type):
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward) sig = inspect.signature(cls.forward)
for k in dynamic_arg_dims: for k in dynamic_arg_dims:
if k not in sig.parameters: if k not in sig.parameters:
@ -63,13 +65,13 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher # other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ old_init = cls.__init__ # type: ignore
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs) old_init(self, *args, **kwargs)
TorchCompileWrapperWithCustomDispatcher.__init__(self) TorchCompileWrapperWithCustomDispatcher.__init__(self)
cls.__init__ = __init__ cls.__init__ = __init__ # type: ignore
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
@ -109,5 +111,5 @@ def _support_torch_compile(cls: type,
model_output = self.forward(*args, **kwargs) model_output = self.forward(*args, **kwargs)
return model_output return model_output
cls.__call__ = __call__ cls.__call__ = __call__ # type: ignore
return cls return cls

View File

@ -73,7 +73,7 @@ class TorchCompileWrapperWithCustomDispatcher:
return return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe() frame = sys._getframe()
while True: while frame and frame.f_back:
frame = frame.f_back frame = frame.f_back
code_name = frame.f_code.co_name code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1] file_name = frame.f_code.co_filename.split(os.path.sep)[-1]

View File

@ -626,13 +626,14 @@ class CacheConfig:
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb self.cpu_offload_gb = cpu_offload_gb
self._verify_args() self._verify_args()
self._verify_cache_dtype() self._verify_cache_dtype()
self._verify_prefix_caching() self._verify_prefix_caching()
# Will be set after profiling. # Will be set after profiling.
self.num_gpu_blocks = None self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks = None self.num_cpu_blocks: Optional[int] = None
def metrics_info(self): def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus # convert cache_config to dict(key: str, value: str) for prometheus
@ -709,7 +710,8 @@ class TokenizerPoolConfig:
@classmethod @classmethod
def create_config( def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str, cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]] tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]: ) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters. """Create a TokenizerPoolConfig from the given parameters.
@ -1544,7 +1546,7 @@ class LoRAConfig:
max_loras: int max_loras: int
fully_sharded_loras: bool = False fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None lora_dtype: Optional[Union[torch.dtype, str]] = None
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
# This is a constant. # This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256 lora_vocab_padding_size: ClassVar[int] = 256

View File

@ -4,8 +4,9 @@ import random
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, from typing import Callable, Deque, Dict, Iterable, List, Optional
Tuple, Union) from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
class SchedulerOutputs: class SchedulerOutputs:
"""The scheduling decision made from a scheduler.""" """The scheduling decision made from a scheduler."""
# Scheduled sequence groups. # Scheduled sequence groups.
scheduled_seq_groups: Iterable[ScheduledSequenceGroup] scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled. # Number of prefill groups scheduled.
num_prefill_groups: int num_prefill_groups: int
# Total number of batched tokens. # Total number of batched tokens.

View File

@ -3,7 +3,7 @@ import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union) Tuple, Type, Union, cast)
import torch import torch
@ -89,7 +89,7 @@ class EngineArgs:
trust_remote_code: bool = False trust_remote_code: bool = False
download_dir: Optional[str] = None download_dir: Optional[str] = None
load_format: str = 'auto' load_format: str = 'auto'
config_format: str = 'auto' config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None quantization_param_path: Optional[str] = None
@ -181,7 +181,7 @@ class EngineArgs:
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
# Setup plugins # Setup plugins
@ -837,7 +837,8 @@ class EngineArgs:
def create_model_config(self) -> ModelConfig: def create_model_config(self) -> ModelConfig:
return ModelConfig( return ModelConfig(
model=self.model, model=self.model,
tokenizer=self.tokenizer, # We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode, tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
dtype=self.dtype, dtype=self.dtype,
@ -908,8 +909,9 @@ class EngineArgs:
self.enable_prefix_caching = False self.enable_prefix_caching = False
cache_config = CacheConfig( cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len (self.max_model_len if self.max_model_len is not None else 0),
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space, swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype, cache_dtype=self.kv_cache_dtype,

View File

@ -6,7 +6,7 @@ from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, Union, overload from typing import Set, Type, Union, cast, overload
import torch import torch
from typing_extensions import TypeVar from typing_extensions import TypeVar
@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
@ -188,7 +188,7 @@ class LLMEngine:
raise TypeError(f"Expected output of type {output_type}, " raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}") f"but found type {type(output)}")
return output return cast(_O, output)
@classmethod @classmethod
def validate_outputs( def validate_outputs(
@ -1039,6 +1039,7 @@ class LLMEngine:
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
has_multiple_outputs: bool = len(outputs) > 1 has_multiple_outputs: bool = len(outputs) > 1
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if has_multiple_outputs: if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \ assert self.scheduler_config.is_multi_step or \
self.speculative_config self.speculative_config
@ -1084,6 +1085,7 @@ class LLMEngine:
finished_before.append(i) finished_before.append(i)
continue continue
output: List[SequenceGroupOutput]
if has_multiple_outputs: if has_multiple_outputs:
output = outputs_by_sequence_group[i] output = outputs_by_sequence_group[i]
else: else:
@ -1096,7 +1098,7 @@ class LLMEngine:
seq_group, seq_group_meta, is_first_step_output) seq_group, seq_group_meta, is_first_step_output)
else: else:
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size) seq_group_meta.token_chunk_size or 0)
if outputs: if outputs:
for o in outputs: for o in outputs:
@ -1104,13 +1106,13 @@ class LLMEngine:
and seq_group.metrics is not None): and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None: if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += ( seq_group.metrics.model_forward_time += (
o.model_forward_time) o.model_forward_time or 0)
else: else:
seq_group.metrics.model_forward_time = ( seq_group.metrics.model_forward_time = (
o.model_forward_time) o.model_forward_time)
if seq_group.metrics.model_execute_time is not None: if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += ( seq_group.metrics.model_execute_time += (
o.model_execute_time) o.model_execute_time or 0)
else: else:
seq_group.metrics.model_execute_time = ( seq_group.metrics.model_execute_time = (
o.model_execute_time) o.model_execute_time)
@ -1236,8 +1238,10 @@ class LLMEngine:
seq_group, seq_group_metadata, seq_group, seq_group_metadata,
seq_group.state.num_steps == 1) seq_group.state.num_steps == 1)
else: else:
seq_group.update_num_computed_tokens( token_chunk_size = (seq_group_metadata.token_chunk_size
seq_group_metadata.token_chunk_size) if seq_group_metadata.token_chunk_size
is not None else 0)
seq_group.update_num_computed_tokens(token_chunk_size)
if seq_group_metadata.do_sample: if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, ( assert len(sequence_group_outputs.samples) == 1, (

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Type, Union, cast
import numpy as np import numpy as np
import prometheus_client import prometheus_client
@ -249,10 +249,11 @@ class _RayHistogramWrapper:
labelnames: Optional[List[str]] = None, labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None): buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self._histogram = ray_metrics.Histogram(name=name, self._histogram = ray_metrics.Histogram(name=name,
description=documentation, description=documentation,
tag_keys=labelnames_tuple, tag_keys=labelnames_tuple,
boundaries=buckets) boundaries=boundaries)
def labels(self, **labels): def labels(self, **labels):
self._histogram.set_default_tags(labels) self._histogram.set_default_tags(labels)
@ -267,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library. Provides the same metrics as Metrics but uses Ray's util.metrics library.
""" """
_gauge_cls = _RayGaugeWrapper _gauge_cls: Type[prometheus_client.Gauge] = cast(
_counter_cls = _RayCounterWrapper Type[prometheus_client.Gauge], _RayGaugeWrapper)
_histogram_cls = _RayHistogramWrapper _counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None: if ray_metrics is None:

View File

@ -3,7 +3,7 @@ import copy
import pickle import pickle
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, overload) Optional, Union, cast, overload)
import cloudpickle import cloudpickle
import zmq import zmq
@ -513,9 +513,14 @@ class MQLLMEngineClient(EngineClient):
assert (prompt is not None and pooling_params is not None assert (prompt is not None and pooling_params is not None
and request_id is not None) and request_id is not None)
return self._process_request(prompt, pooling_params, request_id, return cast(
lora_request, trace_headers, None, AsyncGenerator[EmbeddingRequestOutput, None],
priority) self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
async def _process_request( async def _process_request(
self, self,
@ -543,7 +548,9 @@ class MQLLMEngineClient(EngineClient):
build_guided_decoding_logits_processor_async( build_guided_decoding_logits_processor_async(
sampling_params=params, sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request), tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.guided_decoding_backend default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
) )
# 1) Create output queue for this requests. # 1) Create output queue for this requests.

View File

@ -73,11 +73,9 @@ class MQLLMEngine:
# For MQLLMEngine, we can use cached outputs, since each new request # For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees # output is immediately pickled and send over the socket, which frees
# the python object to be reused again. # the python object to be reused again.
use_cached_outputs = True kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args, self.engine = LLMEngine(*args, **kwargs)
**kwargs,
use_cached_outputs=use_cached_outputs)
self.log_requests = log_requests self.log_requests = log_requests
self.use_async_sockets = use_async_sockets self.use_async_sockets = use_async_sockets

View File

@ -1,5 +1,5 @@
import functools import functools
from typing import Callable, List from typing import Callable, List, cast
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
SequenceGroupOutput, SequenceOutput, SequenceStatus) CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
@ -57,6 +59,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
""" """
for output in outputs: for output in outputs:
# Concatenate single-step prompt logprob processing results. # Concatenate single-step prompt logprob processing results.
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output) single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod @staticmethod
@ -100,8 +103,18 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding.") "Beam search not supported in multi-step decoding.")
seq = seqs[0] seq = seqs[0]
seq_id = seq.seq_id seq_id = seq.seq_id
assert all( # This method is defined in the more generic
[seq_id == output.samples[0].parent_seq_id for output in outputs]) # SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert all([
isinstance(output, CompletionSequenceGroupOutput)
for output in outputs
])
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
assert all([
seq_id == output.samples[0].parent_seq_id
for output in compl_outputs
])
if is_async: if is_async:
# Async case: We process tokens one by one. Here, we know the token # Async case: We process tokens one by one. Here, we know the token
@ -113,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group, # Since there's only one sequence per sequence group,
# we can take the first sample. # we can take the first sample.
samples = [output.samples[0] for output in outputs] samples = [output.samples[0] for output in compl_outputs]
# entries in sample tokens may be invalid (eg. due to spec decode # entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens). # rejecting tokens).

View File

@ -6,8 +6,9 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (CompletionSequenceGroupOutput, Sequence,
SequenceOutput, SequenceStatus) SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter from vllm.utils import Counter
@ -16,7 +17,7 @@ logger = init_logger(__name__)
def single_step_process_prompt_logprob( def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: SequenceGroupOutput) -> None: output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput` """Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step. for a given step.
@ -106,6 +107,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
""" """
assert len(outputs) == 1, ("Single step should only has 1 output.") assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0] output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output) single_step_process_prompt_logprob(self, seq_group, output)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

View File

@ -57,7 +57,7 @@ class StopChecker:
# Check if a stop token was encountered. # Check if a stop token was encountered.
# This assumes a single token produced per step. # This assumes a single token produced per step.
last_token_id = seq.get_last_token_id() last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids: if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and ( if new_char_count and (
not sampling_params.include_stop_str_in_output): not sampling_params.include_stop_str_in_output):
# Remove last token # Remove last token
@ -92,7 +92,7 @@ class StopChecker:
Returns the stop string if matched or else None. Returns the stop string if matched or else None.
""" """
if not new_char_count: if not new_char_count or not sampling_params.stop:
return None return None
for stop_str in sampling_params.stop: for stop_str in sampling_params.stop:

View File

@ -1,22 +1,25 @@
from typing import List from typing import List
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import cast
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
def create_output_by_sequence_group( def create_output_by_sequence_group(
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], outputs: GenericSequence[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]: num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
output_by_sequence_group: List[List[SequenceGroupOutput]] = [ output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
[] for _ in range(num_seq_groups) [] for _ in range(num_seq_groups)
] ]
for step in outputs: for step in outputs:
sequence_group_output: CompletionSequenceGroupOutput
for i, sequence_group_output in enumerate(step): for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output) output_by_sequence_group[i].append(sequence_group_output)
return output_by_sequence_group # Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)

View File

@ -1,4 +1,4 @@
from typing import List, Literal, Sequence, TypedDict, Union, overload from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
from typing_extensions import TypeIs from typing_extensions import TypeIs
@ -44,13 +44,16 @@ def parse_and_batch_prompt(
if is_list_of(prompt, str): if is_list_of(prompt, str):
# case 2: array of strings # case 2: array of strings
prompt = cast(List[str], prompt)
return [ return [
ParsedText(content=elem, is_tokens=False) for elem in prompt ParsedText(content=elem, is_tokens=False) for elem in prompt
] ]
if is_list_of(prompt, int): if is_list_of(prompt, int):
# case 3: array of tokens # case 3: array of tokens
prompt = cast(List[int], prompt)
return [ParsedTokens(content=prompt, is_tokens=True)] return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list): if is_list_of(prompt, list):
prompt = cast(List[List[int]], prompt)
if len(prompt[0]) == 0: if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt") raise ValueError("please provide at least one prompt")

View File

@ -4,7 +4,7 @@ import warnings
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
from math import inf from math import inf
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, Iterator, List, Optional, Tuple, Union
import msgspec import msgspec
import torch import torch
@ -117,12 +117,15 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time. # block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
def __getitem__(self, idx: int): def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx] return self.outputs[idx]
def __setitem__(self, idx: int, value): def __setitem__(self, idx: int, value):
self.outputs[idx] = value self.outputs[idx] = value
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
return iter(self.outputs)
def __len__(self): def __len__(self):
return len(self.outputs) return len(self.outputs)

View File

@ -4,6 +4,7 @@ from typing import List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@ -92,7 +93,7 @@ class RequestOutput:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[PromptType],
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],

View File

@ -788,7 +788,7 @@ class SequenceGroup:
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
self.init_multi_step(num_steps=num_lookahead_slots + 1) self.init_multi_step(num_steps=num_lookahead_slots + 1)
def get_last_latency(self, now: float) -> Optional[float]: def get_last_latency(self, now: float) -> float:
"""Sets the last token time for Request level timings.""" """Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error. # If still in prefill phase, raise Error.
if self.is_prefill(): if self.is_prefill():
@ -1198,7 +1198,7 @@ class PoolerOutput(
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
def __getitem__(self, idx: int): def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
return self.outputs[idx] return self.outputs[idx]
def __setitem__(self, idx: int, value): def __setitem__(self, idx: int, value):