From b51c1cc9d2f223cfa3aef1426bced10dfde28dbb Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 29 Mar 2024 02:06:01 +0900 Subject: [PATCH] [2/N] Chunked prefill data update (#3538) --- benchmarks/benchmark_latency.py | 14 +++- tests/conftest.py | 4 ++ tests/core/test_scheduler.py | 22 +++--- tests/test_sequence.py | 24 ++++++- tests/worker/test_model_runner.py | 37 +++++----- vllm/config.py | 4 ++ vllm/core/scheduler.py | 108 ++++++++++++++++++++++-------- vllm/engine/arg_utils.py | 20 ++++-- vllm/engine/llm_engine.py | 21 ++++-- vllm/sequence.py | 55 +++++++++++++++ vllm/worker/model_runner.py | 39 ++++++++--- 11 files changed, 272 insertions(+), 76 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 0f223571..da02493b 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -26,7 +26,9 @@ def main(args: argparse.Namespace): kv_cache_dtype=args.kv_cache_dtype, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, - download_dir=args.download_dir) + enable_chunked_prefill=args.enable_chunked_prefill, + download_dir=args.download_dir, + block_size=args.block_size) sampling_params = SamplingParams( n=args.n, @@ -145,6 +147,16 @@ if __name__ == '__main__': default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument('--block-size', + type=int, + default=16, + help='block size of key/value cache') + parser.add_argument( + '--enable-chunked-prefill', + type=bool, + default=False, + help='If True, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/tests/conftest.py b/tests/conftest.py index 3409f873..cb823893 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -256,6 +256,8 @@ class VllmRunner: dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, + block_size: int = 16, + enable_chunked_prefill: bool = False, **kwargs, ) -> None: self.model = LLM( @@ -266,6 +268,8 @@ class VllmRunner: swap_space=0, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, **kwargs, ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index f40969cf..88c2c37f 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,6 +10,10 @@ from vllm.sequence import Logprob, SequenceGroup from .utils import create_dummy_prompt +def get_sequence_groups(scheduler_output): + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] + + def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig(100, 64, 1) @@ -57,9 +61,9 @@ def test_scheduler_schedule_simple(): cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] # Add seq groups to scheduler. - running: List[SequenceGroup] = [] for i in range(num_seq_group): _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) scheduler.add_seq_group(seq_group) @@ -68,7 +72,7 @@ def test_scheduler_schedule_simple(): # Schedule seq groups prompts. num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) + assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -76,7 +80,7 @@ def test_scheduler_schedule_simple(): # Schedule seq groups generation. seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) + assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] + assert get_sequence_groups(out) == [seq_group_a, seq_group_b] assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups generation and preempt seq group b. seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a] + assert get_sequence_groups(out) == [seq_group_a] assert out.num_batched_tokens == 1 assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort(): # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_b] + assert get_sequence_groups(out) == [seq_group_b] assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -155,11 +159,11 @@ def test_scheduler_max_seqs(): # Schedule seq groups prompts. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) # Schedule seq groups generation. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) # Append 2 more seq group scheduler.add_seq_group(all_seq_groups[1]) @@ -169,7 +173,7 @@ def test_scheduler_max_seqs(): # Only 1 seq group should be scheduled since max_seq_group is 2 # and one is prompting. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) def test_scheduler_delay_factor(): diff --git a/tests/test_sequence.py b/tests/test_sequence.py index bb6bcddf..1dec9281 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,6 +1,7 @@ import pytest -from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput +from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, + SequenceOutput) @pytest.fixture @@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs): sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) assert sampler_output1 == sampler_output2 assert sampler_output1 != sampler_output3 + + +def test_sequence_data_prefill(): + seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) + assert seq_data.get_num_uncomputed_tokens() == 4 + assert seq_data.get_num_computed_tokens() == 0 + # advance by 2 + seq_data.update_num_computed_tokens(2) + assert seq_data.get_num_uncomputed_tokens() == 2 + assert seq_data.get_num_computed_tokens() == 2 + + # advance by 1 + seq_data.update_num_computed_tokens(1) + assert seq_data.get_num_uncomputed_tokens() == 1 + assert seq_data.get_num_computed_tokens() == 3 + + # append tokens and reset, simulating recompute + seq_data.append_token_id(1, logprob=0.0) + seq_data.reset_num_computed_tokens() + assert seq_data.get_num_uncomputed_tokens() == 5 + assert seq_data.get_num_computed_tokens() == 0 diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 930ecad3..5b6f001f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData(seq_data)}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - )) + seq_data = SequenceData(list(range(prompt_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) expected_selected_token_indices = [] selected_token_start_idx = 0 @@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size): prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) seq_data = list(range(prompt_len)) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: SequenceData(seq_data)}, - sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, - )) + seq_data = SequenceData(seq_data) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) input_tokens, input_positions, attn_metadata, _, _, _ = ( model_runner._prepare_decode(seq_group_metadata_list)) diff --git a/vllm/config.py b/vllm/config.py index 5025b046..265cfa56 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -533,6 +533,8 @@ class SchedulerConfig: delay_factor: Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. + enable_chunked_prefill: If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens. """ def __init__( @@ -542,6 +544,7 @@ class SchedulerConfig: max_model_len: int, use_v2_block_manager: bool = False, delay_factor: float = 0.0, + enable_chunked_prefill: bool = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -553,6 +556,7 @@ class SchedulerConfig: self.max_model_len = max_model_len self.delay_factor = delay_factor self.use_v2_block_manager = use_v2_block_manager + self.chunked_prefill_enabled = enable_chunked_prefill self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 85c2fdf7..04e8056a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,7 @@ import enum import time from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig @@ -27,11 +28,24 @@ class PreemptionMode(enum.Enum): RECOMPUTE = enum.auto() +# seq_group: SequenceGroup to schedule. +# token_chunk_size: The number of prefill tokens to be processed in the next +# step. +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int + + class SchedulerOutputs: def __init__( self, - scheduled_seq_groups: Iterable[SequenceGroup], + scheduled_seq_groups: Iterable[ScheduledSequenceGroup], prompt_run: bool, num_batched_tokens: int, blocks_to_swap_in: Dict[int, int], @@ -39,17 +53,41 @@ class SchedulerOutputs: blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], ) -> None: - self.scheduled_seq_groups = scheduled_seq_groups - self.prompt_run = prompt_run - self.num_batched_tokens = num_batched_tokens - self.blocks_to_swap_in = blocks_to_swap_in - self.blocks_to_swap_out = blocks_to_swap_out - self.blocks_to_copy = blocks_to_copy + """A list of sequence groups to be scheduled as a single batch. + + Args: + scheduled_seq_groups: A tuple of scheduled sequence group and its + token chunk size. + prompt_run: True if all sequence groups are in prefill phase. + If False, all sequence groups are in decoding phase. + num_batched_tokens: Total number of batched tokens. + blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block + number. + blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block + number. + blocks_to_copy: Blocks to copy. Source to a list of dest blocks. + ignored_seq_groups: Sequence groups that are going to be ignored. + """ + # A tuple of scheduled sequence group and its chunk size. + self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups + # True if all sequence groups are in prefill phase. If False, all + # sequence groups are in decoding phase. + self.prompt_run: bool = prompt_run + # Total number of batched tokens. + self.num_batched_tokens: int = num_batched_tokens + # Blocks to swap in. Dict of CPU -> GPU block number. + self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in + # Blocks to swap out. Dict of GPU -> CPU block number. + self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out + # Blocks to copy. Source to a list of dest blocks. + self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy + # Sequence groups that are going to be ignored. + self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups + # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) - self.ignored_seq_groups = ignored_seq_groups - self.num_loras = len(self.lora_requests) + self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: self._sort_by_lora_ids() @@ -59,13 +97,13 @@ class SchedulerOutputs: and not self.blocks_to_swap_out and not self.blocks_to_copy) def _sort_by_lora_ids(self) -> bool: - self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, - key=lambda g: - (g.lora_int_id, g.request_id)) + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @property def lora_requests(self) -> Set[LoRARequest]: - return {g.lora_request for g in self.scheduled_seq_groups} + return {g.seq_group.lora_request for g in self.scheduled_seq_groups} class Scheduler: @@ -198,11 +236,13 @@ class Scheduler: assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_prompt_tokens = waiting_seqs[0].get_len() - if num_prompt_tokens > self.prompt_limit: + # get_len includes output tokens if the request has been + # preempted. + num_prefill_tokens = waiting_seqs[0].get_len() + if num_prefill_tokens > self.prompt_limit: logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") + f"Input prompt ({num_prefill_tokens} tokens) is too " + f"long and exceeds limit of {self.prompt_limit}") for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -215,8 +255,8 @@ class Scheduler: break elif can_allocate == AllocStatus.NEVER: logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") + f"Input prompt ({num_prefill_tokens} tokens) is too " + f"long and exceeds the capacity of block_manager") for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -235,7 +275,7 @@ class Scheduler: continue # If the number of batched tokens exceeds the limit, stop. - num_batched_tokens += num_prompt_tokens + num_batched_tokens += num_prefill_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -253,8 +293,10 @@ class Scheduler: self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs - scheduled.append(seq_group) - + scheduled.append( + ScheduledSequenceGroup( + seq_group=seq_group, + token_chunk_size=num_prefill_tokens)) self.waiting.extendleft(leftover_waiting_sequences) if scheduled or ignored_seq_groups: @@ -352,7 +394,11 @@ class Scheduler: for seq_group in self.running) scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=self.running, + scheduled_seq_groups=[ + ScheduledSequenceGroup(seq_group=running_group, + token_chunk_size=1) + for running_group in self.running + ], prompt_run=False, num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, @@ -371,10 +417,14 @@ class Scheduler: # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for seq_group in scheduler_outputs.scheduled_seq_groups: + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) + # seq_id -> SequenceData seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -393,6 +443,7 @@ class Scheduler: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, state=seq_group.state, @@ -409,8 +460,9 @@ class Scheduler: # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution # will crash the vLLM instance / will not retry. - for seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed(seq_group) + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group) return seq_group_metadata_list, scheduler_outputs @@ -418,6 +470,7 @@ class Scheduler: self.block_manager.fork(parent_seq, child_seq) def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: @@ -480,7 +533,8 @@ class Scheduler: assert len(seqs) == 1 for seq in seqs: seq.status = SequenceStatus.WAITING - self.block_manager.free(seq) + self.free_seq(seq) + seq.reset_state_for_recompute() # NOTE: For FCFS, we insert the preempted sequence group to the front # of the waiting queue. self.waiting.appendleft(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 09f90d10..83ef7ca1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -62,6 +62,7 @@ class EngineArgs: image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None scheduler_delay_factor: float = 0.0 + enable_chunked_prefill: bool = False def __post_init__(self): if self.tokenizer is None: @@ -356,6 +357,12 @@ class EngineArgs: default=EngineArgs.scheduler_delay_factor, help='Apply a delay (of delay factor multiplied by previous' 'prompt latency) before scheduling next prompt.') + parser.add_argument( + '--enable-chunked-prefill', + type=bool, + default=False, + help='If True, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens') return parser @classmethod @@ -394,11 +401,14 @@ class EngineArgs: self.tokenizer_pool_type, self.tokenizer_pool_extra_config, ), self.ray_workers_use_nsight) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - self.use_v2_block_manager, - self.scheduler_delay_factor) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.use_v2_block_manager, + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 649cd040..a977a23d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -553,7 +553,10 @@ class LLMEngine: # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for seq_group, outputs in zip(scheduled_seq_groups, output): + for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size + seq_group.update_num_computed_tokens(token_chunk_size) self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -561,7 +564,8 @@ class LLMEngine: # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in scheduled_seq_groups: + for scheduled_seq_group in scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) @@ -676,17 +680,20 @@ class LLMEngine: # Number of Tokens. if prompt_run: num_prompt_tokens = sum( - len(seq_group.prompt_token_ids) - for seq_group in scheduler_outputs.scheduled_seq_groups) + len(scheduled_seq_group.seq_group.prompt_token_ids) + for scheduled_seq_group in + scheduler_outputs.scheduled_seq_groups) num_generation_tokens = sum( - seq_group.num_seqs() - for seq_group in scheduler_outputs.scheduled_seq_groups) + scheduled_seq_group.seq_group.num_seqs() + for scheduled_seq_group in + scheduler_outputs.scheduled_seq_groups) else: num_generation_tokens = scheduler_outputs.num_batched_tokens # Latency Timings. time_last_iters = [] - for seq_group in scheduler_outputs.scheduled_seq_groups: + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group # Time since last token. # (n.b. updates seq_group.metrics.last_token_time) time_last_iters.append(seq_group.get_last_latency(now)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8292e207..a40f38f7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -113,6 +113,8 @@ class SequenceData: self.prompt_token_ids = prompt_token_ids self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 + # The number of tokens that are computed (that run against the model). + self._num_computed_tokens = 0 def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) @@ -130,6 +132,28 @@ class SequenceData: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_num_computed_tokens(self) -> int: + """Return the number of prefill tokens that are already computed.""" + return self._num_computed_tokens + + def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int: + """Update number of tokens computed so far.""" + self._num_computed_tokens += num_new_computed_tokens + + def reset_num_computed_tokens(self) -> None: + """Reset the number of computed tokens from this sequence. It is + supposed to be called when a sequence needs to be started from + the beginning again (e.g., sequence is preempted). + """ + self._num_computed_tokens = 0 + + def get_num_uncomputed_tokens(self) -> int: + """Return the number of prefil tokens that are not computed.""" + # we use `get_len()` which includes prompt_len + output_len instead + # of prompt_len here. This is because during recompute we need to + # prefill for both prompt and output. + return self.get_len() - self.get_num_computed_tokens() + def get_last_token_id(self) -> int: if not self.output_token_ids: return self.prompt_token_ids[-1] @@ -208,6 +232,10 @@ class Sequence: def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size + def reset_state_for_recompute(self): + """Reset the sequence states for recomputation.""" + self.data.reset_num_computed_tokens() + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -430,6 +458,18 @@ class SequenceGroup: def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + def update_num_computed_tokens(self, num_new_computed_tokens: int): + """Update number of tokens computed so far.""" + for seq in self.seqs_dict.values(): + seq.data.update_num_computed_tokens(num_new_computed_tokens) + + def get_num_uncomputed_tokens(self) -> int: + # All sequences in the group should have the same prompt, so the + # number of unfinished prefill tokens are the same across all + # sequences. + return list( + self.seqs_dict.values())[0].data.get_num_uncomputed_tokens() + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) @@ -473,6 +513,8 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + token_chunk_size: The number of tokens to be processed. None if + chunking is not required. state: Internal state tied to this sequence group. lora_request: LoRA request. multi_modal_data: Multi modal data. @@ -485,6 +527,7 @@ class SequenceGroupMetadata: seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, @@ -499,11 +542,23 @@ class SequenceGroupMetadata: self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self._token_chunk_size = token_chunk_size + + if self._token_chunk_size is None: + if is_prompt: + self._token_chunk_size = list(seq_data.values())[0].get_len() + else: + self._token_chunk_size = 1 @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def token_chunk_size(self) -> int: + """Return the number of tokens to be processed (chunk size).""" + return self._token_chunk_size + class SequenceOutput: """The model output associated with a sequence. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f0c98700..31fa5247 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -150,39 +150,58 @@ class ModelRunner: subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and computed_block_nums is not None): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + # We should use get_len here because in case of preemption + # it contains output tokens. + prefill_end = min(seq_data.get_len(), + computed_len + token_chunk_size) + # TODO(sang): Rename it after chunked prefill is introduced. + prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] prompt_len = len(prompt_tokens) + # Right now, the prefill_end is always same as the length of + # sequence. However, once chunked prefill is introduced, this + # assumption can be changed. + assert prefill_end == seq_data.get_len() prompt_lens.append(prompt_len) - computed_len = 0 # NOTE: This only works for oooooooxxx style attention. - computed_block_nums = seq_group_metadata.computed_block_nums if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) - context_len = computed_len else: prefix_block_tables.append([]) - context_len = 0 + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert computed_len == 0 + # actual prompt lens - context_lens.append(context_len) + context_lens.append(computed_len) subquery_lens.append(prompt_len - computed_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend( - list(range(computed_len, computed_len + len(prompt_tokens)))) + input_positions.extend(list(range(computed_len, prefill_end))) lora_id = seq_group_metadata.lora_int_id @@ -218,7 +237,8 @@ class ModelRunner: "Prefix caching is currently not supported with " "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) - for i in range(computed_len, prompt_len): + + for i in range(computed_len, prefill_end): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -331,6 +351,7 @@ class ModelRunner: for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) lora_id = seq_group_metadata.lora_int_id