[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)

This commit is contained in:
SangBin Cho 2024-08-18 17:57:20 -07:00 committed by GitHub
parent 200a2ffa6b
commit ff7ec82c4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 722 additions and 346 deletions

View File

@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10 typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq pyzmq
msgspec
librosa # Required for audio processing librosa # Required for audio processing
soundfile # Required for audio processing soundfile # Required for audio processing
gguf == 0.9.1 gguf == 0.9.1

View File

@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`.
import pytest import pytest
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
import vllm.envs as envs
from vllm import SamplingParams from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT) ENABLE_ARTIFICIAL_PREEMPT)
@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"tests/basic_correctness/test_preemption.py`") "tests/basic_correctness/test_preemption.py`")
@pytest.fixture
def worker_use_ray() -> bool:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return envs.VLLM_USE_RAY_SPMD_WORKER
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
chunked_prefill_token_size: int, chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Ensure that chunked prefill works with preemption.""" """Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256) max_num_seqs = min(chunked_prefill_token_size, 256)
@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
@ -79,6 +89,7 @@ def test_preemption(
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""By default, recompute preemption is enabled""" """By default, recompute preemption is enabled"""
@ -89,6 +100,7 @@ def test_preemption(
model, model,
dtype=dtype, dtype=dtype,
disable_log_stats=False, disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
@ -132,6 +144,7 @@ def test_swap(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
beam_width: int, beam_width: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Use beam search enables swapping.""" """Use beam search enables swapping."""
example_prompts = example_prompts[:1] example_prompts = example_prompts[:1]
@ -144,6 +157,7 @@ def test_swap(
dtype=dtype, dtype=dtype,
swap_space=10, swap_space=10,
disable_log_stats=False, disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts, vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens) beam_width, max_tokens)
@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
beam_width: int, beam_width: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Verify infeasible swap request will be ignored.""" """Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16 BLOCK_SIZE = 16
@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish. # decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks, num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
sampling_params = SamplingParams(n=beam_width, sampling_params = SamplingParams(n=beam_width,
use_beam_search=True, use_beam_search=True,
@ -230,6 +246,7 @@ def test_preemption_infeasible(
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Verify infeasible preemption request will be ignored.""" """Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16 BLOCK_SIZE = 16
@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever. # ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens, sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True) ignore_eos=True)

View File

@ -0,0 +1,33 @@
import msgspec
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.sequence import ExecuteModelRequest
from ..spec_decode.utils import create_batch
def test_msgspec_serialization():
num_lookahead_slots = 4
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=4)
encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
req = decoder.decode(encoder.encode(execute_model_req))
expected = execute_model_req.seq_group_metadata_list
actual = req.seq_group_metadata_list
assert (len(expected) == len(actual))
expected = expected[0]
actual = actual[0]
assert expected.block_tables == actual.block_tables
assert expected.is_prompt == actual.is_prompt
assert expected.request_id == actual.request_id
assert (expected.seq_data[0].prompt_token_ids ==
actual.seq_data[0].prompt_token_ids)
assert (expected.seq_data[0].output_token_ids ==
actual.seq_data[0].output_token_ids)

View File

@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
@pytest.mark.skipif(cuda_device_count_stateless() < 2, @pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, test_suite", [ "model, distributed_executor_backend, attention_backend, "
"test_suite", [
("facebook/opt-125m", "ray", "", "L4"), ("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"), ("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),

View File

@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py
``` ```
""" """
import os
import pytest import pytest
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
@ -30,6 +32,11 @@ def test_models(
model: str, model: str,
distributed_executor_backend: str, distributed_executor_backend: str,
) -> None: ) -> None:
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa
assert distributed_executor_backend == "ray"
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
dtype = "half" dtype = "half"
max_tokens = 5 max_tokens = 5

View File

@ -1,5 +1,6 @@
import itertools import itertools
import random import random
from array import array
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import Counter, is_pin_memory_available from vllm.utils import Counter, is_pin_memory_available
@ -56,7 +58,9 @@ def _do_sample(
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sequence_data(num_input=3, num_generated=0): def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData( seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input)) array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
if num_generated > 0: if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated) k=num_generated)
@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams( sampling_params=SamplingParams(
temperature=1, temperature=1,
top_k=top_k, top_k=top_k,
@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
sampling_params=sampling_params[i], sampling_params=sampling_params[i],
block_tables={0: [1]}, block_tables={0: [1]},
)) ))

View File

@ -1,3 +1,4 @@
from array import array
from itertools import count from itertools import count
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
@ -9,7 +10,8 @@ import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata, SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data={ seq_data={
i: i:
SequenceData( SequenceData(
prompt_token_ids=prompt_token_ids[:], array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
output_token_ids=cont_token_ids[:], _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
), ),
}, },
sampling_params=SamplingParams(temperature=0.0, ), sampling_params=SamplingParams(temperature=0.0, ),

View File

@ -1,4 +1,5 @@
import random import random
from array import array
from typing import Tuple from typing import Tuple
from unittest.mock import patch from unittest.mock import patch
@ -8,7 +9,8 @@ import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(temperature=0, sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]), logits_processors=[pick_ith]),
block_tables={0: [1]}, block_tables={0: [1]},

View File

@ -1,6 +1,9 @@
from array import array
import pytest import pytest
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput) SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt from .core.utils import create_dummy_prompt
@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
def test_sequence_data_prefill(): def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0 assert seq_data.get_num_computed_tokens() == 0
# advance by 2 # advance by 2

View File

@ -1,10 +1,12 @@
from array import array
from typing import List from typing import List
import pytest import pytest
import torch import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_cpu from vllm.utils import is_cpu
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
@ -125,10 +127,12 @@ def test_prepare_prompt(
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len) encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len))) encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -319,10 +323,12 @@ def test_prepare_decode(
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len) encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len))) encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,

View File

@ -1,3 +1,4 @@
from array import array
from typing import List from typing import List
import pytest import pytest
@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import get_open_port from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len) context_lens.append(context_len)
seq_data = SequenceData(list(range(context_len))) seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished. # Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0)
@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size): for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len)) prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
seq_data = SequenceData(prompt_toks) seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)

View File

@ -1,8 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest(ABC): class AdapterRequest(ABC):
""" """
Base class for adapter requests. Base class for adapter requests.

View File

@ -770,8 +770,8 @@ class ParallelConfig:
self.tokenizer_pool_config = tokenizer_pool_config self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group self.placement_group = placement_group
self.world_size = pipeline_parallel_size * self.tensor_parallel_size self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if worker_use_ray: if worker_use_ray:
if self.distributed_executor_backend is None: if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray" self.distributed_executor_backend = "ray"
@ -867,6 +867,11 @@ class SchedulerConfig:
swapping. However, when the sequence group has multiple sequences swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In (e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead. such a case, we use swapping instead.
send_delta_data: Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
""" """
def __init__(self, def __init__(self,
@ -879,7 +884,8 @@ class SchedulerConfig:
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False, embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1) -> None: num_scheduler_steps: int = 1,
send_delta_data: bool = False) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
else: else:
@ -909,6 +915,7 @@ class SchedulerConfig:
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps self.num_scheduler_steps = num_scheduler_steps
self.send_delta_data = send_delta_data
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:

View File

@ -12,7 +12,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import PyObjectCache from vllm.utils import PyObjectCache
logger = init_logger(__name__) logger = init_logger(__name__)
@ -363,8 +364,6 @@ class Scheduler:
self.num_cumulative_preemption: int = 0 self.num_cumulative_preemption: int = 0
# Used to cache python objects # Used to cache python objects
self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
seq_group_metadata_builder)
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder) scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
@ -1048,15 +1047,10 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now) seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache.get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData # seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers # seq_id -> physical block numbers
block_tables: Dict[int, block_tables: Dict[int, List[int]] = {}
List[int]] = seq_group_metadata.block_tables
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup # Encoder associated with SequenceGroup
@ -1081,45 +1075,65 @@ class Scheduler:
seq_group.get_seqs(status=SequenceStatus.RUNNING))) seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True do_sample = True
if seq_group.is_prefill(): is_prompt = seq_group.is_prefill()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill = False
if is_prompt:
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
# Prefill has only 1 sequence. # Prefill has only 1 sequence.
assert len(seqs) == 1 assert len(seqs) == 1
num_computed_tokens = seqs[0].data.get_num_computed_tokens()
is_first_prefill = num_computed_tokens == 0
# In the next iteration, all prompt tokens are not computed. # In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling. # It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when # NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated # a sequence is preempted, prefill includes previous generated
# output tokens. # output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < if (token_chunk_size + num_computed_tokens <
seqs[0].data.get_len()): seqs[0].data.get_len()):
do_sample = False do_sample = False
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
is_prompt = seq_group.is_prefill() if is_first_prefill or not self.scheduler_config.send_delta_data:
seq_group_metadata = SequenceGroupMetadata(
seq_group_metadata.__init__( request_id=seq_group.request_id,
request_id=seq_group.request_id, is_prompt=is_prompt,
is_prompt=is_prompt, seq_data=seq_data,
seq_data=seq_data, sampling_params=seq_group.sampling_params,
sampling_params=seq_group.sampling_params, block_tables=block_tables,
block_tables=block_tables, do_sample=do_sample,
do_sample=do_sample, pooling_params=seq_group.pooling_params,
pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size,
token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request,
lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums,
computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data,
encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table,
cross_block_table=cross_block_table, state=seq_group.state,
state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm
# `multi_modal_data` will only be present for the 1st comm # between engine and worker.
# between engine and worker. # the subsequent comms can still use delta, but
# the subsequent comms can still use delta, but # `multi_modal_data` will be None.
# `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data
multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None,
if scheduler_outputs.num_prefill_groups > 0 else None, prompt_adapter_request=seq_group.prompt_adapter_request,
prompt_adapter_request=seq_group.prompt_adapter_request, )
) else:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
seq_data_delta = {}
for id, data in seq_data.items():
seq_data_delta[id] = data.get_delta_and_reset()
seq_group_metadata = SequenceGroupMetadataDelta(
seq_data_delta,
seq_group.request_id,
block_tables,
is_prompt,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
computed_block_nums=common_computed_block_nums,
)
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
# Now that the batch has been created, we can assume all blocks in the # Now that the batch has been created, we can assume all blocks in the
@ -1130,8 +1144,6 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed( self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group) scheduled_seq_group.seq_group)
self._seq_group_metadata_cache.reset()
scheduler_time = time.perf_counter() - scheduler_start_time scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently # Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant # running. This will help estimate if the scheduler is a significant

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union) Union)
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, ObservabilityConfig, ParallelConfig,
@ -905,6 +906,8 @@ class EngineArgs:
embedding_mode=model_config.embedding_mode, embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode, preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps, num_scheduler_steps=self.num_scheduler_steps,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
) )
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,

View File

@ -224,7 +224,6 @@ class LLMEngine:
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
load_general_plugins() load_general_plugins()

View File

@ -0,0 +1,27 @@
from array import array
from typing import Any, Type
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if isinstance(obj, array):
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if type is array:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized

View File

@ -4,9 +4,12 @@ from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import msgspec
import vllm.envs as envs import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
def shutdown(self) -> None: def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None: if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown() self.forward_dag.teardown()
@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_remote_kwargs) ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args() worker_wrapper_kwargs = self._get_worker_wrapper_args()
@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
if self.forward_dag is None: if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req)) serialized_data = self.input_encoder.encode(execute_model_req)
return outputs[0] outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers( def _run_workers(
self, self,
@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
if self.forward_dag is None: if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req) serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future outputs = await dag_future
return outputs[0] return self.output_decoder.decode(outputs[0])
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,

View File

@ -1,6 +1,9 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import msgspec
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
@ -24,6 +27,10 @@ try:
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
@ -33,16 +40,26 @@ try:
return node_id, gpu_ids return node_id, gpu_ids
def execute_model_spmd( def execute_model_spmd(
self, req_or_tuple: Union[ExecuteModelRequest, self, req_or_tuple: Union[bytes,
Tuple[ExecuteModelRequest, Tuple[bytes,
IntermediateTensors]]): Optional[IntermediateTensors]]]
) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and """Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled. compiled DAG are both enabled.
Args: Args:
req_or_tuple: The request to execute the model, or a tuple req_or_tuple: A request or a tuple containing the
containing the request and intermediate tensors. request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
""" """
if isinstance(req_or_tuple, bytes):
serialized_req, intermediate_tensors = req_or_tuple, None
else:
serialized_req, intermediate_tensors = req_or_tuple
execute_model_req = self.input_decoder.decode(serialized_req)
# TODO(swang): This is needed right now because Ray aDAG executes # TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current # on a background thread, so we need to reset torch's current
# device. # device.
@ -51,16 +68,14 @@ try:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None
output = self.worker._execute_model_spmd(execute_model_req, output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors) intermediate_tensors)
# Pipeline model request and output to the next pipeline stage.
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
return execute_model_req, output output = serialized_req, output
else:
output = self.output_encoder.encode(output)
return output return output
ray_import_err = None ray_import_err = None

View File

@ -1,4 +1,5 @@
import functools import functools
from array import array
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol, from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
@ -21,6 +22,10 @@ logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@dataclass(frozen=True) @dataclass(frozen=True)
class InputContext: class InputContext:
@ -118,7 +123,8 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData([0] * seq_len) dummy_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_multi_modal_data = None dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data return dummy_seq_data, dummy_multi_modal_data

View File

@ -1,12 +1,15 @@
import warnings import warnings
from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import msgspec
from vllm.adapter_commons.request import AdapterRequest from vllm.adapter_commons.request import AdapterRequest
@dataclass class LoRARequest(
class LoRARequest(AdapterRequest): msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
""" """
Request for a LoRA adapter. Request for a LoRA adapter.
@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest):
lora_int_id must be globally unique for a given adapter. lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM. This is currently not enforced in vLLM.
""" """
__metaclass__ = AdapterRequest
lora_name: str lora_name: str
lora_int_id: int lora_int_id: int
lora_path: str = "" lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False) lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__ __hash__ = AdapterRequest.__hash__
def __post_init__(self): def __post_init__(self):
if 'lora_local_path' in self.__dict__: if 'lora_local_path' in self.__struct_fields__:
warnings.warn( warnings.warn(
"The 'lora_local_path' attribute is deprecated " "The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. " "and will be removed in a future version. "

View File

@ -1,5 +1,6 @@
"""Minimal implementation of BlipVisionModel intended to be only used """Minimal implementation of BlipVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@ -53,8 +54,10 @@ def dummy_seq_data_for_blip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size) [image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids) return SequenceData(token_ids)

View File

@ -1,3 +1,4 @@
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)

View File

@ -1,3 +1,4 @@
from array import array
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict) Tuple, TypedDict)
@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)

View File

@ -1,5 +1,6 @@
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@ -53,8 +54,10 @@ def dummy_seq_data_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)

View File

@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
import torch import torch
@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_image_processor, from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer) cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings from .utils import merge_multimodal_embeddings
@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size() ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx) image_feature_size = get_max_fuyu_image_tokens(ctx)
image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow image_token_ids = (
token_ids = image_token_ids * num_images array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
token_ids += [0] * (seq_len - image_feature_size * num_images) array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids) return SequenceData(token_ids)

View File

@ -23,6 +23,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math import math
import re import re
from array import array
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_image_processor, from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer) cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = [0] * seq_len token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
return SequenceData(token_ids) return SequenceData(token_ids)

View File

@ -2,6 +2,7 @@
within a vision language model.""" within a vision language model."""
import math import math
from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
token_ids += [0] * (seq_len - image_feature_size * num_images) [image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids) return SequenceData(token_ids)

View File

@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (PyObjectCache, async_tensor_h2d, from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad, is_pin_memory_available, make_tensor_with_pad,
@ -505,9 +506,11 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices) prefill_len = len(seq_group.prompt_logprob_indices)
prompt_tokens.extend( prompt_tokens.extend(
array('l') for _ in range(prefill_len)) array(VLLM_TOKEN_ID_ARRAY_TYPE)
for _ in range(prefill_len))
output_tokens.extend( output_tokens.extend(
array('l') for _ in range(prefill_len)) array(VLLM_TOKEN_ID_ARRAY_TYPE)
for _ in range(prefill_len))
if seq_group.do_sample: if seq_group.do_sample:
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]

View File

@ -1,15 +1,18 @@
from typing import Any, Optional from typing import Any, Optional
import msgspec
class PoolingParams:
class PoolingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Pooling parameters for pooling. """Pooling parameters for pooling.
Attributes: Attributes:
additional_data: Any additional data needed for pooling. additional_data: Any additional data needed for pooling.
""" """
additional_data: Optional[Any] = None
def __init__(self, additional_data: Optional[Any] = None):
self.additional_data = additional_data
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""

View File

@ -1,13 +1,17 @@
from dataclasses import dataclass import msgspec
from vllm.adapter_commons.request import AdapterRequest from vllm.adapter_commons.request import AdapterRequest
@dataclass class PromptAdapterRequest(
class PromptAdapterRequest(AdapterRequest): msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
frozen=True): # type: ignore[call-arg]
""" """
Request for a Prompt adapter. Request for a Prompt adapter.
""" """
__metaclass__ = AdapterRequest
prompt_adapter_name: str prompt_adapter_name: str
prompt_adapter_id: int prompt_adapter_id: int

View File

@ -2,10 +2,10 @@
import copy import copy
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
import msgspec
import torch import torch
from pydantic import Field
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.logger import init_logger from vllm.logger import init_logger
@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
to sample from.""" to sample from."""
class SamplingParams: class SamplingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True): # type: ignore[call-arg]
"""Sampling parameters for text generation. """Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion Overall, we follow the sampling parameters from the OpenAI text completion
@ -112,87 +116,73 @@ class SamplingParams:
(i.e., no truncation). (i.e., no truncation).
""" """
def __init__( n: int = 1
self, best_of: Optional[int] = None
n: int = 1, presence_penalty: float = 0.0
best_of: Optional[int] = None, frequency_penalty: float = 0.0
presence_penalty: float = 0.0, repetition_penalty: float = 1.0
frequency_penalty: float = 0.0, temperature: float = 1.0
repetition_penalty: float = 1.0, top_p: float = 1.0
temperature: float = 1.0, top_k: int = -1
top_p: float = 1.0, min_p: float = 0.0
top_k: int = -1, seed: Optional[int] = None
min_p: float = 0.0, use_beam_search: bool = False
seed: Optional[int] = None, length_penalty: float = 1.0
use_beam_search: bool = False, early_stopping: Union[bool, str] = False
length_penalty: float = 1.0, stop: Optional[Union[str, List[str]]] = None
early_stopping: Union[bool, str] = False, stop_token_ids: Optional[List[int]] = None
stop: Optional[Union[str, List[str]]] = None, ignore_eos: bool = False
stop_token_ids: Optional[List[int]] = None, max_tokens: Optional[int] = 16
include_stop_str_in_output: bool = False, min_tokens: int = 0
ignore_eos: bool = False, logprobs: Optional[int] = None
max_tokens: Optional[int] = 16, prompt_logprobs: Optional[int] = None
min_tokens: int = 0, # NOTE: This parameter is only exposed at the engine level for now.
logprobs: Optional[int] = None, # It is not exposed in the OpenAI API server, as the OpenAI API does
prompt_logprobs: Optional[int] = None, # not support returning only a list of token IDs.
detokenize: bool = True, detokenize: bool = True
skip_special_tokens: bool = True, skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True
logits_processors: Optional[List[LogitsProcessor]] = None, # Optional[List[LogitsProcessor]] type. We use Any here because
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, # Optional[List[LogitsProcessor]] type is not supported by msgspec.
) -> None: logits_processors: Optional[Any] = None
self.n = n include_stop_str_in_output: bool = False
self.best_of = best_of if best_of is not None else n truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty # The below fields are not supposed to be used as an input.
self.repetition_penalty = repetition_penalty # They are set in post_init.
if 0 < temperature < _MAX_TEMP: output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
if 0 < self.temperature < _MAX_TEMP:
logger.warning( logger.warning(
"temperature %s is less than %s, which may cause numerical " "temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.", "errors nan or inf in tensors. We have maxed it out to %s.",
temperature, _MAX_TEMP, _MAX_TEMP) self.temperature, _MAX_TEMP, _MAX_TEMP)
temperature = max(temperature, _MAX_TEMP) self.temperature = max(self.temperature, _MAX_TEMP)
self.temperature = temperature if self.seed == -1:
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
if seed == -1:
self.seed = None self.seed = None
else: else:
self.seed = seed self.seed = self.seed
self.use_beam_search = use_beam_search if self.stop is None:
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.stop = [] self.stop = []
elif isinstance(stop, str): elif isinstance(self.stop, str):
self.stop = [stop] self.stop = [self.stop]
else: else:
self.stop = list(stop) self.stop = list(self.stop)
if stop_token_ids is None: if self.stop_token_ids is None:
self.stop_token_ids = [] self.stop_token_ids = []
else: else:
self.stop_token_ids = list(stop_token_ids) self.stop_token_ids = list(self.stop_token_ids)
self.ignore_eos = ignore_eos self.logprobs = 1 if self.logprobs is True else self.logprobs
self.max_tokens = max_tokens self.prompt_logprobs = (1 if self.prompt_logprobs is True else
self.min_tokens = min_tokens self.prompt_logprobs)
self.logprobs = 1 if logprobs is True else logprobs
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self.detokenize = detokenize
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation # Number of characters to hold back for stop string evaluation
# until sequence is finished. # until sequence is finished.
if self.stop and not include_stop_str_in_output: if self.stop and not self.include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
@ -206,11 +196,12 @@ class SamplingParams:
self.min_p = 0.0 self.min_p = 0.0
self._verify_greedy_sampling() self._verify_greedy_sampling()
# eos_token_id is added to this by the engine # eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids) self._all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.n < 1: if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.") raise ValueError(f"n must be at least 1, got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of < self.n: if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, " raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.") f"got n={self.n} and best_of={self.best_of}.")
@ -257,6 +248,7 @@ class SamplingParams:
and self.truncate_prompt_tokens < 1): and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, " raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}") f"got {self.truncate_prompt_tokens}")
assert isinstance(self.stop, list)
if any(not stop_str for stop_str in self.stop): if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.") raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize: if self.stop and not self.detokenize:
@ -290,6 +282,7 @@ class SamplingParams:
"default value of 1.0 when not using beam search.") "default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None: def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int)
if self.best_of > 1: if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling." raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.") f"Got {self.best_of}.")
@ -303,7 +296,7 @@ class SamplingParams:
if model_eos_token_id is not None: if model_eos_token_id is not None:
# Add the eos token id into the sampling_params to support # Add the eos token id into the sampling_params to support
# min_tokens processing. # min_tokens processing.
self.all_stop_token_ids.add(model_eos_token_id) self._all_stop_token_ids.add(model_eos_token_id)
# Update eos_token_id for generation # Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None: if (eos_ids := generation_config.get("eos_token_id")) is not None:
@ -315,7 +308,7 @@ class SamplingParams:
# purposes. # purposes.
eos_ids.discard(model_eos_token_id) eos_ids.discard(model_eos_token_id)
if eos_ids: if eos_ids:
self.all_stop_token_ids.update(eos_ids) self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos: if not self.ignore_eos:
eos_ids.update(self.stop_token_ids) eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids) self.stop_token_ids = list(eos_ids)
@ -330,6 +323,10 @@ class SamplingParams:
return SamplingType.RANDOM_SEED return SamplingType.RANDOM_SEED
return SamplingType.RANDOM return SamplingType.RANDOM
@property
def all_stop_token_ids(self) -> Set[int]:
return self._all_stop_token_ids
def clone(self) -> "SamplingParams": def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects. """Deep copy excluding LogitsProcessor objects.

View File

@ -4,10 +4,11 @@ import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Union, cast) Tuple, Union, cast)
import msgspec
import numpy import numpy
import torch import torch
@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.multimodal import MultiModalDataDict from vllm.multimodal.base import MultiModalDataDict
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
@dataclass @dataclass
class Logprob: class Logprob:
"""Infos for supporting OpenAI compatible logprobs and token ranks. """Infos for supporting OpenAI compatible logprobs and token ranks.
@ -112,7 +118,23 @@ class RequestMetrics:
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
class SequenceData: class SequenceDataDelta(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
new_output_token_ids: List[int]
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob: float
# Overwriting existing `num_computed_tokens`.
new_num_computed_tokens: int
# Overwriting existing `stage`.
new_stage: SequenceStage
class SequenceData(msgspec.Struct,
omit_defaults=True): # type: ignore[call-arg]
"""Data associated with a sequence. """Data associated with a sequence.
Args: Args:
@ -125,40 +147,57 @@ class SequenceData:
output_token_ids: The token IDs of the output. output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output. cumulative_logprob: The cumulative log probability of the output.
""" """
# NOTE: we cannot use Union[List, array] because msgspec cannot support
# union of 2 list types.
_prompt_token_ids: array
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
def __init__( ### The below fields should not be passed as an argument ###
self, _cumulative_logprob: float = 0.0
prompt_token_ids: List[int], _prompt_token_ids_tuple: Tuple[int,
output_token_ids: Optional[List[int]] = None, ...] = msgspec.field(default_factory=tuple)
) -> None: # The number of tokens that are computed (that run against the model).
self._prompt_token_ids = array('l', prompt_token_ids) _num_computed_tokens: int = 0
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) _stage: SequenceStage = SequenceStage.PREFILL
self._output_token_ids = array( _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
'l', output_token_ids if output_token_ids is not None else [])
self.cumulative_logprob = 0.0 # It is used to get delta input. It is reset when `get_delta_and_reset`
# The number of tokens that are computed (that run against the model). # is called.
self._num_computed_tokens = 0 _new_appended_tokens: List[int] = msgspec.field(default_factory=list)
self._stage: SequenceStage = SequenceStage.PREFILL
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
self._prompt_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
def _update_cached_all_tokens(self): def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
assert isinstance(self._output_token_ids, array)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._output_token_ids) self._output_token_ids)
@property
def cumulative_logprob(self) -> float:
return self._cumulative_logprob
@property @property
def prompt_token_ids(self) -> Tuple[int, ...]: def prompt_token_ids(self) -> Tuple[int, ...]:
return self._prompt_token_ids_tuple return self._prompt_token_ids_tuple
@prompt_token_ids.setter @prompt_token_ids.setter
def prompt_token_ids(self, new_prompt_token_ids) -> None: def prompt_token_ids(self, new_prompt_token_ids) -> None:
self._prompt_token_ids = array('l', new_prompt_token_ids) raise NotImplementedError
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
self._update_cached_all_tokens()
@property @property
def prompt_token_ids_array(self) -> array: def prompt_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
return self._prompt_token_ids return self._prompt_token_ids
@property @property
@ -166,18 +205,26 @@ class SequenceData:
return tuple(self._output_token_ids) return tuple(self._output_token_ids)
@output_token_ids.setter @output_token_ids.setter
def output_token_ids(self, new_output_token_ids) -> None: def output_token_ids(self, new_output_token_ids: List[int]) -> None:
self._output_token_ids = array('l', new_output_token_ids) self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
@property @property
def output_token_ids_array(self) -> array: def output_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
assert isinstance(self._output_token_ids, array)
return self._output_token_ids return self._output_token_ids
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id) self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id) self._cached_all_token_ids.append(token_id)
self.cumulative_logprob += logprob self._cumulative_logprob += logprob
def get_len(self) -> int: def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids) return len(self._output_token_ids) + len(self._prompt_token_ids)
@ -222,6 +269,7 @@ class SequenceData:
""" """
self._num_computed_tokens = 0 self._num_computed_tokens = 0
self._stage = SequenceStage.PREFILL self._stage = SequenceStage.PREFILL
self._new_appended_tokens = []
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefill tokens that are not computed.""" """Return the number of prefill tokens that are not computed."""
@ -241,6 +289,21 @@ class SequenceData:
def get_output_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids return self.output_token_ids
def get_delta_and_reset(self) -> SequenceDataDelta:
delta = SequenceDataDelta(self._new_appended_tokens,
self._cumulative_logprob,
self.get_num_computed_tokens(), self.stage)
# Reset delta state.
self._new_appended_tokens = []
return delta
def apply_delta(self, delta: SequenceDataDelta):
self._num_computed_tokens = delta.new_num_computed_tokens
self._cumulative_logprob = delta.new_cumulative_logprob
self._stage = delta.new_stage
self._output_token_ids.extend(delta.new_output_token_ids)
self._cached_all_token_ids.extend(delta.new_output_token_ids)
@property @property
def stage(self) -> SequenceStage: def stage(self) -> SequenceStage:
return self._stage return self._stage
@ -248,8 +311,9 @@ class SequenceData:
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceData(" return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, " f"prompt_token_ids={self._prompt_token_ids}, "
f"output_token_ids={self._output_token_ids}, " f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob})") f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()}")
class Sequence: class Sequence:
@ -325,7 +389,8 @@ class Sequence:
f"invalid input {inputs}; did you forget the " f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?") "encoder input prompt fields?")
self.data = SequenceData(self.prompt_token_ids) self.data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
@ -490,8 +555,8 @@ class Sequence:
f"num_blocks={self.n_blocks}, ") f"num_blocks={self.n_blocks}, ")
@dataclass class SequenceGroupState(msgspec.Struct,
class SequenceGroupState: omit_defaults=True): # type: ignore[call-arg]
"""Mutable state tied to a specific sequence group""" """Mutable state tied to a specific sequence group"""
# for multi-step decoding # for multi-step decoding
@ -647,14 +712,19 @@ class SequenceGroup:
if self.sampling_params and self.sampling_params.use_beam_search: if self.sampling_params and self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam # For beam search, maximally there will always be `best_of` beam
# candidates running in the future. # candidates running in the future.
return self.sampling_params.best_of best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
return best_of
else: else:
if (self.sampling_params if self.sampling_params:
and self.sampling_params.best_of > self.num_seqs()): best_of = self.sampling_params.best_of
# At prompt stage, the sequence group is not yet filled up assert isinstance(best_of, int)
# and only have one sequence running. However, in the if best_of > self.num_seqs():
# generation stage, we will have `best_of` sequences running. # At prompt stage, the sequence group is not yet filled up
return self.sampling_params.best_of # and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return best_of
# At sampling stages, return the number of actual sequences # At sampling stages, return the number of actual sequences
# that are not finished yet. # that are not finished yet.
return self.num_unfinished_seqs() return self.num_unfinished_seqs()
@ -757,7 +827,32 @@ class SequenceGroup:
f"num_seqs={len(self.seqs)})") f"num_seqs={len(self.seqs)})")
class SequenceGroupMetadata: class SequenceGroupMetadataDelta(
msgspec.Struct,
tag=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Delta of SequenceGroupMetadata.
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
"""
seq_data_delta: Dict[int, SequenceDataDelta]
request_id: str
block_tables: Dict[int, List[int]]
is_prompt: bool
do_sample: bool = True
token_chunk_size: Optional[int] = None
computed_block_nums: Optional[List[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
class SequenceGroupMetadata(
msgspec.Struct,
tag=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Metadata for a sequence group. Used to create `AttentionMetadata`. """Metadata for a sequence group. Used to create `AttentionMetadata`.
Args: Args:
@ -789,52 +884,39 @@ class SequenceGroupMetadata:
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
""" """
def __init__( request_id: str
self, is_prompt: bool
request_id: str, seq_data: Dict[int, SequenceData]
is_prompt: bool, sampling_params: SamplingParams
seq_data: Dict[int, SequenceData], block_tables: Dict[int, List[int]]
sampling_params: SamplingParams, do_sample: bool = True
block_tables: Dict[int, List[int]], pooling_params: Optional[PoolingParams] = None
do_sample: bool = True, lora_request: Optional[LoRARequest] = None
pooling_params: Optional[PoolingParams] = None, computed_block_nums: Optional[List[int]] = None
token_chunk_size: Optional[int] = None, state: Optional[SequenceGroupState] = msgspec.field(
lora_request: Optional[LoRARequest] = None, default_factory=lambda: SequenceGroupState())
computed_block_nums: Optional[List[int]] = None, # "MultiModalDataDict" types. We have to use Any due to msgspec
state: Optional[SequenceGroupState] = None, # doesn't allow to have union of 2 different dicts.
multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_data: Optional[Any] = None
encoder_seq_data: Optional[SequenceData] = None, encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None, cross_block_table: Optional[List[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None: token_chunk_size: Optional[int] = None
self.request_id = request_id
self.is_prompt = is_prompt
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables
self.pooling_params = pooling_params
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample
# The number of speculative tokens adopted in this request. ### Stateful fields that are lazily defined. ###
# None means specuative decoding is not used. # The number of speculative tokens adopted in this request.
# Zero means speculative decoding is disabled for some reasons. # None means specuative decoding is not used.
# TODO: We should maintain this states out of the sequence group. # Zero means speculative decoding is disabled for some reasons.
self.num_speculative_tokens = None # TODO: We should maintain this states out of the sequence group.
num_speculative_tokens: Optional[int] = None
if seq_data is not None and self._token_chunk_size is None: def __post_init__(self):
if is_prompt: if self.seq_data is not None and self.token_chunk_size is None:
self._token_chunk_size = next(iter( if self.is_prompt:
seq_data.values())).get_len() self.token_chunk_size = next(iter(
self.seq_data.values())).get_len()
else: else:
self._token_chunk_size = 1 self.token_chunk_size = 1
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
@ -850,18 +932,26 @@ class SequenceGroupMetadata:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0 if self.prompt_adapter_request else 0
@property def apply_delta(self,
def token_chunk_size(self) -> int: sequence_group_metadata_delta: SequenceGroupMetadataDelta):
"""Return the number of tokens to be processed (chunk size).""" for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
assert self._token_chunk_size is not None self.seq_data[id].apply_delta(delta)
return self._token_chunk_size assert self.request_id == sequence_group_metadata_delta.request_id
self.block_tables = sequence_group_metadata_delta.block_tables
self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
self.do_sample = sequence_group_metadata_delta.do_sample
self.is_prompt = sequence_group_metadata_delta.is_prompt
def finish_step(self) -> None: def finish_step(self) -> None:
assert self.state is not None
assert self.state.current_step < self.state.num_steps assert self.state.current_step < self.state.num_steps
self.state.current_step += 1 self.state.current_step += 1
class SequenceOutput: class SequenceOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The model output associated with a sequence. """The model output associated with a sequence.
Args: Args:
@ -871,16 +961,9 @@ class SequenceOutput:
logprobs: The logprobs of the output token. logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i)) (Token id -> logP(x_i+1 | x_0, ..., x_i))
""" """
parent_seq_id: int
def __init__( output_token: int
self, logprobs: Dict[int, Logprob]
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, Logprob],
) -> None:
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC):
pass pass
class CompletionSequenceGroupOutput(SequenceGroupOutput): class CompletionSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
__metaclass__ = SequenceGroupOutput
"""The model output associated with a completion sequence group.""" """The model output associated with a completion sequence group."""
samples: List[SequenceOutput]
def __init__( # Prompt logprob for each prompt query token.
self, prompt_logprobs: Optional[PromptLogprobs]
samples: List[SequenceOutput],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
# Prompt logprob for each prompt query token.
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, " return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput):
and self.prompt_logprobs == other.prompt_logprobs) and self.prompt_logprobs == other.prompt_logprobs)
class EmbeddingSequenceGroupOutput(SequenceGroupOutput): class EmbeddingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
):
"""The model output associated with an embedding sequence group.""" """The model output associated with an embedding sequence group."""
__metaclass__ = SequenceGroupOutput
def __init__( embeddings: List[int]
self,
embeddings: List[float],
) -> None:
self.embeddings = embeddings
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"EmbeddingSequenceGroupOutput(" return (f"EmbeddingSequenceGroupOutput("
@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings return self.embeddings == other.embeddings
@dataclass class IntermediateTensors(
class IntermediateTensors: msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For all pipeline stages except the last, we need to return the hidden """For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request. contains the hidden states and residuals for a request.
@ -978,8 +1061,10 @@ class IntermediateTensors:
return f"IntermediateTensors(tensors={self.tensors})" return f"IntermediateTensors(tensors={self.tensors})"
@dataclass class SamplerOutput(
class SamplerOutput: msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object, """For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token. each of which contains one possible candidate for the next token.
@ -1000,7 +1085,7 @@ class SamplerOutput:
sampled_token_ids_numpy: Optional[numpy.ndarray] = None sampled_token_ids_numpy: Optional[numpy.ndarray] = None
# Spec decode metrics populated by workers. # Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model. # Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None
@ -1039,12 +1124,14 @@ class SamplerOutput:
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
@dataclass class PoolerOutput(
class PoolerOutput: msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the embedding model.""" """The output from a pooling operation in the embedding model."""
outputs: List[EmbeddingSequenceGroupOutput] outputs: List[EmbeddingSequenceGroupOutput]
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self.outputs[idx] return self.outputs[idx]
@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids(
return seq_ids, request_id_seq_ids_mapping return seq_ids, request_id_seq_ids_mapping
class HiddenStates: class HiddenStates(msgspec.Struct, array_like=True,
omit_defaults=True): # type: ignore[call-arg]
"""Hidden states corresponding to in-progress sequences. """Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step. the target model to the proposer model in the subsequent step.
@ -1091,42 +1179,53 @@ class HiddenStates:
seq_ids are the sequence ids of each entry of the batch seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor""" dimension of the hidden_states tensor"""
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata]
hidden_states: torch.Tensor): hidden_states: torch.Tensor
assert len(seq_group_metadata_list) == len(hidden_states) _seq_ids: List[int] = msgspec.field(default_factory=list)
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
self.hidden_states: torch.Tensor = hidden_states def __post_init__(self):
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
@property
def seq_ids(self) -> List[int]:
return self._seq_ids
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None: hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation.""" """Update hidden states from target model invocation."""
assert len(seq_group_metadata_list) == len(hidden_states) assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states]) self.hidden_states = torch.cat([self.hidden_states, hidden_states])
def prune(self, def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids.""" """Prune to provided list of sequence ids."""
seq_ids = get_all_seq_ids(seq_group_metadata_list) seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self.seq_ids: if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences. # Batch contents changed - prune removed sequences.
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index] self.hidden_states = self.hidden_states[index]
self.seq_ids = seq_ids self._seq_ids = seq_ids
@dataclass class ExecuteModelRequest(
class ExecuteModelRequest: msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""The model execution request, containing CPU metadata only. The LLM """The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch.""" engine should create an instance of this class for each request batch."""
# The sequence group metadata list. # The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata] seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
# Blocks to swap in. List of CPU -> GPU block number. # Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) blocks_to_swap_in: List[Tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to swap out. List of GPU -> CPU block number. # Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) blocks_to_swap_out: List[Tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to copy. Source to dest block. # Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
# Virtual engine ID for pipeline parallel. # Virtual engine ID for pipeline parallel.
virtual_engine: int = 0 virtual_engine: int = 0
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
@ -1138,7 +1237,7 @@ class ExecuteModelRequest:
# The number of forward steps to run. # The number of forward steps to run.
num_steps: int = 1 num_steps: int = 1
# Finished request ids since last step. # Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list) finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding. # The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None last_sampled_token_ids: Optional[torch.Tensor] = None
@ -1148,6 +1247,7 @@ class ExecuteModelRequest:
# steps # steps
assert len(self.seq_group_metadata_list) > 0 assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0] first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
return first_seq_group.state.current_step == 0 return first_seq_group.state.current_step == 0
@property @property
@ -1156,6 +1256,7 @@ class ExecuteModelRequest:
# steps # steps
assert len(self.seq_group_metadata_list) > 0 assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0] first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
num_steps = first_seq_group.state.num_steps num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step current_step = first_seq_group.state.current_step
return num_steps - current_step == 1 return num_steps - current_step == 1
@ -1165,10 +1266,13 @@ class ExecuteModelRequest:
# TODO(will) make this be able to handle batches with variable number of # TODO(will) make this be able to handle batches with variable number of
# steps # steps
assert len(self.seq_group_metadata_list) > 0 assert len(self.seq_group_metadata_list) > 0
return self.seq_group_metadata_list[0].state.current_step state = self.seq_group_metadata_list[0].state
assert state is not None
return state.current_step
def clone( def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata] self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
) -> "ExecuteModelRequest": ) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list.""" """Clone the request with a new sequence group metadata list."""
return ExecuteModelRequest( return ExecuteModelRequest(

View File

@ -1,11 +1,13 @@
from array import array
from itertools import chain, count from itertools import chain, count
from typing import Iterator, List, Tuple from typing import Iterator, List, Tuple
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
SequenceGroupMetadata, get_all_seq_ids) SamplerOutput, SequenceData, SequenceGroupMetadata,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
input sequence. input sequence.
""" """
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids() prompt_token_ids = seq_data.prompt_token_ids_array
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
new_seq_data_dict = { new_seq_data_dict = {
target_seq_id: target_seq_id:
SequenceData( SequenceData(
prompt_token_ids=prompt_token_ids, prompt_token_ids,
output_token_ids=new_output_token_ids, _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids),
), ),
} }
# This is a hack. Technically, spec decoding should compute # This is a hack. Technically, spec decoding should compute

View File

@ -1,7 +1,7 @@
import time import time
from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
import msgspec
import torch import torch
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@dataclass class SpecDecodeWorkerMetrics(
class SpecDecodeWorkerMetrics: msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker. """Dataclass holding metrics emitted from the spec decode worker.
""" """

View File

@ -1,7 +1,7 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import List, Optional, Set, Tuple, Type from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
import torch.distributed import torch.distributed
@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase):
self.cache_engine: List[CacheEngine] self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
def _is_encoder_decoder_model(self): def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model return self.model_config.is_encoder_decoder_model
@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase):
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def _get_cached_seq_group_metadata(
self,
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]],
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
"""Return a list of cached Sequence Group Metadata after updating its
state.
It is used because scheduler only sends delta to workers to reduce
the data payload size. The function also cleans up cache based on
a given `finished_request_ids`.
"""
new_seq_group_metadata_list = []
for metadata_or_delta in seq_group_metadata_list:
request_id = metadata_or_delta.request_id
if request_id not in self._seq_group_metadata_cache:
# The first prefill.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[request_id] = metadata_or_delta
else:
# The first prefill is already cached.
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
self._seq_group_metadata_cache[request_id].apply_delta(
metadata_or_delta)
else:
# If metadata snapshot is sent again, it is
# preempted. Reset the cache because we need to start
# from scratch.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[
request_id] = metadata_or_delta
new_seq_group_metadata_list.append(
self._seq_group_metadata_cache[request_id])
# Clean up finished ids
for finished_id in finished_request_ids:
del self._seq_group_metadata_cache[finished_id]
return new_seq_group_metadata_list
def _execute_model_spmd(
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Optional[List[SamplerOutput]]:
if execute_model_req is not None:
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
execute_model_req.seq_group_metadata_list,
execute_model_req.finished_requests_ids)
execute_model_req.seq_group_metadata_list = (
new_seq_group_metadata_list)
output = super()._execute_model_spmd(execute_model_req,
intermediate_tensors)
return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)