[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
This commit is contained in:
parent
200a2ffa6b
commit
ff7ec82c4d
@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
pyzmq
|
||||
msgspec
|
||||
librosa # Required for audio processing
|
||||
soundfile # Required for audio processing
|
||||
gguf == 0.9.1
|
||||
|
@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`.
|
||||
import pytest
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
||||
ENABLE_ARTIFICIAL_PREEMPT)
|
||||
@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, (
|
||||
"tests/basic_correctness/test_preemption.py`")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker_use_ray() -> bool:
|
||||
# When SPMD worker is used, use ray_use_worker=True
|
||||
# to test delta input optimization works with preemption.
|
||||
return envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""Ensure that chunked prefill works with preemption."""
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_seqs=max_num_seqs,
|
||||
worker_use_ray=worker_use_ray,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
@ -79,6 +89,7 @@ def test_preemption(
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""By default, recompute preemption is enabled"""
|
||||
|
||||
@ -89,6 +100,7 @@ def test_preemption(
|
||||
model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
worker_use_ray=worker_use_ray,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
@ -132,6 +144,7 @@ def test_swap(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""Use beam search enables swapping."""
|
||||
example_prompts = example_prompts[:1]
|
||||
@ -144,6 +157,7 @@ def test_swap(
|
||||
dtype=dtype,
|
||||
swap_space=10,
|
||||
disable_log_stats=False,
|
||||
worker_use_ray=worker_use_ray,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
|
||||
beam_width, max_tokens)
|
||||
@ -188,6 +202,7 @@ def test_swap_infeasible(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""Verify infeasible swap request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
@ -204,6 +219,7 @@ def test_swap_infeasible(
|
||||
# decode blocks are not enough to finish.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks,
|
||||
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
|
||||
worker_use_ray=worker_use_ray,
|
||||
) as vllm_model:
|
||||
sampling_params = SamplingParams(n=beam_width,
|
||||
use_beam_search=True,
|
||||
@ -230,6 +246,7 @@ def test_preemption_infeasible(
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""Verify infeasible preemption request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
@ -244,6 +261,7 @@ def test_preemption_infeasible(
|
||||
# ignored instead of hanging forever.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
|
||||
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
|
||||
worker_use_ray=worker_use_ray,
|
||||
) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
|
33
tests/core/test_serialization.py
Normal file
33
tests/core/test_serialization.py
Normal file
@ -0,0 +1,33 @@
|
||||
import msgspec
|
||||
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
from ..spec_decode.utils import create_batch
|
||||
|
||||
|
||||
def test_msgspec_serialization():
|
||||
num_lookahead_slots = 4
|
||||
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
running_queue_size=4)
|
||||
|
||||
encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
|
||||
dec_hook=decode_hook)
|
||||
req = decoder.decode(encoder.encode(execute_model_req))
|
||||
expected = execute_model_req.seq_group_metadata_list
|
||||
actual = req.seq_group_metadata_list
|
||||
assert (len(expected) == len(actual))
|
||||
expected = expected[0]
|
||||
actual = actual[0]
|
||||
|
||||
assert expected.block_tables == actual.block_tables
|
||||
assert expected.is_prompt == actual.is_prompt
|
||||
assert expected.request_id == actual.request_id
|
||||
assert (expected.seq_data[0].prompt_token_ids ==
|
||||
actual.seq_data[0].prompt_token_ids)
|
||||
assert (expected.seq_data[0].output_token_ids ==
|
||||
actual.seq_data[0].output_token_ids)
|
@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
|
||||
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"model, distributed_executor_backend, attention_backend, test_suite", [
|
||||
"model, distributed_executor_backend, attention_backend, "
|
||||
"test_suite", [
|
||||
("facebook/opt-125m", "ray", "", "L4"),
|
||||
("facebook/opt-125m", "mp", "", "L4"),
|
||||
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
|
||||
|
@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
@ -30,6 +32,11 @@ def test_models(
|
||||
model: str,
|
||||
distributed_executor_backend: str,
|
||||
) -> None:
|
||||
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa
|
||||
assert distributed_executor_backend == "ray"
|
||||
# test ray adag
|
||||
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
|
||||
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
|
||||
|
||||
dtype = "half"
|
||||
max_tokens = 5
|
||||
|
@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import random
|
||||
from array import array
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.utils import Counter, is_pin_memory_available
|
||||
|
||||
|
||||
@ -56,7 +58,9 @@ def _do_sample(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input)))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0:
|
||||
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[1, 2, 3]))
|
||||
},
|
||||
sampling_params=sampling_params[i],
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
@ -1,3 +1,4 @@
|
||||
from array import array
|
||||
from itertools import count
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
@ -9,7 +10,8 @@ import torch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceData, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
|
||||
seq_data={
|
||||
i:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids[:],
|
||||
output_token_ids=cont_token_ids[:],
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
|
||||
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
cont_token_ids[:]),
|
||||
),
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0.0, ),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import random
|
||||
from array import array
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -8,7 +9,8 @@ import torch
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
|
@ -1,6 +1,9 @@
|
||||
from array import array
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
CompletionSequenceGroupOutput, SamplerOutput,
|
||||
SequenceData, SequenceOutput)
|
||||
|
||||
from .core.utils import create_dummy_prompt
|
||||
@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
|
||||
|
||||
|
||||
def test_sequence_data_prefill():
|
||||
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
|
||||
assert seq_data.get_num_uncomputed_tokens() == 4
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
# advance by 2
|
||||
|
@ -1,10 +1,12 @@
|
||||
from array import array
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
|
||||
@ -125,10 +127,12 @@ def test_prepare_prompt(
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
range(seq_len)))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_lens.append(encoder_seq_len)
|
||||
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
|
||||
encoder_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -319,10 +323,12 @@ def test_prepare_decode(
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_lens.append(encoder_seq_len)
|
||||
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
|
||||
encoder_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
|
@ -1,3 +1,4 @@
|
||||
from array import array
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
range(seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
context_lens.append(context_len)
|
||||
seq_data = SequenceData(list(range(context_len)))
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
# Append one token ID since prefill is finished.
|
||||
seq_data.append_token_id(1, 0)
|
||||
@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
range(seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(context_len))
|
||||
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
|
@ -1,8 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterRequest(ABC):
|
||||
"""
|
||||
Base class for adapter requests.
|
||||
|
@ -770,8 +770,8 @@ class ParallelConfig:
|
||||
self.tokenizer_pool_config = tokenizer_pool_config
|
||||
self.ray_workers_use_nsight = ray_workers_use_nsight
|
||||
self.placement_group = placement_group
|
||||
|
||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||
|
||||
if worker_use_ray:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
@ -867,6 +867,11 @@ class SchedulerConfig:
|
||||
swapping. However, when the sequence group has multiple sequences
|
||||
(e.g., beam search), recomputation is not currently supported. In
|
||||
such a case, we use swapping instead.
|
||||
send_delta_data: Private API. If used, scheduler sends delta data to
|
||||
workers instead of an entire data. It should be enabled only
|
||||
when SPMD worker architecture is enabled. I.e.,
|
||||
VLLM_USE_RAY_SPMD_WORKER=1
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -879,7 +884,8 @@ class SchedulerConfig:
|
||||
enable_chunked_prefill: bool = False,
|
||||
embedding_mode: Optional[bool] = False,
|
||||
preemption_mode: Optional[str] = None,
|
||||
num_scheduler_steps: int = 1) -> None:
|
||||
num_scheduler_steps: int = 1,
|
||||
send_delta_data: bool = False) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
@ -909,6 +915,7 @@ class SchedulerConfig:
|
||||
self.embedding_mode = embedding_mode
|
||||
self.preemption_mode = preemption_mode
|
||||
self.num_scheduler_steps = num_scheduler_steps
|
||||
self.send_delta_data = send_delta_data
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
@ -12,7 +12,8 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta,
|
||||
SequenceStatus)
|
||||
from vllm.utils import PyObjectCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -363,8 +364,6 @@ class Scheduler:
|
||||
self.num_cumulative_preemption: int = 0
|
||||
|
||||
# Used to cache python objects
|
||||
self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
|
||||
seq_group_metadata_builder)
|
||||
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
|
||||
scheduler_running_outputs_builder)
|
||||
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
|
||||
@ -1048,15 +1047,10 @@ class Scheduler:
|
||||
token_chunk_size = scheduled_seq_group.token_chunk_size
|
||||
seq_group.maybe_set_first_scheduled_time(now)
|
||||
|
||||
seq_group_metadata = self._seq_group_metadata_cache.get_object()
|
||||
seq_group_metadata.seq_data.clear()
|
||||
seq_group_metadata.block_tables.clear()
|
||||
|
||||
# seq_id -> SequenceData
|
||||
seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
# seq_id -> physical block numbers
|
||||
block_tables: Dict[int,
|
||||
List[int]] = seq_group_metadata.block_tables
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
# Encoder associated with SequenceGroup
|
||||
@ -1081,24 +1075,29 @@ class Scheduler:
|
||||
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
||||
|
||||
do_sample = True
|
||||
if seq_group.is_prefill():
|
||||
is_prompt = seq_group.is_prefill()
|
||||
# We should send the metadata to workers when the first prefill
|
||||
# is sent. Subsequent requests could be chunked prefill or decode.
|
||||
is_first_prefill = False
|
||||
if is_prompt:
|
||||
seqs = seq_group.get_seqs()
|
||||
# Prefill has only 1 sequence.
|
||||
assert len(seqs) == 1
|
||||
num_computed_tokens = seqs[0].data.get_num_computed_tokens()
|
||||
is_first_prefill = num_computed_tokens == 0
|
||||
# In the next iteration, all prompt tokens are not computed.
|
||||
# It means the prefill is chunked, and we don't need sampling.
|
||||
# NOTE: We use get_len instead of get_prompt_len because when
|
||||
# a sequence is preempted, prefill includes previous generated
|
||||
# output tokens.
|
||||
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
|
||||
if (token_chunk_size + num_computed_tokens <
|
||||
seqs[0].data.get_len()):
|
||||
do_sample = False
|
||||
|
||||
# It assumes the scheduled_seq_groups is ordered by
|
||||
# prefill < decoding.
|
||||
is_prompt = seq_group.is_prefill()
|
||||
|
||||
seq_group_metadata.__init__(
|
||||
if is_first_prefill or not self.scheduler_config.send_delta_data:
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
@ -1120,6 +1119,21 @@ class Scheduler:
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
)
|
||||
else:
|
||||
# When SPMD mode is enabled, we only send delta data except for
|
||||
# the first request to reduce serialization cost.
|
||||
seq_data_delta = {}
|
||||
for id, data in seq_data.items():
|
||||
seq_data_delta[id] = data.get_delta_and_reset()
|
||||
seq_group_metadata = SequenceGroupMetadataDelta(
|
||||
seq_data_delta,
|
||||
seq_group.request_id,
|
||||
block_tables,
|
||||
is_prompt,
|
||||
do_sample=do_sample,
|
||||
token_chunk_size=token_chunk_size,
|
||||
computed_block_nums=common_computed_block_nums,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# Now that the batch has been created, we can assume all blocks in the
|
||||
@ -1130,8 +1144,6 @@ class Scheduler:
|
||||
self.block_manager.mark_blocks_as_computed(
|
||||
scheduled_seq_group.seq_group)
|
||||
|
||||
self._seq_group_metadata_cache.reset()
|
||||
|
||||
scheduler_time = time.perf_counter() - scheduler_start_time
|
||||
# Add this to scheduler time to all the sequences that are currently
|
||||
# running. This will help estimate if the scheduler is a significant
|
||||
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
|
||||
Union)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig,
|
||||
@ -905,6 +906,8 @@ class EngineArgs:
|
||||
embedding_mode=model_config.embedding_mode,
|
||||
preemption_mode=self.preemption_mode,
|
||||
num_scheduler_steps=self.num_scheduler_steps,
|
||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
and parallel_config.use_ray),
|
||||
)
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
|
@ -224,7 +224,6 @@ class LLMEngine:
|
||||
cache_config.enable_prefix_caching,
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
|
27
vllm/executor/msgspec_utils.py
Normal file
27
vllm/executor/msgspec_utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
from array import array
|
||||
from typing import Any, Type
|
||||
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
||||
|
||||
|
||||
def encode_hook(obj: Any) -> Any:
|
||||
"""Custom msgspec enc hook that supports array types.
|
||||
|
||||
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
||||
"""
|
||||
if isinstance(obj, array):
|
||||
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
|
||||
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
|
||||
f"Given array has a type code of {obj.typecode}.")
|
||||
return obj.tobytes()
|
||||
|
||||
|
||||
def decode_hook(type: Type, obj: Any) -> Any:
|
||||
"""Custom msgspec dec hook that supports array types.
|
||||
|
||||
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
||||
"""
|
||||
if type is array:
|
||||
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||||
deserialized.frombytes(obj)
|
||||
return deserialized
|
@ -4,9 +4,12 @@ from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import msgspec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.msgspec_utils import encode_hook
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
self.output_decoder = msgspec.msgpack.Decoder(
|
||||
Optional[List[SamplerOutput]])
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
ray_remote_kwargs)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
outputs = ray.get(self.forward_dag.execute(execute_model_req))
|
||||
return outputs[0]
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
||||
output = self.output_decoder.decode(outputs[0])
|
||||
return output
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
|
||||
|
||||
dag_future = await self.forward_dag.execute_async(execute_model_req)
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
dag_future = await self.forward_dag.execute_async(serialized_data)
|
||||
outputs = await dag_future
|
||||
return outputs[0]
|
||||
return self.output_decoder.decode(outputs[0])
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
|
@ -1,6 +1,9 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
@ -24,6 +27,10 @@ try:
|
||||
# that thread.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
|
||||
dec_hook=decode_hook)
|
||||
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
|
||||
def get_node_ip(self) -> str:
|
||||
return get_ip()
|
||||
|
||||
@ -33,16 +40,26 @@ try:
|
||||
return node_id, gpu_ids
|
||||
|
||||
def execute_model_spmd(
|
||||
self, req_or_tuple: Union[ExecuteModelRequest,
|
||||
Tuple[ExecuteModelRequest,
|
||||
IntermediateTensors]]):
|
||||
self, req_or_tuple: Union[bytes,
|
||||
Tuple[bytes,
|
||||
Optional[IntermediateTensors]]]
|
||||
) -> bytes:
|
||||
"""Execute model in SPMD fashion: used only when SPMD worker and
|
||||
compiled DAG are both enabled.
|
||||
|
||||
Args:
|
||||
req_or_tuple: The request to execute the model, or a tuple
|
||||
containing the request and intermediate tensors.
|
||||
req_or_tuple: A request or a tuple containing the
|
||||
request and intermediate tensors. Intermediate tensors are
|
||||
None unless if it is provided because it is > 0 pipeline
|
||||
stage. The request is serialized by msgspec.
|
||||
"""
|
||||
if isinstance(req_or_tuple, bytes):
|
||||
serialized_req, intermediate_tensors = req_or_tuple, None
|
||||
else:
|
||||
serialized_req, intermediate_tensors = req_or_tuple
|
||||
|
||||
execute_model_req = self.input_decoder.decode(serialized_req)
|
||||
|
||||
# TODO(swang): This is needed right now because Ray aDAG executes
|
||||
# on a background thread, so we need to reset torch's current
|
||||
# device.
|
||||
@ -51,16 +68,14 @@ try:
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
if isinstance(req_or_tuple, tuple):
|
||||
execute_model_req, intermediate_tensors = req_or_tuple
|
||||
else:
|
||||
execute_model_req = req_or_tuple
|
||||
intermediate_tensors = None
|
||||
|
||||
output = self.worker._execute_model_spmd(execute_model_req,
|
||||
intermediate_tensors)
|
||||
# Pipeline model request and output to the next pipeline stage.
|
||||
if isinstance(output, IntermediateTensors):
|
||||
return execute_model_req, output
|
||||
output = serialized_req, output
|
||||
else:
|
||||
output = self.output_encoder.encode(output)
|
||||
|
||||
return output
|
||||
|
||||
ray_import_err = None
|
||||
|
@ -1,4 +1,5 @@
|
||||
import functools
|
||||
from array import array
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
|
||||
@ -21,6 +22,10 @@ logger = init_logger(__name__)
|
||||
|
||||
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
|
||||
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
|
||||
# We cannot import it here because of circular dependencies.
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputContext:
|
||||
@ -118,7 +123,8 @@ class InputRegistry:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
dummy_seq_data = SequenceData([0] * seq_len)
|
||||
dummy_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
|
||||
dummy_multi_modal_data = None
|
||||
|
||||
return dummy_seq_data, dummy_multi_modal_data
|
||||
|
@ -1,12 +1,15 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.adapter_commons.request import AdapterRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRARequest(AdapterRequest):
|
||||
class LoRARequest(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""
|
||||
Request for a LoRA adapter.
|
||||
|
||||
@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest):
|
||||
lora_int_id must be globally unique for a given adapter.
|
||||
This is currently not enforced in vLLM.
|
||||
"""
|
||||
__metaclass__ = AdapterRequest
|
||||
|
||||
lora_name: str
|
||||
lora_int_id: int
|
||||
lora_path: str = ""
|
||||
lora_local_path: Optional[str] = field(default=None, repr=False)
|
||||
lora_local_path: Optional[str] = msgspec.field(default=None)
|
||||
long_lora_max_len: Optional[int] = None
|
||||
__hash__ = AdapterRequest.__hash__
|
||||
|
||||
def __post_init__(self):
|
||||
if 'lora_local_path' in self.__dict__:
|
||||
if 'lora_local_path' in self.__struct_fields__:
|
||||
warnings.warn(
|
||||
"The 'lora_local_path' attribute is deprecated "
|
||||
"and will be removed in a future version. "
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Minimal implementation of BlipVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from array import array
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
@ -53,8 +54,10 @@ def dummy_seq_data_for_blip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from array import array
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.opt import OPTModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SamplerOutput, SequenceData)
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from array import array
|
||||
from functools import cached_property
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict)
|
||||
@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SamplerOutput, SequenceData)
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from array import array
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
@ -53,8 +54,10 @@ def dummy_seq_data_for_clip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from array import array
|
||||
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import (cached_get_image_processor,
|
||||
cached_get_tokenizer)
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SamplerOutput, SequenceData)
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import merge_multimodal_embeddings
|
||||
@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
|
||||
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||
image_feature_size = get_max_fuyu_image_tokens(ctx)
|
||||
|
||||
image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
|
||||
token_ids = image_token_ids * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
image_token_ids = (
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from array import array
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import (cached_get_image_processor,
|
||||
cached_get_tokenizer)
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SamplerOutput, SequenceData)
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
|
||||
@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||
|
||||
|
||||
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
|
||||
token_ids = [0] * seq_len
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
within a vision language model."""
|
||||
|
||||
import math
|
||||
from array import array
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * image_feature_size
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
|
@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
||||
is_pin_memory_available, make_tensor_with_pad,
|
||||
@ -505,9 +506,11 @@ class SamplingTensors:
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||
prompt_tokens.extend(
|
||||
array('l') for _ in range(prefill_len))
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||||
for _ in range(prefill_len))
|
||||
output_tokens.extend(
|
||||
array('l') for _ in range(prefill_len))
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||||
for _ in range(prefill_len))
|
||||
if seq_group.do_sample:
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
|
@ -1,15 +1,18 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
class PoolingParams:
|
||||
|
||||
class PoolingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""Pooling parameters for pooling.
|
||||
|
||||
Attributes:
|
||||
additional_data: Any additional data needed for pooling.
|
||||
"""
|
||||
|
||||
def __init__(self, additional_data: Optional[Any] = None):
|
||||
self.additional_data = additional_data
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
|
@ -1,13 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
import msgspec
|
||||
|
||||
from vllm.adapter_commons.request import AdapterRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterRequest(AdapterRequest):
|
||||
class PromptAdapterRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
frozen=True): # type: ignore[call-arg]
|
||||
"""
|
||||
Request for a Prompt adapter.
|
||||
"""
|
||||
__metaclass__ = AdapterRequest
|
||||
|
||||
prompt_adapter_name: str
|
||||
prompt_adapter_id: int
|
||||
|
@ -2,10 +2,10 @@
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
class SamplingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True): # type: ignore[call-arg]
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
Overall, we follow the sampling parameters from the OpenAI text completion
|
||||
@ -112,87 +116,73 @@ class SamplingParams:
|
||||
(i.e., no truncation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: int = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
if 0 < temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||
temperature, _MAX_TEMP, _MAX_TEMP)
|
||||
temperature = max(temperature, _MAX_TEMP)
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
if seed == -1:
|
||||
self.seed = None
|
||||
else:
|
||||
self.seed = seed
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
if stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
if stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.min_tokens = min_tokens
|
||||
self.logprobs = 1 if logprobs is True else logprobs
|
||||
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
|
||||
n: int = 1
|
||||
best_of: Optional[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
seed: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
length_penalty: float = 1.0
|
||||
early_stopping: Union[bool, str] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
ignore_eos: bool = False
|
||||
max_tokens: Optional[int] = 16
|
||||
min_tokens: int = 0
|
||||
logprobs: Optional[int] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
self.detokenize = detokenize
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
detokenize: bool = True
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
# Optional[List[LogitsProcessor]] type. We use Any here because
|
||||
# Optional[List[LogitsProcessor]] type is not supported by msgspec.
|
||||
logits_processors: Optional[Any] = None
|
||||
include_stop_str_in_output: bool = False
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
||||
|
||||
# The below fields are not supposed to be used as an input.
|
||||
# They are set in post_init.
|
||||
output_text_buffer_length: int = 0
|
||||
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.best_of = self.best_of or self.n
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
||||
self.temperature = max(self.temperature, _MAX_TEMP)
|
||||
if self.seed == -1:
|
||||
self.seed = None
|
||||
else:
|
||||
self.seed = self.seed
|
||||
if self.stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
else:
|
||||
self.stop = list(self.stop)
|
||||
if self.stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(self.stop_token_ids)
|
||||
self.logprobs = 1 if self.logprobs is True else self.logprobs
|
||||
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
|
||||
self.prompt_logprobs)
|
||||
|
||||
# Number of characters to hold back for stop string evaluation
|
||||
# until sequence is finished.
|
||||
if self.stop and not include_stop_str_in_output:
|
||||
if self.stop and not self.include_stop_str_in_output:
|
||||
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
||||
else:
|
||||
self.output_text_buffer_length = 0
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
@ -206,11 +196,12 @@ class SamplingParams:
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
# eos_token_id is added to this by the engine
|
||||
self.all_stop_token_ids = set(self.stop_token_ids)
|
||||
self._all_stop_token_ids = set(self.stop_token_ids)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
assert isinstance(self.best_of, int)
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
@ -257,6 +248,7 @@ class SamplingParams:
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
assert isinstance(self.stop, list)
|
||||
if any(not stop_str for stop_str in self.stop):
|
||||
raise ValueError("stop cannot contain an empty string.")
|
||||
if self.stop and not self.detokenize:
|
||||
@ -290,6 +282,7 @@ class SamplingParams:
|
||||
"default value of 1.0 when not using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
assert isinstance(self.best_of, int)
|
||||
if self.best_of > 1:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
@ -303,7 +296,7 @@ class SamplingParams:
|
||||
if model_eos_token_id is not None:
|
||||
# Add the eos token id into the sampling_params to support
|
||||
# min_tokens processing.
|
||||
self.all_stop_token_ids.add(model_eos_token_id)
|
||||
self._all_stop_token_ids.add(model_eos_token_id)
|
||||
|
||||
# Update eos_token_id for generation
|
||||
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
||||
@ -315,7 +308,7 @@ class SamplingParams:
|
||||
# purposes.
|
||||
eos_ids.discard(model_eos_token_id)
|
||||
if eos_ids:
|
||||
self.all_stop_token_ids.update(eos_ids)
|
||||
self._all_stop_token_ids.update(eos_ids)
|
||||
if not self.ignore_eos:
|
||||
eos_ids.update(self.stop_token_ids)
|
||||
self.stop_token_ids = list(eos_ids)
|
||||
@ -330,6 +323,10 @@ class SamplingParams:
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
@property
|
||||
def all_stop_token_ids(self) -> Set[int]:
|
||||
return self._all_stop_token_ids
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy excluding LogitsProcessor objects.
|
||||
|
||||
|
372
vllm/sequence.py
372
vllm/sequence.py
@ -4,10 +4,11 @@ import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||
Union, cast)
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Union, cast)
|
||||
|
||||
import msgspec
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
from vllm.multimodal.base import MultiModalDataDict
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
|
||||
# We use dataclass for now because it is used for
|
||||
# openai server output, and msgspec is not serializable.
|
||||
# TODO(sang): Fix it.
|
||||
@dataclass
|
||||
class Logprob:
|
||||
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
||||
@ -112,7 +118,23 @@ class RequestMetrics:
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
|
||||
class SequenceData:
|
||||
class SequenceDataDelta(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Delta SequenceData to send to workers per step."""
|
||||
# A new token to be appended to existing SequenceData.
|
||||
new_output_token_ids: List[int]
|
||||
# Overwriting existing `cumulative_logprob`
|
||||
new_cumulative_logprob: float
|
||||
# Overwriting existing `num_computed_tokens`.
|
||||
new_num_computed_tokens: int
|
||||
# Overwriting existing `stage`.
|
||||
new_stage: SequenceStage
|
||||
|
||||
|
||||
class SequenceData(msgspec.Struct,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Data associated with a sequence.
|
||||
|
||||
Args:
|
||||
@ -125,40 +147,57 @@ class SequenceData:
|
||||
output_token_ids: The token IDs of the output.
|
||||
cumulative_logprob: The cumulative log probability of the output.
|
||||
"""
|
||||
# NOTE: we cannot use Union[List, array] because msgspec cannot support
|
||||
# union of 2 list types.
|
||||
_prompt_token_ids: array
|
||||
_output_token_ids: array = msgspec.field(
|
||||
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
output_token_ids: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
self._prompt_token_ids = array('l', prompt_token_ids)
|
||||
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
|
||||
self._output_token_ids = array(
|
||||
'l', output_token_ids if output_token_ids is not None else [])
|
||||
|
||||
self.cumulative_logprob = 0.0
|
||||
### The below fields should not be passed as an argument ###
|
||||
_cumulative_logprob: float = 0.0
|
||||
_prompt_token_ids_tuple: Tuple[int,
|
||||
...] = msgspec.field(default_factory=tuple)
|
||||
# The number of tokens that are computed (that run against the model).
|
||||
self._num_computed_tokens = 0
|
||||
self._stage: SequenceStage = SequenceStage.PREFILL
|
||||
_num_computed_tokens: int = 0
|
||||
_stage: SequenceStage = SequenceStage.PREFILL
|
||||
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
# It is used to get delta input. It is reset when `get_delta_and_reset`
|
||||
# is called.
|
||||
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
assert self._output_token_ids.typecode == "l"
|
||||
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
|
||||
self._prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
def _update_cached_all_tokens(self):
|
||||
assert isinstance(self._prompt_token_ids, array)
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
|
||||
self._output_token_ids)
|
||||
|
||||
@property
|
||||
def cumulative_logprob(self) -> float:
|
||||
return self._cumulative_logprob
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> Tuple[int, ...]:
|
||||
return self._prompt_token_ids_tuple
|
||||
|
||||
@prompt_token_ids.setter
|
||||
def prompt_token_ids(self, new_prompt_token_ids) -> None:
|
||||
self._prompt_token_ids = array('l', new_prompt_token_ids)
|
||||
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def prompt_token_ids_array(self) -> array:
|
||||
"""Return the prompt token ids in array type.
|
||||
|
||||
Note that the array is in "I" type, and it is not compatible
|
||||
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
|
||||
"""
|
||||
return self._prompt_token_ids
|
||||
|
||||
@property
|
||||
@ -166,18 +205,26 @@ class SequenceData:
|
||||
return tuple(self._output_token_ids)
|
||||
|
||||
@output_token_ids.setter
|
||||
def output_token_ids(self, new_output_token_ids) -> None:
|
||||
self._output_token_ids = array('l', new_output_token_ids)
|
||||
def output_token_ids(self, new_output_token_ids: List[int]) -> None:
|
||||
self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
new_output_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
@property
|
||||
def output_token_ids_array(self) -> array:
|
||||
"""Return the prompt token ids in array type.
|
||||
|
||||
Note that the array is in "I" type, and it is not compatible
|
||||
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
|
||||
"""
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
return self._output_token_ids
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._new_appended_tokens.append(token_id)
|
||||
self._cached_all_token_ids.append(token_id)
|
||||
self.cumulative_logprob += logprob
|
||||
self._cumulative_logprob += logprob
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self._output_token_ids) + len(self._prompt_token_ids)
|
||||
@ -222,6 +269,7 @@ class SequenceData:
|
||||
"""
|
||||
self._num_computed_tokens = 0
|
||||
self._stage = SequenceStage.PREFILL
|
||||
self._new_appended_tokens = []
|
||||
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
"""Return the number of prefill tokens that are not computed."""
|
||||
@ -241,6 +289,21 @@ class SequenceData:
|
||||
def get_output_token_ids(self) -> Tuple[int, ...]:
|
||||
return self.output_token_ids
|
||||
|
||||
def get_delta_and_reset(self) -> SequenceDataDelta:
|
||||
delta = SequenceDataDelta(self._new_appended_tokens,
|
||||
self._cumulative_logprob,
|
||||
self.get_num_computed_tokens(), self.stage)
|
||||
# Reset delta state.
|
||||
self._new_appended_tokens = []
|
||||
return delta
|
||||
|
||||
def apply_delta(self, delta: SequenceDataDelta):
|
||||
self._num_computed_tokens = delta.new_num_computed_tokens
|
||||
self._cumulative_logprob = delta.new_cumulative_logprob
|
||||
self._stage = delta.new_stage
|
||||
self._output_token_ids.extend(delta.new_output_token_ids)
|
||||
self._cached_all_token_ids.extend(delta.new_output_token_ids)
|
||||
|
||||
@property
|
||||
def stage(self) -> SequenceStage:
|
||||
return self._stage
|
||||
@ -248,8 +311,9 @@ class SequenceData:
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt_token_ids={self._prompt_token_ids}, "
|
||||
f"output_token_ids={self._output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob})")
|
||||
f"output_token_ids={self.output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||
f"get_num_computed_tokens={self.get_num_computed_tokens()}")
|
||||
|
||||
|
||||
class Sequence:
|
||||
@ -325,7 +389,8 @@ class Sequence:
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
|
||||
self.data = SequenceData(self.prompt_token_ids)
|
||||
self.data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
@ -490,8 +555,8 @@ class Sequence:
|
||||
f"num_blocks={self.n_blocks}, ")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupState:
|
||||
class SequenceGroupState(msgspec.Struct,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Mutable state tied to a specific sequence group"""
|
||||
|
||||
# for multi-step decoding
|
||||
@ -647,14 +712,19 @@ class SequenceGroup:
|
||||
if self.sampling_params and self.sampling_params.use_beam_search:
|
||||
# For beam search, maximally there will always be `best_of` beam
|
||||
# candidates running in the future.
|
||||
return self.sampling_params.best_of
|
||||
best_of = self.sampling_params.best_of
|
||||
assert isinstance(best_of, int)
|
||||
return best_of
|
||||
else:
|
||||
if (self.sampling_params
|
||||
and self.sampling_params.best_of > self.num_seqs()):
|
||||
if self.sampling_params:
|
||||
best_of = self.sampling_params.best_of
|
||||
assert isinstance(best_of, int)
|
||||
if best_of > self.num_seqs():
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
return self.sampling_params.best_of
|
||||
# generation stage, we will have `best_of` sequences
|
||||
# running.
|
||||
return best_of
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# that are not finished yet.
|
||||
return self.num_unfinished_seqs()
|
||||
@ -757,7 +827,32 @@ class SequenceGroup:
|
||||
f"num_seqs={len(self.seqs)})")
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
class SequenceGroupMetadataDelta(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Delta of SequenceGroupMetadata.
|
||||
|
||||
After sending the first SequenceGroupMetadata, vLLM scheduler
|
||||
only sends delta to reduce the data payload size.
|
||||
"""
|
||||
seq_data_delta: Dict[int, SequenceDataDelta]
|
||||
request_id: str
|
||||
block_tables: Dict[int, List[int]]
|
||||
is_prompt: bool
|
||||
do_sample: bool = True
|
||||
token_chunk_size: Optional[int] = None
|
||||
computed_block_nums: Optional[List[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
|
||||
|
||||
class SequenceGroupMetadata(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
|
||||
|
||||
Args:
|
||||
@ -789,52 +884,39 @@ class SequenceGroupMetadata:
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
is_prompt: bool,
|
||||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]],
|
||||
do_sample: bool = True,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
token_chunk_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
state: Optional[SequenceGroupState] = None,
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
encoder_seq_data: Optional[SequenceData] = None,
|
||||
cross_block_table: Optional[List[int]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
self.seq_data = seq_data
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
self.pooling_params = pooling_params
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
self.encoder_seq_data = encoder_seq_data
|
||||
self.cross_block_table = cross_block_table
|
||||
self._token_chunk_size = token_chunk_size
|
||||
self.do_sample = do_sample
|
||||
request_id: str
|
||||
is_prompt: bool
|
||||
seq_data: Dict[int, SequenceData]
|
||||
sampling_params: SamplingParams
|
||||
block_tables: Dict[int, List[int]]
|
||||
do_sample: bool = True
|
||||
pooling_params: Optional[PoolingParams] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
computed_block_nums: Optional[List[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||
# doesn't allow to have union of 2 different dicts.
|
||||
multi_modal_data: Optional[Any] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
cross_block_table: Optional[List[int]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
token_chunk_size: Optional[int] = None
|
||||
|
||||
### Stateful fields that are lazily defined. ###
|
||||
# The number of speculative tokens adopted in this request.
|
||||
# None means specuative decoding is not used.
|
||||
# Zero means speculative decoding is disabled for some reasons.
|
||||
# TODO: We should maintain this states out of the sequence group.
|
||||
self.num_speculative_tokens = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
|
||||
if seq_data is not None and self._token_chunk_size is None:
|
||||
if is_prompt:
|
||||
self._token_chunk_size = next(iter(
|
||||
seq_data.values())).get_len()
|
||||
def __post_init__(self):
|
||||
if self.seq_data is not None and self.token_chunk_size is None:
|
||||
if self.is_prompt:
|
||||
self.token_chunk_size = next(iter(
|
||||
self.seq_data.values())).get_len()
|
||||
else:
|
||||
self._token_chunk_size = 1
|
||||
self.token_chunk_size = 1
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
@ -850,18 +932,26 @@ class SequenceGroupMetadata:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def token_chunk_size(self) -> int:
|
||||
"""Return the number of tokens to be processed (chunk size)."""
|
||||
assert self._token_chunk_size is not None
|
||||
return self._token_chunk_size
|
||||
def apply_delta(self,
|
||||
sequence_group_metadata_delta: SequenceGroupMetadataDelta):
|
||||
for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
|
||||
self.seq_data[id].apply_delta(delta)
|
||||
assert self.request_id == sequence_group_metadata_delta.request_id
|
||||
self.block_tables = sequence_group_metadata_delta.block_tables
|
||||
self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
|
||||
self.do_sample = sequence_group_metadata_delta.do_sample
|
||||
self.is_prompt = sequence_group_metadata_delta.is_prompt
|
||||
|
||||
def finish_step(self) -> None:
|
||||
assert self.state is not None
|
||||
assert self.state.current_step < self.state.num_steps
|
||||
self.state.current_step += 1
|
||||
|
||||
|
||||
class SequenceOutput:
|
||||
class SequenceOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
@ -871,16 +961,9 @@ class SequenceOutput:
|
||||
logprobs: The logprobs of the output token.
|
||||
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, Logprob],
|
||||
) -> None:
|
||||
self.parent_seq_id = parent_seq_id
|
||||
self.output_token = output_token
|
||||
self.logprobs = logprobs
|
||||
parent_seq_id: int
|
||||
output_token: int
|
||||
logprobs: Dict[int, Logprob]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||
@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
||||
class CompletionSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
"""The model output associated with a completion sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: List[SequenceOutput],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
samples: List[SequenceOutput]
|
||||
# Prompt logprob for each prompt query token.
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
prompt_logprobs: Optional[PromptLogprobs]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
||||
@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
||||
and self.prompt_logprobs == other.prompt_logprobs)
|
||||
|
||||
|
||||
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
||||
class EmbeddingSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
):
|
||||
"""The model output associated with an embedding sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: List[float],
|
||||
) -> None:
|
||||
self.embeddings = embeddings
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
embeddings: List[int]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"EmbeddingSequenceGroupOutput("
|
||||
@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
||||
return self.embeddings == other.embeddings
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateTensors:
|
||||
class IntermediateTensors(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""For all pipeline stages except the last, we need to return the hidden
|
||||
states and residuals to be sent to the next stage. This data structure
|
||||
contains the hidden states and residuals for a request.
|
||||
@ -978,8 +1061,10 @@ class IntermediateTensors:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
class SamplerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||
each of which contains one possible candidate for the next token.
|
||||
|
||||
@ -1000,7 +1085,7 @@ class SamplerOutput:
|
||||
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
@ -1039,12 +1124,14 @@ class SamplerOutput:
|
||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolerOutput:
|
||||
class PoolerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The output from a pooling operation in the embedding model."""
|
||||
outputs: List[EmbeddingSequenceGroupOutput]
|
||||
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self.outputs[idx]
|
||||
@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids(
|
||||
return seq_ids, request_id_seq_ids_mapping
|
||||
|
||||
|
||||
class HiddenStates:
|
||||
class HiddenStates(msgspec.Struct, array_like=True,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Hidden states corresponding to in-progress sequences.
|
||||
Used in speculative decoding to pass hidden states from
|
||||
the target model to the proposer model in the subsequent step.
|
||||
@ -1091,42 +1179,53 @@ class HiddenStates:
|
||||
seq_ids are the sequence ids of each entry of the batch
|
||||
dimension of the hidden_states tensor"""
|
||||
|
||||
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor):
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
|
||||
self.hidden_states: torch.Tensor = hidden_states
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
hidden_states: torch.Tensor
|
||||
_seq_ids: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
||||
|
||||
@property
|
||||
def seq_ids(self) -> List[int]:
|
||||
return self._seq_ids
|
||||
|
||||
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Update hidden states from target model invocation."""
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
||||
|
||||
def prune(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
"""Prune to provided list of sequence ids."""
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
if seq_ids != self.seq_ids:
|
||||
if seq_ids != self._seq_ids:
|
||||
# Batch contents changed - prune removed sequences.
|
||||
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
self.hidden_states = self.hidden_states[index]
|
||||
self.seq_ids = seq_ids
|
||||
self._seq_ids = seq_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelRequest:
|
||||
class ExecuteModelRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""The model execution request, containing CPU metadata only. The LLM
|
||||
engine should create an instance of this class for each request batch."""
|
||||
# The sequence group metadata list.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]]
|
||||
# Blocks to swap in. List of CPU -> GPU block number.
|
||||
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
|
||||
blocks_to_swap_in: List[Tuple[int,
|
||||
int]] = msgspec.field(default_factory=list)
|
||||
# Blocks to swap out. List of GPU -> CPU block number.
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
|
||||
blocks_to_swap_out: List[Tuple[int,
|
||||
int]] = msgspec.field(default_factory=list)
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
||||
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
|
||||
# Virtual engine ID for pipeline parallel.
|
||||
virtual_engine: int = 0
|
||||
# The number of slots for lookahead decoding.
|
||||
@ -1138,7 +1237,7 @@ class ExecuteModelRequest:
|
||||
# The number of forward steps to run.
|
||||
num_steps: int = 1
|
||||
# Finished request ids since last step.
|
||||
finished_requests_ids: List[str] = field(default_factory=list)
|
||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
@ -1148,6 +1247,7 @@ class ExecuteModelRequest:
|
||||
# steps
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
first_seq_group = self.seq_group_metadata_list[0]
|
||||
assert first_seq_group.state is not None
|
||||
return first_seq_group.state.current_step == 0
|
||||
|
||||
@property
|
||||
@ -1156,6 +1256,7 @@ class ExecuteModelRequest:
|
||||
# steps
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
first_seq_group = self.seq_group_metadata_list[0]
|
||||
assert first_seq_group.state is not None
|
||||
num_steps = first_seq_group.state.num_steps
|
||||
current_step = first_seq_group.state.current_step
|
||||
return num_steps - current_step == 1
|
||||
@ -1165,10 +1266,13 @@ class ExecuteModelRequest:
|
||||
# TODO(will) make this be able to handle batches with variable number of
|
||||
# steps
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
return self.seq_group_metadata_list[0].state.current_step
|
||||
state = self.seq_group_metadata_list[0].state
|
||||
assert state is not None
|
||||
return state.current_step
|
||||
|
||||
def clone(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]]
|
||||
) -> "ExecuteModelRequest":
|
||||
"""Clone the request with a new sequence group metadata list."""
|
||||
return ExecuteModelRequest(
|
||||
|
@ -1,11 +1,13 @@
|
||||
from array import array
|
||||
from itertools import chain, count
|
||||
from typing import Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
|
||||
SamplerOutput, SequenceData, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
input sequence.
|
||||
"""
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_token_ids = seq_data.prompt_token_ids_array
|
||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||
|
||||
new_seq_data_dict = {
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
prompt_token_ids,
|
||||
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
new_output_token_ids),
|
||||
),
|
||||
}
|
||||
# This is a hack. Technically, spec decoding should compute
|
||||
|
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodeWorkerMetrics:
|
||||
class SpecDecodeWorkerMetrics(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""Dataclass holding metrics emitted from the spec decode worker.
|
||||
"""
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.cache_engine: List[CacheEngine]
|
||||
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
|
||||
|
||||
def _get_cached_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]],
|
||||
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
|
||||
"""Return a list of cached Sequence Group Metadata after updating its
|
||||
state.
|
||||
|
||||
It is used because scheduler only sends delta to workers to reduce
|
||||
the data payload size. The function also cleans up cache based on
|
||||
a given `finished_request_ids`.
|
||||
"""
|
||||
new_seq_group_metadata_list = []
|
||||
for metadata_or_delta in seq_group_metadata_list:
|
||||
request_id = metadata_or_delta.request_id
|
||||
if request_id not in self._seq_group_metadata_cache:
|
||||
# The first prefill.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[request_id] = metadata_or_delta
|
||||
else:
|
||||
# The first prefill is already cached.
|
||||
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
|
||||
self._seq_group_metadata_cache[request_id].apply_delta(
|
||||
metadata_or_delta)
|
||||
else:
|
||||
# If metadata snapshot is sent again, it is
|
||||
# preempted. Reset the cache because we need to start
|
||||
# from scratch.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[
|
||||
request_id] = metadata_or_delta
|
||||
|
||||
new_seq_group_metadata_list.append(
|
||||
self._seq_group_metadata_cache[request_id])
|
||||
|
||||
# Clean up finished ids
|
||||
for finished_id in finished_request_ids:
|
||||
del self._seq_group_metadata_cache[finished_id]
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _execute_model_spmd(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if execute_model_req is not None:
|
||||
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.finished_requests_ids)
|
||||
|
||||
execute_model_req.seq_group_metadata_list = (
|
||||
new_seq_group_metadata_list)
|
||||
output = super()._execute_model_spmd(execute_model_req,
|
||||
intermediate_tensors)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user