[Core] Factor out common code in SequenceData
and Sequence
(#8675)
This commit is contained in:
parent
d4bf085ad0
commit
0455c46ed4
@ -1,6 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
from array import array
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
@ -12,8 +11,7 @@ import vllm.envs as envs
|
|||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
SequenceData, SequenceGroupMetadata)
|
|
||||||
from vllm.utils import Counter, is_pin_memory_available
|
from vllm.utils import Counter, is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
@ -59,9 +57,7 @@ def _do_sample(
|
|||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_data={
|
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
|
||||||
},
|
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
@ -205,9 +201,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
def create_sequence_data(num_input=3, num_generated=0):
|
def create_sequence_data(num_input=3, num_generated=0):
|
||||||
seq_data = SequenceData(
|
seq_data = SequenceData.from_seqs(
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||||
random.choices(range(0, VOCAB_SIZE), k=num_input)))
|
|
||||||
if num_generated > 0:
|
if num_generated > 0:
|
||||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||||
k=num_generated)
|
k=num_generated)
|
||||||
@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
|
|||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_data={
|
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
|
||||||
},
|
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_data={
|
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
|
||||||
},
|
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@ -699,11 +690,7 @@ def test_sampler_repetition_penalty_mixed(device: str):
|
|||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_data={
|
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||||
0:
|
|
||||||
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
|
||||||
[1, 2, 3]))
|
|
||||||
},
|
|
||||||
sampling_params=sampling_params[i],
|
sampling_params=sampling_params[i],
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from array import array
|
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Dict, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
CompletionSequenceGroupOutput, Logprob,
|
|
||||||
SequenceData, SequenceGroupMetadata, SequenceOutput)
|
SequenceData, SequenceGroupMetadata, SequenceOutput)
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
|
|||||||
request_id=str(i),
|
request_id=str(i),
|
||||||
is_prompt=len(cont_token_ids) == 0,
|
is_prompt=len(cont_token_ids) == 0,
|
||||||
seq_data={
|
seq_data={
|
||||||
i:
|
i: SequenceData.from_seqs(prompt_token_ids[:],
|
||||||
SequenceData(
|
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
|
|
||||||
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
|
||||||
cont_token_ids[:]),
|
cont_token_ids[:]),
|
||||||
),
|
|
||||||
},
|
},
|
||||||
sampling_params=SamplingParams(temperature=0.0, ),
|
sampling_params=SamplingParams(temperature=0.0, ),
|
||||||
block_tables={i: block_allocations[i][:]},
|
block_tables={i: block_allocations[i][:]},
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import random
|
import random
|
||||||
from array import array
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -9,8 +8,7 @@ import torch
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
SequenceData, SequenceGroupMetadata)
|
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
|
|||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_data={
|
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||||
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
|
|
||||||
},
|
|
||||||
sampling_params=SamplingParams(temperature=0,
|
sampling_params=SamplingParams(temperature=0,
|
||||||
logits_processors=[pick_ith]),
|
logits_processors=[pick_ith]),
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
from array import array
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
|
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
|
||||||
CompletionSequenceGroupOutput, SequenceData,
|
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
|
|
||||||
from .core.utils import create_dummy_prompt
|
from .core.utils import create_dummy_prompt
|
||||||
@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
|
|||||||
|
|
||||||
|
|
||||||
def test_sequence_data_prefill():
|
def test_sequence_data_prefill():
|
||||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
|
seq_data = SequenceData.from_seqs([1, 2, 3, 4])
|
||||||
assert seq_data.get_num_uncomputed_tokens() == 4
|
assert seq_data.get_num_uncomputed_tokens() == 4
|
||||||
assert seq_data.get_num_computed_tokens() == 0
|
assert seq_data.get_num_computed_tokens() == 0
|
||||||
# advance by 2
|
# advance by 2
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from array import array
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
SequenceData, SequenceGroupMetadata)
|
|
||||||
from vllm.utils import is_cpu, make_tensor_with_pad
|
from vllm.utils import is_cpu, make_tensor_with_pad
|
||||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||||
from vllm.worker.model_runner import _get_graph_batch_size
|
from vllm.worker.model_runner import _get_graph_batch_size
|
||||||
@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
|
|||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||||
range(seq_len)))
|
|
||||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||||
encoder_seq_lens.append(encoder_seq_len)
|
encoder_seq_lens.append(encoder_seq_len)
|
||||||
encoder_seq_data = SequenceData(
|
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_data = SequenceData(
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
|
|
||||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||||
encoder_seq_data = SequenceData(
|
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
|
|
||||||
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_data = SequenceData(
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
|
|
||||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||||
encoder_seq_data = SequenceData(
|
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=False,
|
is_prompt=False,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from array import array
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
SequenceData, SequenceGroupMetadata)
|
|
||||||
from vllm.utils import get_open_port
|
from vllm.utils import get_open_port
|
||||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||||
|
|
||||||
@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
|
|||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||||
range(seq_len)))
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
context_len = i % (model_runner.block_size - 1) + 1
|
context_len = i % (model_runner.block_size - 1) + 1
|
||||||
context_lens.append(context_len)
|
context_lens.append(context_len)
|
||||||
seq_data = SequenceData(
|
seq_data = SequenceData.from_seqs(range(context_len))
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
|
|
||||||
seq_data.update_num_computed_tokens(context_len)
|
seq_data.update_num_computed_tokens(context_len)
|
||||||
# Append one token ID since prefill is finished.
|
# Append one token ID since prefill is finished.
|
||||||
seq_data.append_token_id(1, 0)
|
seq_data.append_token_id(1, 0)
|
||||||
@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||||
range(seq_len)))
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
for i in range(prefill_batch_size, batch_size):
|
for i in range(prefill_batch_size, batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
context_len = i % (model_runner.block_size - 1) + 1
|
context_len = i % (model_runner.block_size - 1) + 1
|
||||||
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
|
seq_data = SequenceData.from_seqs(range(context_len))
|
||||||
seq_data = SequenceData(prompt_toks)
|
|
||||||
seq_data.append_token_id(1, 0)
|
seq_data.append_token_id(1, 0)
|
||||||
seq_data.update_num_computed_tokens(context_len)
|
seq_data.update_num_computed_tokens(context_len)
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import functools
|
import functools
|
||||||
from array import array
|
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
|
||||||
@ -22,10 +21,6 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||||
|
|
||||||
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
|
|
||||||
# We cannot import it here because of circular dependencies.
|
|
||||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class InputContext:
|
class InputContext:
|
||||||
@ -130,8 +125,7 @@ class InputRegistry:
|
|||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
dummy_seq_data = SequenceData(
|
dummy_seq_data = SequenceData.from_counts({0: seq_len})
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
|
|
||||||
dummy_multi_modal_data = None
|
dummy_multi_modal_data = None
|
||||||
|
|
||||||
return dummy_seq_data, dummy_multi_modal_data
|
return dummy_seq_data, dummy_multi_modal_data
|
||||||
|
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from array import array
|
from array import array
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import cached_property, reduce
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Tuple, Union, cast
|
from typing import Set, Tuple, Union, cast
|
||||||
@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
|
|||||||
# It is used to compute mrope_position_ids.
|
# It is used to compute mrope_position_ids.
|
||||||
_mrope_position_delta: Optional[int] = None
|
_mrope_position_delta: Optional[int] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
|
||||||
|
if len(counts_by_token) == 0:
|
||||||
|
return SequenceData.from_seqs([])
|
||||||
|
|
||||||
|
arrs = [
|
||||||
|
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
|
||||||
|
for token_id, count in counts_by_token.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
return SequenceData(reduce(array.__add__, arrs))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_seqs(
|
||||||
|
prompt_token_ids: GenericSequence[int],
|
||||||
|
output_token_ids: Optional[GenericSequence[int]] = None,
|
||||||
|
) -> "SequenceData":
|
||||||
|
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
|
prompt_token_ids)
|
||||||
|
|
||||||
|
if output_token_ids is None:
|
||||||
|
return SequenceData(prompt_token_ids_arr)
|
||||||
|
|
||||||
|
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
|
output_token_ids)
|
||||||
|
|
||||||
|
return SequenceData(prompt_token_ids_arr,
|
||||||
|
_output_token_ids=output_token_ids_arr)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert self._prompt_token_ids.typecode == "l"
|
assert self._prompt_token_ids.typecode == "l"
|
||||||
assert self._output_token_ids.typecode == "l"
|
assert self._output_token_ids.typecode == "l"
|
||||||
@ -370,8 +400,6 @@ class Sequence:
|
|||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
self.from_decoder_prompt = from_decoder_prompt
|
self.from_decoder_prompt = from_decoder_prompt
|
||||||
self._prompt: Optional[str] = None
|
|
||||||
self._prompt_token_ids: Optional[List[int]] = None
|
|
||||||
|
|
||||||
# For decoder-only models, a Sequence is constructed
|
# For decoder-only models, a Sequence is constructed
|
||||||
# from an LLMInputs instance (the `inputs` arg.)
|
# from an LLMInputs instance (the `inputs` arg.)
|
||||||
@ -400,8 +428,7 @@ class Sequence:
|
|||||||
f"invalid input {inputs}; did you forget the "
|
f"invalid input {inputs}; did you forget the "
|
||||||
"encoder input prompt fields?")
|
"encoder input prompt fields?")
|
||||||
|
|
||||||
self.data = SequenceData(
|
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
|
|
||||||
self.output_logprobs: SampleLogprobs = []
|
self.output_logprobs: SampleLogprobs = []
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
|
|
||||||
@ -422,37 +449,23 @@ class Sequence:
|
|||||||
def n_blocks(self) -> int:
|
def n_blocks(self) -> int:
|
||||||
return (self.get_len() + self.block_size - 1) // self.block_size
|
return (self.get_len() + self.block_size - 1) // self.block_size
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
if self._prompt is not None:
|
# Select decoder or encoder input prompt str, as appropriate
|
||||||
# Reuse precomputed prompt string
|
|
||||||
return self._prompt
|
|
||||||
|
|
||||||
# Select decoder or encoder input prompt str,
|
|
||||||
# as appropriate
|
|
||||||
prompt_key: str = ("prompt"
|
prompt_key: str = ("prompt"
|
||||||
if self.from_decoder_prompt else "encoder_prompt")
|
if self.from_decoder_prompt else "encoder_prompt")
|
||||||
|
|
||||||
# Cache prompt
|
return cast(Optional[str], self.inputs.get(prompt_key))
|
||||||
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
|
|
||||||
return self._prompt
|
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def prompt_token_ids(self) -> List[int]:
|
def prompt_token_ids(self) -> List[int]:
|
||||||
if self._prompt_token_ids is not None:
|
# Select decoder or encoder input prompt token ids, as appropriate
|
||||||
# Reuse precomputed prompt token ids
|
|
||||||
return self._prompt_token_ids
|
|
||||||
|
|
||||||
# Select decoder or encoder input prompt
|
|
||||||
# token ids, as appropriate
|
|
||||||
prompt_token_ids_key: str = ("prompt_token_ids"
|
prompt_token_ids_key: str = ("prompt_token_ids"
|
||||||
if self.from_decoder_prompt else
|
if self.from_decoder_prompt else
|
||||||
"encoder_prompt_token_ids")
|
"encoder_prompt_token_ids")
|
||||||
|
|
||||||
# Cache computed prompt token ids
|
# Cache computed prompt token ids
|
||||||
self._prompt_token_ids = cast(List[int],
|
return cast(List[int], self.inputs.get(prompt_token_ids_key))
|
||||||
self.inputs.get(prompt_token_ids_key))
|
|
||||||
return self._prompt_token_ids
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user