[Core] Factor out common code in SequenceData
and Sequence
(#8675)
This commit is contained in:
parent
d4bf085ad0
commit
0455c46ed4
@ -1,6 +1,5 @@
|
||||
import itertools
|
||||
import random
|
||||
from array import array
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -12,8 +11,7 @@ import vllm.envs as envs
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import Counter, is_pin_memory_available
|
||||
|
||||
|
||||
@ -59,9 +57,7 @@ def _do_sample(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
@ -205,9 +201,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
return sampling_params
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input)))
|
||||
seq_data = SequenceData.from_seqs(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
@ -699,11 +690,7 @@ def test_sampler_repetition_penalty_mixed(device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
0:
|
||||
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[1, 2, 3]))
|
||||
},
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params[i],
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
@ -1,4 +1,3 @@
|
||||
from array import array
|
||||
from itertools import count
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
CompletionSequenceGroupOutput, Logprob,
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceData, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data={
|
||||
i:
|
||||
SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
|
||||
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
cont_token_ids[:]),
|
||||
),
|
||||
i: SequenceData.from_seqs(prompt_token_ids[:],
|
||||
cont_token_ids[:]),
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0.0, ),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
|
@ -1,5 +1,4 @@
|
||||
import random
|
||||
from array import array
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -9,8 +8,7 @@ import torch
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
||||
},
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
|
@ -1,10 +1,7 @@
|
||||
from array import array
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
CompletionSequenceGroupOutput, SequenceData,
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
|
||||
SequenceOutput)
|
||||
|
||||
from .core.utils import create_dummy_prompt
|
||||
@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
|
||||
|
||||
|
||||
def test_sequence_data_prefill():
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
|
||||
seq_data = SequenceData.from_seqs([1, 2, 3, 4])
|
||||
assert seq_data.get_num_uncomputed_tokens() == 4
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
# advance by 2
|
||||
|
@ -1,13 +1,11 @@
|
||||
import itertools
|
||||
from array import array
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import is_cpu, make_tensor_with_pad
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
range(seq_len)))
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_lens.append(encoder_seq_len)
|
||||
encoder_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
|
||||
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
|
||||
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
|
||||
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
|
@ -1,4 +1,3 @@
|
||||
from array import array
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
range(seq_len)))
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
context_lens.append(context_len)
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
|
||||
seq_data = SequenceData.from_seqs(range(context_len))
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
# Append one token ID since prefill is finished.
|
||||
seq_data.append_token_id(1, 0)
|
||||
@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
range(seq_len)))
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_data = SequenceData.from_seqs(range(context_len))
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
|
@ -1,5 +1,4 @@
|
||||
import functools
|
||||
from array import array
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
|
||||
@ -22,10 +21,6 @@ logger = init_logger(__name__)
|
||||
|
||||
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
|
||||
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
|
||||
# We cannot import it here because of circular dependencies.
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputContext:
|
||||
@ -130,8 +125,7 @@ class InputRegistry:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
dummy_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
|
||||
dummy_seq_data = SequenceData.from_counts({0: seq_len})
|
||||
dummy_multi_modal_data = None
|
||||
|
||||
return dummy_seq_data, dummy_multi_modal_data
|
||||
|
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, reduce
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union, cast
|
||||
@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
|
||||
# It is used to compute mrope_position_ids.
|
||||
_mrope_position_delta: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
|
||||
if len(counts_by_token) == 0:
|
||||
return SequenceData.from_seqs([])
|
||||
|
||||
arrs = [
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
|
||||
for token_id, count in counts_by_token.items()
|
||||
]
|
||||
|
||||
return SequenceData(reduce(array.__add__, arrs))
|
||||
|
||||
@staticmethod
|
||||
def from_seqs(
|
||||
prompt_token_ids: GenericSequence[int],
|
||||
output_token_ids: Optional[GenericSequence[int]] = None,
|
||||
) -> "SequenceData":
|
||||
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
prompt_token_ids)
|
||||
|
||||
if output_token_ids is None:
|
||||
return SequenceData(prompt_token_ids_arr)
|
||||
|
||||
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
output_token_ids)
|
||||
|
||||
return SequenceData(prompt_token_ids_arr,
|
||||
_output_token_ids=output_token_ids_arr)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
assert self._output_token_ids.typecode == "l"
|
||||
@ -370,8 +400,6 @@ class Sequence:
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.from_decoder_prompt = from_decoder_prompt
|
||||
self._prompt: Optional[str] = None
|
||||
self._prompt_token_ids: Optional[List[int]] = None
|
||||
|
||||
# For decoder-only models, a Sequence is constructed
|
||||
# from an LLMInputs instance (the `inputs` arg.)
|
||||
@ -400,8 +428,7 @@ class Sequence:
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
|
||||
self.data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
|
||||
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
@ -422,37 +449,23 @@ class Sequence:
|
||||
def n_blocks(self) -> int:
|
||||
return (self.get_len() + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def prompt(self) -> Optional[str]:
|
||||
if self._prompt is not None:
|
||||
# Reuse precomputed prompt string
|
||||
return self._prompt
|
||||
|
||||
# Select decoder or encoder input prompt str,
|
||||
# as appropriate
|
||||
# Select decoder or encoder input prompt str, as appropriate
|
||||
prompt_key: str = ("prompt"
|
||||
if self.from_decoder_prompt else "encoder_prompt")
|
||||
|
||||
# Cache prompt
|
||||
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
|
||||
return self._prompt
|
||||
return cast(Optional[str], self.inputs.get(prompt_key))
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
if self._prompt_token_ids is not None:
|
||||
# Reuse precomputed prompt token ids
|
||||
return self._prompt_token_ids
|
||||
|
||||
# Select decoder or encoder input prompt
|
||||
# token ids, as appropriate
|
||||
# Select decoder or encoder input prompt token ids, as appropriate
|
||||
prompt_token_ids_key: str = ("prompt_token_ids"
|
||||
if self.from_decoder_prompt else
|
||||
"encoder_prompt_token_ids")
|
||||
|
||||
# Cache computed prompt token ids
|
||||
self._prompt_token_ids = cast(List[int],
|
||||
self.inputs.get(prompt_token_ids_key))
|
||||
return self._prompt_token_ids
|
||||
return cast(List[int], self.inputs.get(prompt_token_ids_key))
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
|
Loading…
x
Reference in New Issue
Block a user