[Core] Factor out common code in SequenceData and Sequence (#8675)

This commit is contained in:
Cyrus Leung 2024-09-21 10:30:39 +08:00 committed by GitHub
parent d4bf085ad0
commit 0455c46ed4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 64 additions and 97 deletions

View File

@ -1,6 +1,5 @@
import itertools import itertools
import random import random
from array import array
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -12,8 +11,7 @@ import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
SequenceData, SequenceGroupMetadata)
from vllm.utils import Counter, is_pin_memory_available from vllm.utils import Counter, is_pin_memory_available
@ -59,9 +57,7 @@ def _do_sample(
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={ seq_data={0: SequenceData.from_seqs([1, 2, 3])},
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
@ -205,9 +201,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
return sampling_params return sampling_params
def create_sequence_data(num_input=3, num_generated=0): def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData( seq_data = SequenceData.from_seqs(
array(VLLM_TOKEN_ID_ARRAY_TYPE, random.choices(range(0, VOCAB_SIZE), k=num_input))
random.choices(range(0, VOCAB_SIZE), k=num_input)))
if num_generated > 0: if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated) k=num_generated)
@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={ seq_data={0: SequenceData.from_seqs([1, 2, 3])},
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={ seq_data={0: SequenceData.from_seqs([1, 2, 3])},
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams( sampling_params=SamplingParams(
temperature=1, temperature=1,
top_k=top_k, top_k=top_k,
@ -699,11 +690,7 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={ seq_data={0: SequenceData.from_seqs([1, 2, 3])},
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
sampling_params=sampling_params[i], sampling_params=sampling_params[i],
block_tables={0: [1]}, block_tables={0: [1]},
)) ))

View File

@ -1,4 +1,3 @@
from array import array
from itertools import count from itertools import count
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput) SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
request_id=str(i), request_id=str(i),
is_prompt=len(cont_token_ids) == 0, is_prompt=len(cont_token_ids) == 0,
seq_data={ seq_data={
i: i: SequenceData.from_seqs(prompt_token_ids[:],
SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]), cont_token_ids[:]),
),
}, },
sampling_params=SamplingParams(temperature=0.0, ), sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]}, block_tables={i: block_allocations[i][:]},

View File

@ -1,5 +1,4 @@
import random import random
from array import array
from typing import Tuple from typing import Tuple
from unittest.mock import patch from unittest.mock import patch
@ -9,8 +8,7 @@ import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={ seq_data={0: SequenceData.from_seqs([1, 2, 3])},
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(temperature=0, sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]), logits_processors=[pick_ith]),
block_tables={0: [1]}, block_tables={0: [1]},

View File

@ -1,10 +1,7 @@
from array import array
import pytest import pytest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
CompletionSequenceGroupOutput, SequenceData,
SequenceOutput) SequenceOutput)
from .core.utils import create_dummy_prompt from .core.utils import create_dummy_prompt
@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
def test_sequence_data_prefill(): def test_sequence_data_prefill():
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4])) seq_data = SequenceData.from_seqs([1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0 assert seq_data.get_num_computed_tokens() == 0
# advance by 2 # advance by 2

View File

@ -1,13 +1,11 @@
import itertools import itertools
from array import array
from typing import List from typing import List
import pytest import pytest
import torch import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_cpu, make_tensor_with_pad from vllm.utils import is_cpu, make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size from vllm.worker.model_runner import _get_graph_batch_size
@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, seq_data = SequenceData.from_seqs(range(seq_len))
range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len) encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData( encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData( seq_data = SequenceData.from_seqs(range(seq_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData( encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData( seq_data = SequenceData.from_seqs(range(seq_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData( encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,

View File

@ -1,4 +1,3 @@
from array import array
from typing import List from typing import List
import pytest import pytest
@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
SequenceData, SequenceGroupMetadata)
from vllm.utils import get_open_port from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, seq_data = SequenceData.from_seqs(range(seq_len))
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len) context_lens.append(context_len)
seq_data = SequenceData( seq_data = SequenceData.from_seqs(range(context_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished. # Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0)
@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, seq_data = SequenceData.from_seqs(range(seq_len))
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size): for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)) seq_data = SequenceData.from_seqs(range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(

View File

@ -1,5 +1,4 @@
import functools import functools
from array import array
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
@ -22,10 +21,6 @@ logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@dataclass(frozen=True) @dataclass(frozen=True)
class InputContext: class InputContext:
@ -130,8 +125,7 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData( dummy_seq_data = SequenceData.from_counts({0: seq_len})
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_multi_modal_data = None dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data return dummy_seq_data, dummy_multi_modal_data

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast from typing import Set, Tuple, Union, cast
@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids. # It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None _mrope_position_delta: Optional[int] = None
@staticmethod
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
if len(counts_by_token) == 0:
return SequenceData.from_seqs([])
arrs = [
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
for token_id, count in counts_by_token.items()
]
return SequenceData(reduce(array.__add__, arrs))
@staticmethod
def from_seqs(
prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None,
) -> "SequenceData":
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids)
if output_token_ids is None:
return SequenceData(prompt_token_ids_arr)
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids)
return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr)
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l" assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l"
@ -370,8 +400,6 @@ class Sequence:
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None
# For decoder-only models, a Sequence is constructed # For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.) # from an LLMInputs instance (the `inputs` arg.)
@ -400,8 +428,7 @@ class Sequence:
f"invalid input {inputs}; did you forget the " f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?") "encoder input prompt fields?")
self.data = SequenceData( self.data = SequenceData.from_seqs(self.prompt_token_ids)
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
@ -422,37 +449,23 @@ class Sequence:
def n_blocks(self) -> int: def n_blocks(self) -> int:
return (self.get_len() + self.block_size - 1) // self.block_size return (self.get_len() + self.block_size - 1) // self.block_size
@property @cached_property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
if self._prompt is not None: # Select decoder or encoder input prompt str, as appropriate
# Reuse precomputed prompt string
return self._prompt
# Select decoder or encoder input prompt str,
# as appropriate
prompt_key: str = ("prompt" prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt") if self.from_decoder_prompt else "encoder_prompt")
# Cache prompt return cast(Optional[str], self.inputs.get(prompt_key))
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
@property @cached_property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
if self._prompt_token_ids is not None: # Select decoder or encoder input prompt token ids, as appropriate
# Reuse precomputed prompt token ids
return self._prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids" prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else if self.from_decoder_prompt else
"encoder_prompt_token_ids") "encoder_prompt_token_ids")
# Cache computed prompt token ids # Cache computed prompt token ids
self._prompt_token_ids = cast(List[int], return cast(List[int], self.inputs.get(prompt_token_ids_key))
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":