[2/N] Chunked prefill data update (#3538)

This commit is contained in:
SangBin Cho 2024-03-29 02:06:01 +09:00 committed by GitHub
parent ce567a2926
commit b51c1cc9d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 272 additions and 76 deletions

View File

@ -26,7 +26,9 @@ def main(args: argparse.Namespace):
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
download_dir=args.download_dir)
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)
sampling_params = SamplingParams(
n=args.n,
@ -145,6 +147,16 @@ if __name__ == '__main__':
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',

View File

@ -256,6 +256,8 @@ class VllmRunner:
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
**kwargs,
) -> None:
self.model = LLM(
@ -266,6 +268,8 @@ class VllmRunner:
swap_space=0,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
**kwargs,
)

View File

@ -10,6 +10,10 @@ from vllm.sequence import Logprob, SequenceGroup
from .utils import create_dummy_prompt
def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def test_scheduler_add_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1)
@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
running: List[SequenceGroup] = []
for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
scheduler.add_seq_group(seq_group)
@ -68,7 +72,7 @@ def test_scheduler_schedule_simple():
# Schedule seq groups prompts.
num_tokens = block_size * num_seq_group
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running)
assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_tokens
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
@ -76,7 +80,7 @@ def test_scheduler_schedule_simple():
# Schedule seq groups generation.
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running)
assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_seq_group
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts.
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups generation and preempt seq group b.
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a]
assert get_sequence_groups(out) == [seq_group_a]
assert out.num_batched_tokens == 1
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
scheduler.abort_seq_group("1")
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_b]
assert get_sequence_groups(out) == [seq_group_b]
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
# Schedule seq groups prompts.
_, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
# Schedule seq groups generation.
_, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
# Append 2 more seq group
scheduler.add_seq_group(all_seq_groups[1])
@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
# Only 1 seq group should be scheduled since max_seq_group is 2
# and one is prompting.
_, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]])
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
def test_scheduler_delay_factor():

View File

@ -1,6 +1,7 @@
import pytest
from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
@pytest.fixture
@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3
def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
seq_data.update_num_computed_tokens(2)
assert seq_data.get_num_uncomputed_tokens() == 2
assert seq_data.get_num_computed_tokens() == 2
# advance by 1
seq_data.update_num_computed_tokens(1)
assert seq_data.get_num_uncomputed_tokens() == 1
assert seq_data.get_num_computed_tokens() == 3
# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0

View File

@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
))
seq_data = SequenceData(list(range(prompt_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
expected_selected_token_indices = []
selected_token_start_idx = 0
@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size):
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
))
seq_data = SequenceData(seq_data)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))

View File

@ -533,6 +533,8 @@ class SchedulerConfig:
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
"""
def __init__(
@ -542,6 +544,7 @@ class SchedulerConfig:
max_model_len: int,
use_v2_block_manager: bool = False,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
@ -553,6 +556,7 @@ class SchedulerConfig:
self.max_model_len = max_model_len
self.delay_factor = delay_factor
self.use_v2_block_manager = use_v2_block_manager
self.chunked_prefill_enabled = enable_chunked_prefill
self._verify_args()
def _verify_args(self) -> None:

View File

@ -1,6 +1,7 @@
import enum
import time
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
@ -27,11 +28,24 @@ class PreemptionMode(enum.Enum):
RECOMPUTE = enum.auto()
# seq_group: SequenceGroup to schedule.
# token_chunk_size: The number of prefill tokens to be processed in the next
# step.
@dataclass
class ScheduledSequenceGroup:
# A sequence group that's scheduled.
seq_group: SequenceGroup
# The total chunk size (number of tokens) to process for next iteration.
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
# chunked, it can be smaller than that.
token_chunk_size: int
class SchedulerOutputs:
def __init__(
self,
scheduled_seq_groups: Iterable[SequenceGroup],
scheduled_seq_groups: Iterable[ScheduledSequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int],
@ -39,17 +53,41 @@ class SchedulerOutputs:
blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy
"""A list of sequence groups to be scheduled as a single batch.
Args:
scheduled_seq_groups: A tuple of scheduled sequence group and its
token chunk size.
prompt_run: True if all sequence groups are in prefill phase.
If False, all sequence groups are in decoding phase.
num_batched_tokens: Total number of batched tokens.
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
number.
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
number.
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
ignored_seq_groups: Sequence groups that are going to be ignored.
"""
# A tuple of scheduled sequence group and its chunk size.
self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups
# True if all sequence groups are in prefill phase. If False, all
# sequence groups are in decoding phase.
self.prompt_run: bool = prompt_run
# Total number of batched tokens.
self.num_batched_tokens: int = num_batched_tokens
# Blocks to swap in. Dict of CPU -> GPU block number.
self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in
# Blocks to swap out. Dict of GPU -> CPU block number.
self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out
# Blocks to copy. Source to a list of dest blocks.
self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy
# Sequence groups that are going to be ignored.
self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
self.num_loras = len(self.lora_requests)
self.num_loras: int = len(self.lora_requests)
if self.num_loras > 0:
self._sort_by_lora_ids()
@ -59,13 +97,13 @@ class SchedulerOutputs:
and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool:
self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
key=lambda g:
(g.lora_int_id, g.request_id))
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@property
def lora_requests(self) -> Set[LoRARequest]:
return {g.lora_request for g in self.scheduled_seq_groups}
return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
class Scheduler:
@ -198,11 +236,13 @@ class Scheduler:
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = waiting_seqs[0].get_len()
if num_prompt_tokens > self.prompt_limit:
# get_len includes output tokens if the request has been
# preempted.
num_prefill_tokens = waiting_seqs[0].get_len()
if num_prefill_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}")
f"Input prompt ({num_prefill_tokens} tokens) is too "
f"long and exceeds limit of {self.prompt_limit}")
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@ -215,8 +255,8 @@ class Scheduler:
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds the capacity of block_manager")
f"Input prompt ({num_prefill_tokens} tokens) is too "
f"long and exceeds the capacity of block_manager")
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@ -235,7 +275,7 @@ class Scheduler:
continue
# If the number of batched tokens exceeds the limit, stop.
num_batched_tokens += num_prompt_tokens
num_batched_tokens += num_prefill_tokens
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens):
break
@ -253,8 +293,10 @@ class Scheduler:
self._allocate(seq_group)
self.running.append(seq_group)
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)
scheduled.append(
ScheduledSequenceGroup(
seq_group=seq_group,
token_chunk_size=num_prefill_tokens))
self.waiting.extendleft(leftover_waiting_sequences)
if scheduled or ignored_seq_groups:
@ -352,7 +394,11 @@ class Scheduler:
for seq_group in self.running)
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
scheduled_seq_groups=[
ScheduledSequenceGroup(seq_group=running_group,
token_chunk_size=1)
for running_group in self.running
],
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
@ -371,10 +417,14 @@ class Scheduler:
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in scheduler_outputs.scheduled_seq_groups:
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -393,6 +443,7 @@ class Scheduler:
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
state=seq_group.state,
@ -409,8 +460,9 @@ class Scheduler:
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(seq_group)
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
return seq_group_metadata_list, scheduler_outputs
@ -418,6 +470,7 @@ class Scheduler:
self.block_manager.fork(parent_seq, child_seq)
def free_seq(self, seq: Sequence) -> None:
"""Free a sequence from a block table."""
self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None:
@ -480,7 +533,8 @@ class Scheduler:
assert len(seqs) == 1
for seq in seqs:
seq.status = SequenceStatus.WAITING
self.block_manager.free(seq)
self.free_seq(seq)
seq.reset_state_for_recompute()
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
self.waiting.appendleft(seq_group)

View File

@ -62,6 +62,7 @@ class EngineArgs:
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
def __post_init__(self):
if self.tokenizer is None:
@ -356,6 +357,12 @@ class EngineArgs:
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous'
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
return parser
@classmethod
@ -394,11 +401,14 @@ class EngineArgs:
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.use_v2_block_manager,
self.scheduler_delay_factor)
scheduler_config = SchedulerConfig(
self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.use_v2_block_manager,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,

View File

@ -553,7 +553,10 @@ class LLMEngine:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, outputs in zip(scheduled_seq_groups, output):
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.update_num_computed_tokens(token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
@ -561,7 +564,8 @@ class LLMEngine:
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups:
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
@ -676,17 +680,20 @@ class LLMEngine:
# Number of Tokens.
if prompt_run:
num_prompt_tokens = sum(
len(seq_group.prompt_token_ids)
for seq_group in scheduler_outputs.scheduled_seq_groups)
len(scheduled_seq_group.seq_group.prompt_token_ids)
for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
num_generation_tokens = sum(
seq_group.num_seqs()
for seq_group in scheduler_outputs.scheduled_seq_groups)
scheduled_seq_group.seq_group.num_seqs()
for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
else:
num_generation_tokens = scheduler_outputs.num_batched_tokens
# Latency Timings.
time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups:
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now))

View File

@ -113,6 +113,8 @@ class SequenceData:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
@ -130,6 +132,28 @@ class SequenceData:
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
def get_num_computed_tokens(self) -> int:
"""Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens
def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int:
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
def reset_num_computed_tokens(self) -> None:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted).
"""
self._num_computed_tokens = 0
def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefil tokens that are not computed."""
# we use `get_len()` which includes prompt_len + output_len instead
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
return self.get_len() - self.get_num_computed_tokens()
def get_last_token_id(self) -> int:
if not self.output_token_ids:
return self.prompt_token_ids[-1]
@ -208,6 +232,10 @@ class Sequence:
def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation."""
self.data.reset_num_computed_tokens()
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
@ -430,6 +458,18 @@ class SequenceGroup:
def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
for seq in self.seqs_dict.values():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int:
# All sequences in the group should have the same prompt, so the
# number of unfinished prefill tokens are the same across all
# sequences.
return list(
self.seqs_dict.values())[0].data.get_num_uncomputed_tokens()
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
@ -473,6 +513,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
token_chunk_size: The number of tokens to be processed. None if
chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
multi_modal_data: Multi modal data.
@ -485,6 +527,7 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
@ -499,11 +542,23 @@ class SequenceGroupMetadata:
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size
if self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len()
else:
self._token_chunk_size = 1
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size
class SequenceOutput:
"""The model output associated with a sequence.

View File

@ -150,39 +150,58 @@ class ModelRunner:
subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and computed_block_nums is not None):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = len(prompt_tokens)
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert prefill_end == seq_data.get_len()
prompt_lens.append(prompt_len)
computed_len = 0
# NOTE: This only works for oooooooxxx style attention.
computed_block_nums = seq_group_metadata.computed_block_nums
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums)
context_len = computed_len
else:
prefix_block_tables.append([])
context_len = 0
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert computed_len == 0
# actual prompt lens
context_lens.append(context_len)
context_lens.append(computed_len)
subquery_lens.append(prompt_len - computed_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(
list(range(computed_len, computed_len + len(prompt_tokens))))
input_positions.extend(list(range(computed_len, prefill_end)))
lora_id = seq_group_metadata.lora_int_id
@ -218,7 +237,8 @@ class ModelRunner:
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(computed_len, prompt_len):
for i in range(computed_len, prefill_end):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
@ -331,6 +351,7 @@ class ModelRunner:
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id