[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
8345045833
commit
776dbd74f1
@ -13,24 +13,14 @@ run_mypy() {
|
||||
|
||||
run_mypy # Note that this is less strict than CI
|
||||
run_mypy tests
|
||||
run_mypy vllm/assets
|
||||
run_mypy vllm/attention
|
||||
#run_mypy vllm/compilation
|
||||
#run_mypy vllm/core
|
||||
run_mypy vllm/compilation
|
||||
run_mypy vllm/distributed
|
||||
run_mypy vllm/engine
|
||||
run_mypy vllm/entrypoints
|
||||
run_mypy vllm/executor
|
||||
#run_mypy vllm/inputs
|
||||
run_mypy vllm/logging
|
||||
run_mypy vllm/lora
|
||||
run_mypy vllm/model_executor
|
||||
run_mypy vllm/multimodal
|
||||
run_mypy vllm/platforms
|
||||
run_mypy vllm/plugins
|
||||
run_mypy vllm/prompt_adapter
|
||||
run_mypy vllm/spec_decode
|
||||
run_mypy vllm/transformers_utils
|
||||
run_mypy vllm/usage
|
||||
#run_mypy vllm/vllm_flash_attn
|
||||
run_mypy vllm/worker
|
||||
|
@ -92,7 +92,7 @@ class Attention(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
|
@ -244,8 +244,8 @@ def vllm_backend(
|
||||
|
||||
def select_default_backend(level: int) -> Union[str, Callable]:
|
||||
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||
backend = "eager"
|
||||
return backend
|
||||
backend_str = "eager"
|
||||
return backend_str
|
||||
assert level in [
|
||||
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
|
||||
], f"Invalid level {level}"
|
||||
|
@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
||||
def cls_decorator_helper(cls: type):
|
||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||
# to avoid too much indentation for `_support_torch_compile``
|
||||
if not hasattr(cls, 'forward'):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
for k in dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
@ -63,13 +65,13 @@ def _support_torch_compile(cls: type,
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
|
||||
old_init = cls.__init__
|
||||
old_init = cls.__init__ # type: ignore
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
old_init(self, *args, **kwargs)
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
@ -109,5 +111,5 @@ def _support_torch_compile(cls: type,
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
cls.__call__ = __call__
|
||||
cls.__call__ = __call__ # type: ignore
|
||||
return cls
|
||||
|
@ -73,7 +73,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while True:
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
|
@ -626,13 +626,14 @@ class CacheConfig:
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.cpu_offload_gb = cpu_offload_gb
|
||||
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
self._verify_prefix_caching()
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks = None
|
||||
self.num_cpu_blocks = None
|
||||
self.num_gpu_blocks: Optional[int] = None
|
||||
self.num_cpu_blocks: Optional[int] = None
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||
@ -709,7 +710,8 @@ class TokenizerPoolConfig:
|
||||
|
||||
@classmethod
|
||||
def create_config(
|
||||
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
|
||||
cls, tokenizer_pool_size: int,
|
||||
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
|
||||
tokenizer_pool_extra_config: Optional[Union[str, dict]]
|
||||
) -> Optional["TokenizerPoolConfig"]:
|
||||
"""Create a TokenizerPoolConfig from the given parameters.
|
||||
@ -1544,7 +1546,7 @@ class LoRAConfig:
|
||||
max_loras: int
|
||||
fully_sharded_loras: bool = False
|
||||
max_cpu_loras: Optional[int] = None
|
||||
lora_dtype: Optional[torch.dtype] = None
|
||||
lora_dtype: Optional[Union[torch.dtype, str]] = None
|
||||
lora_extra_vocab_size: int = 256
|
||||
# This is a constant.
|
||||
lora_vocab_padding_size: ClassVar[int] = 256
|
||||
|
@ -4,8 +4,9 @@ import random
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
|
||||
Tuple, Union)
|
||||
from typing import Callable, Deque, Dict, Iterable, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
|
||||
class SchedulerOutputs:
|
||||
"""The scheduling decision made from a scheduler."""
|
||||
# Scheduled sequence groups.
|
||||
scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
|
||||
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
|
||||
# Number of prefill groups scheduled.
|
||||
num_prefill_groups: int
|
||||
# Total number of batched tokens.
|
||||
|
@ -3,7 +3,7 @@ import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||
Tuple, Type, Union)
|
||||
Tuple, Type, Union, cast)
|
||||
|
||||
import torch
|
||||
|
||||
@ -89,7 +89,7 @@ class EngineArgs:
|
||||
trust_remote_code: bool = False
|
||||
download_dir: Optional[str] = None
|
||||
load_format: str = 'auto'
|
||||
config_format: str = 'auto'
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
quantization_param_path: Optional[str] = None
|
||||
@ -181,7 +181,7 @@ class EngineArgs:
|
||||
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
|
||||
# Setup plugins
|
||||
@ -837,7 +837,8 @@ class EngineArgs:
|
||||
def create_model_config(self) -> ModelConfig:
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
# We know this is not None because we set it in __post_init__
|
||||
tokenizer=cast(str, self.tokenizer),
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
dtype=self.dtype,
|
||||
@ -908,8 +909,9 @@ class EngineArgs:
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
cache_config = CacheConfig(
|
||||
# neuron needs block_size = max_model_len
|
||||
block_size=self.block_size if self.device != "neuron" else
|
||||
self.max_model_len, # neuron needs block_size = max_model_len
|
||||
(self.max_model_len if self.max_model_len is not None else 0),
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
swap_space=self.swap_space,
|
||||
cache_dtype=self.kv_cache_dtype,
|
||||
|
@ -6,7 +6,7 @@ from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, overload
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar
|
||||
@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
@ -188,7 +188,7 @@ class LLMEngine:
|
||||
raise TypeError(f"Expected output of type {output_type}, "
|
||||
f"but found type {type(output)}")
|
||||
|
||||
return output
|
||||
return cast(_O, output)
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(
|
||||
@ -1039,6 +1039,7 @@ class LLMEngine:
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
has_multiple_outputs: bool = len(outputs) > 1
|
||||
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
|
||||
if has_multiple_outputs:
|
||||
assert self.scheduler_config.is_multi_step or \
|
||||
self.speculative_config
|
||||
@ -1084,6 +1085,7 @@ class LLMEngine:
|
||||
finished_before.append(i)
|
||||
continue
|
||||
|
||||
output: List[SequenceGroupOutput]
|
||||
if has_multiple_outputs:
|
||||
output = outputs_by_sequence_group[i]
|
||||
else:
|
||||
@ -1096,7 +1098,7 @@ class LLMEngine:
|
||||
seq_group, seq_group_meta, is_first_step_output)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size)
|
||||
seq_group_meta.token_chunk_size or 0)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
@ -1104,13 +1106,13 @@ class LLMEngine:
|
||||
and seq_group.metrics is not None):
|
||||
if seq_group.metrics.model_forward_time is not None:
|
||||
seq_group.metrics.model_forward_time += (
|
||||
o.model_forward_time)
|
||||
o.model_forward_time or 0)
|
||||
else:
|
||||
seq_group.metrics.model_forward_time = (
|
||||
o.model_forward_time)
|
||||
if seq_group.metrics.model_execute_time is not None:
|
||||
seq_group.metrics.model_execute_time += (
|
||||
o.model_execute_time)
|
||||
o.model_execute_time or 0)
|
||||
else:
|
||||
seq_group.metrics.model_execute_time = (
|
||||
o.model_execute_time)
|
||||
@ -1236,8 +1238,10 @@ class LLMEngine:
|
||||
seq_group, seq_group_metadata,
|
||||
seq_group.state.num_steps == 1)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
token_chunk_size = (seq_group_metadata.token_chunk_size
|
||||
if seq_group_metadata.token_chunk_size
|
||||
is not None else 0)
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Counter as CollectionsCounter
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import prometheus_client
|
||||
@ -249,10 +249,11 @@ class _RayHistogramWrapper:
|
||||
labelnames: Optional[List[str]] = None,
|
||||
buckets: Optional[List[float]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
boundaries = buckets if buckets else []
|
||||
self._histogram = ray_metrics.Histogram(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple,
|
||||
boundaries=buckets)
|
||||
boundaries=boundaries)
|
||||
|
||||
def labels(self, **labels):
|
||||
self._histogram.set_default_tags(labels)
|
||||
@ -267,9 +268,12 @@ class RayMetrics(Metrics):
|
||||
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
|
||||
Provides the same metrics as Metrics but uses Ray's util.metrics library.
|
||||
"""
|
||||
_gauge_cls = _RayGaugeWrapper
|
||||
_counter_cls = _RayCounterWrapper
|
||||
_histogram_cls = _RayHistogramWrapper
|
||||
_gauge_cls: Type[prometheus_client.Gauge] = cast(
|
||||
Type[prometheus_client.Gauge], _RayGaugeWrapper)
|
||||
_counter_cls: Type[prometheus_client.Counter] = cast(
|
||||
Type[prometheus_client.Counter], _RayCounterWrapper)
|
||||
_histogram_cls: Type[prometheus_client.Histogram] = cast(
|
||||
Type[prometheus_client.Histogram], _RayHistogramWrapper)
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
if ray_metrics is None:
|
||||
|
@ -3,7 +3,7 @@ import copy
|
||||
import pickle
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
|
||||
Optional, Union, overload)
|
||||
Optional, Union, cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
@ -513,9 +513,14 @@ class MQLLMEngineClient(EngineClient):
|
||||
assert (prompt is not None and pooling_params is not None
|
||||
and request_id is not None)
|
||||
|
||||
return self._process_request(prompt, pooling_params, request_id,
|
||||
lora_request, trace_headers, None,
|
||||
priority)
|
||||
return cast(
|
||||
AsyncGenerator[EmbeddingRequestOutput, None],
|
||||
self._process_request(prompt,
|
||||
pooling_params,
|
||||
request_id,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority=priority))
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
@ -543,7 +548,9 @@ class MQLLMEngineClient(EngineClient):
|
||||
build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer(lora_request),
|
||||
default_guided_backend=self.decoding_config.guided_decoding_backend
|
||||
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
||||
if self.decoding_config
|
||||
else DecodingConfig.guided_decoding_backend),
|
||||
)
|
||||
|
||||
# 1) Create output queue for this requests.
|
||||
|
@ -73,11 +73,9 @@ class MQLLMEngine:
|
||||
# For MQLLMEngine, we can use cached outputs, since each new request
|
||||
# output is immediately pickled and send over the socket, which frees
|
||||
# the python object to be reused again.
|
||||
use_cached_outputs = True
|
||||
kwargs['use_cached_outputs'] = True
|
||||
|
||||
self.engine = LLMEngine(*args,
|
||||
**kwargs,
|
||||
use_cached_outputs=use_cached_outputs)
|
||||
self.engine = LLMEngine(*args, **kwargs)
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.use_async_sockets = use_async_sockets
|
||||
|
@ -1,5 +1,5 @@
|
||||
import functools
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, cast
|
||||
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
@ -57,6 +59,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"""
|
||||
for output in outputs:
|
||||
# Concatenate single-step prompt logprob processing results.
|
||||
assert isinstance(output, CompletionSequenceGroupOutput)
|
||||
single_step_process_prompt_logprob(self, seq_group, output)
|
||||
|
||||
@staticmethod
|
||||
@ -100,8 +103,18 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"Beam search not supported in multi-step decoding.")
|
||||
seq = seqs[0]
|
||||
seq_id = seq.seq_id
|
||||
assert all(
|
||||
[seq_id == output.samples[0].parent_seq_id for output in outputs])
|
||||
# This method is defined in the more generic
|
||||
# SequenceGroupOutputProcessor, but here we assume that the outputs are
|
||||
# of a more specific type.
|
||||
assert all([
|
||||
isinstance(output, CompletionSequenceGroupOutput)
|
||||
for output in outputs
|
||||
])
|
||||
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
|
||||
assert all([
|
||||
seq_id == output.samples[0].parent_seq_id
|
||||
for output in compl_outputs
|
||||
])
|
||||
|
||||
if is_async:
|
||||
# Async case: We process tokens one by one. Here, we know the token
|
||||
@ -113,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
|
||||
# Since there's only one sequence per sequence group,
|
||||
# we can take the first sample.
|
||||
samples = [output.samples[0] for output in outputs]
|
||||
samples = [output.samples[0] for output in compl_outputs]
|
||||
|
||||
# entries in sample tokens may be invalid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
|
@ -6,8 +6,9 @@ from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
@ -16,7 +17,7 @@ logger = init_logger(__name__)
|
||||
|
||||
def single_step_process_prompt_logprob(
|
||||
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
|
||||
output: SequenceGroupOutput) -> None:
|
||||
output: CompletionSequenceGroupOutput) -> None:
|
||||
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
|
||||
for a given step.
|
||||
|
||||
@ -106,6 +107,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"""
|
||||
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||
output = outputs[0]
|
||||
assert isinstance(output, CompletionSequenceGroupOutput)
|
||||
single_step_process_prompt_logprob(self, seq_group, output)
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
|
@ -57,7 +57,7 @@ class StopChecker:
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
@ -92,7 +92,7 @@ class StopChecker:
|
||||
|
||||
Returns the stop string if matched or else None.
|
||||
"""
|
||||
if not new_char_count:
|
||||
if not new_char_count or not sampling_params.stop:
|
||||
return None
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
|
@ -1,22 +1,25 @@
|
||||
from typing import List
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
from typing import cast
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import PoolerOutput, SequenceGroupOutput
|
||||
from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
|
||||
|
||||
|
||||
def create_output_by_sequence_group(
|
||||
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
|
||||
outputs: GenericSequence[SamplerOutput],
|
||||
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
|
||||
"""Helper method which transforms a 2d list organized by
|
||||
[step][sequence group] into [sequence group][step].
|
||||
"""
|
||||
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
|
||||
output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
|
||||
[] for _ in range(num_seq_groups)
|
||||
]
|
||||
for step in outputs:
|
||||
sequence_group_output: CompletionSequenceGroupOutput
|
||||
for i, sequence_group_output in enumerate(step):
|
||||
output_by_sequence_group[i].append(sequence_group_output)
|
||||
|
||||
return output_by_sequence_group
|
||||
# Cast to the more generic type that CompletionSequenceGroupOutput
|
||||
# inherits from.
|
||||
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Literal, Sequence, TypedDict, Union, overload
|
||||
from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
@ -44,13 +44,16 @@ def parse_and_batch_prompt(
|
||||
|
||||
if is_list_of(prompt, str):
|
||||
# case 2: array of strings
|
||||
prompt = cast(List[str], prompt)
|
||||
return [
|
||||
ParsedText(content=elem, is_tokens=False) for elem in prompt
|
||||
]
|
||||
if is_list_of(prompt, int):
|
||||
# case 3: array of tokens
|
||||
prompt = cast(List[int], prompt)
|
||||
return [ParsedTokens(content=prompt, is_tokens=True)]
|
||||
if is_list_of(prompt, list):
|
||||
prompt = cast(List[List[int]], prompt)
|
||||
if len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from math import inf
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@ -117,12 +117,15 @@ class SamplerOutput(
|
||||
# block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
|
||||
return iter(self.outputs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
@ -92,7 +93,7 @@ class RequestOutput:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt: Optional[PromptType],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
outputs: List[CompletionOutput],
|
||||
|
@ -788,7 +788,7 @@ class SequenceGroup:
|
||||
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
|
||||
self.init_multi_step(num_steps=num_lookahead_slots + 1)
|
||||
|
||||
def get_last_latency(self, now: float) -> Optional[float]:
|
||||
def get_last_latency(self, now: float) -> float:
|
||||
"""Sets the last token time for Request level timings."""
|
||||
# If still in prefill phase, raise Error.
|
||||
if self.is_prefill():
|
||||
@ -1198,7 +1198,7 @@ class PoolerOutput(
|
||||
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value):
|
||||
|
Loading…
x
Reference in New Issue
Block a user