[2/N] Chunked prefill data update (#3538)

This commit is contained in:
SangBin Cho 2024-03-29 02:06:01 +09:00 committed by GitHub
parent ce567a2926
commit b51c1cc9d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 272 additions and 76 deletions

View File

@ -26,7 +26,9 @@ def main(args: argparse.Namespace):
kv_cache_dtype=args.kv_cache_dtype, kv_cache_dtype=args.kv_cache_dtype,
device=args.device, device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight, ray_workers_use_nsight=args.ray_workers_use_nsight,
download_dir=args.download_dir) enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
@ -145,6 +147,16 @@ if __name__ == '__main__':
default="cuda", default="cuda",
choices=["cuda"], choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.') help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument( parser.add_argument(
"--ray-workers-use-nsight", "--ray-workers-use-nsight",
action='store_true', action='store_true',

View File

@ -256,6 +256,8 @@ class VllmRunner:
dtype: str = "half", dtype: str = "half",
disable_log_stats: bool = True, disable_log_stats: bool = True,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
self.model = LLM( self.model = LLM(
@ -266,6 +268,8 @@ class VllmRunner:
swap_space=0, swap_space=0,
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
**kwargs, **kwargs,
) )

View File

@ -10,6 +10,10 @@ from vllm.sequence import Logprob, SequenceGroup
from .utils import create_dummy_prompt from .utils import create_dummy_prompt
def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def test_scheduler_add_seq_group(): def test_scheduler_add_seq_group():
block_size = 4 block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1) scheduler_config = SchedulerConfig(100, 64, 1)
@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler. # Add seq groups to scheduler.
running: List[SequenceGroup] = []
for i in range(num_seq_group): for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
@ -68,7 +72,7 @@ def test_scheduler_schedule_simple():
# Schedule seq groups prompts. # Schedule seq groups prompts.
num_tokens = block_size * num_seq_group num_tokens = block_size * num_seq_group
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running) assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_tokens assert out.num_batched_tokens == num_tokens
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
@ -76,7 +80,7 @@ def test_scheduler_schedule_simple():
# Schedule seq groups generation. # Schedule seq groups generation.
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running) assert set(get_sequence_groups(out)) == set(running)
assert out.num_batched_tokens == num_seq_group assert out.num_batched_tokens == num_seq_group
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts. # Schedule seq groups prompts.
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups generation and preempt seq group b. # Schedule seq groups generation and preempt seq group b.
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a] assert get_sequence_groups(out) == [seq_group_a]
assert out.num_batched_tokens == 1 assert out.num_batched_tokens == 1
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
# Abort seq group a. Re-schedule seq group b prompt with recomputation. # Abort seq group a. Re-schedule seq group b prompt with recomputation.
scheduler.abort_seq_group("1") scheduler.abort_seq_group("1")
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_b] assert get_sequence_groups(out) == [seq_group_b]
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
# Schedule seq groups prompts. # Schedule seq groups prompts.
_, out = scheduler.schedule() _, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
# Schedule seq groups generation. # Schedule seq groups generation.
_, out = scheduler.schedule() _, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
# Append 2 more seq group # Append 2 more seq group
scheduler.add_seq_group(all_seq_groups[1]) scheduler.add_seq_group(all_seq_groups[1])
@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
# Only 1 seq group should be scheduled since max_seq_group is 2 # Only 1 seq group should be scheduled since max_seq_group is 2
# and one is prompting. # and one is prompting.
_, out = scheduler.schedule() _, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
def test_scheduler_delay_factor(): def test_scheduler_delay_factor():

View File

@ -1,6 +1,7 @@
import pytest import pytest
from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
@pytest.fixture @pytest.fixture
@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2 assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3 assert sampler_output1 != sampler_output3
def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
seq_data.update_num_computed_tokens(2)
assert seq_data.get_num_uncomputed_tokens() == 2
assert seq_data.get_num_computed_tokens() == 2
# advance by 1
seq_data.update_num_computed_tokens(1)
assert seq_data.get_num_uncomputed_tokens() == 1
assert seq_data.get_num_computed_tokens() == 3
# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0

View File

@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len)) seq_data = SequenceData(list(range(prompt_len)))
seq_group_metadata_list.append( seq_group_metadata = SequenceGroupMetadata(
SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData(seq_data)}, seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables=block_tables, block_tables=block_tables,
)) )
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 selected_token_start_idx = 0
@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size):
prompt_len = i % (model_runner.block_size - 1) + 1 prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len)) seq_data = list(range(prompt_len))
seq_group_metadata_list.append( seq_data = SequenceData(seq_data)
SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,
seq_data={0: SequenceData(seq_data)}, seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]}, block_tables={0: [1]},
)) )
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _ = ( input_tokens, input_positions, attn_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list)) model_runner._prepare_decode(seq_group_metadata_list))

View File

@ -533,6 +533,8 @@ class SchedulerConfig:
delay_factor: Apply a delay (of delay factor multiplied by previous delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt. prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
""" """
def __init__( def __init__(
@ -542,6 +544,7 @@ class SchedulerConfig:
max_model_len: int, max_model_len: int,
use_v2_block_manager: bool = False, use_v2_block_manager: bool = False,
delay_factor: float = 0.0, delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
) -> None: ) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
@ -553,6 +556,7 @@ class SchedulerConfig:
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.delay_factor = delay_factor self.delay_factor = delay_factor
self.use_v2_block_manager = use_v2_block_manager self.use_v2_block_manager = use_v2_block_manager
self.chunked_prefill_enabled = enable_chunked_prefill
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:

View File

@ -1,6 +1,7 @@
import enum import enum
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
@ -27,11 +28,24 @@ class PreemptionMode(enum.Enum):
RECOMPUTE = enum.auto() RECOMPUTE = enum.auto()
# seq_group: SequenceGroup to schedule.
# token_chunk_size: The number of prefill tokens to be processed in the next
# step.
@dataclass
class ScheduledSequenceGroup:
# A sequence group that's scheduled.
seq_group: SequenceGroup
# The total chunk size (number of tokens) to process for next iteration.
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
# chunked, it can be smaller than that.
token_chunk_size: int
class SchedulerOutputs: class SchedulerOutputs:
def __init__( def __init__(
self, self,
scheduled_seq_groups: Iterable[SequenceGroup], scheduled_seq_groups: Iterable[ScheduledSequenceGroup],
prompt_run: bool, prompt_run: bool,
num_batched_tokens: int, num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
@ -39,17 +53,41 @@ class SchedulerOutputs:
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup],
) -> None: ) -> None:
self.scheduled_seq_groups = scheduled_seq_groups """A list of sequence groups to be scheduled as a single batch.
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens Args:
self.blocks_to_swap_in = blocks_to_swap_in scheduled_seq_groups: A tuple of scheduled sequence group and its
self.blocks_to_swap_out = blocks_to_swap_out token chunk size.
self.blocks_to_copy = blocks_to_copy prompt_run: True if all sequence groups are in prefill phase.
If False, all sequence groups are in decoding phase.
num_batched_tokens: Total number of batched tokens.
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
number.
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
number.
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
ignored_seq_groups: Sequence groups that are going to be ignored.
"""
# A tuple of scheduled sequence group and its chunk size.
self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups
# True if all sequence groups are in prefill phase. If False, all
# sequence groups are in decoding phase.
self.prompt_run: bool = prompt_run
# Total number of batched tokens.
self.num_batched_tokens: int = num_batched_tokens
# Blocks to swap in. Dict of CPU -> GPU block number.
self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in
# Blocks to swap out. Dict of GPU -> CPU block number.
self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out
# Blocks to copy. Source to a list of dest blocks.
self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy
# Sequence groups that are going to be ignored.
self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
self.num_loras = len(self.lora_requests) self.num_loras: int = len(self.lora_requests)
if self.num_loras > 0: if self.num_loras > 0:
self._sort_by_lora_ids() self._sort_by_lora_ids()
@ -59,13 +97,13 @@ class SchedulerOutputs:
and not self.blocks_to_swap_out and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool: def _sort_by_lora_ids(self) -> bool:
self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, self.scheduled_seq_groups = sorted(
key=lambda g: self.scheduled_seq_groups,
(g.lora_int_id, g.request_id)) key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@property @property
def lora_requests(self) -> Set[LoRARequest]: def lora_requests(self) -> Set[LoRARequest]:
return {g.lora_request for g in self.scheduled_seq_groups} return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
class Scheduler: class Scheduler:
@ -198,11 +236,13 @@ class Scheduler:
assert len(waiting_seqs) == 1, ( assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt " "Waiting sequence group should have only one prompt "
"sequence.") "sequence.")
num_prompt_tokens = waiting_seqs[0].get_len() # get_len includes output tokens if the request has been
if num_prompt_tokens > self.prompt_limit: # preempted.
num_prefill_tokens = waiting_seqs[0].get_len()
if num_prefill_tokens > self.prompt_limit:
logger.warning( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prefill_tokens} tokens) is too "
f" and exceeds limit of {self.prompt_limit}") f"long and exceeds limit of {self.prompt_limit}")
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
@ -215,8 +255,8 @@ class Scheduler:
break break
elif can_allocate == AllocStatus.NEVER: elif can_allocate == AllocStatus.NEVER:
logger.warning( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prefill_tokens} tokens) is too "
f" and exceeds the capacity of block_manager") f"long and exceeds the capacity of block_manager")
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
@ -235,7 +275,7 @@ class Scheduler:
continue continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
num_batched_tokens += num_prompt_tokens num_batched_tokens += num_prefill_tokens
if (num_batched_tokens > if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens): self.scheduler_config.max_num_batched_tokens):
break break
@ -253,8 +293,10 @@ class Scheduler:
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(
ScheduledSequenceGroup(
seq_group=seq_group,
token_chunk_size=num_prefill_tokens))
self.waiting.extendleft(leftover_waiting_sequences) self.waiting.extendleft(leftover_waiting_sequences)
if scheduled or ignored_seq_groups: if scheduled or ignored_seq_groups:
@ -352,7 +394,11 @@ class Scheduler:
for seq_group in self.running) for seq_group in self.running)
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running, scheduled_seq_groups=[
ScheduledSequenceGroup(seq_group=running_group,
token_chunk_size=1)
for running_group in self.running
],
prompt_run=False, prompt_run=False,
num_batched_tokens=num_batched_tokens, num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
@ -371,10 +417,14 @@ class Scheduler:
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now) seq_group.maybe_set_first_scheduled_time(now)
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -393,6 +443,7 @@ class Scheduler:
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums, computed_block_nums=common_computed_block_nums,
state=seq_group.state, state=seq_group.state,
@ -409,8 +460,9 @@ class Scheduler:
# batch will have been computed before the next scheduling invocation. # batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution # This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry. # will crash the vLLM instance / will not retry.
for seq_group in scheduler_outputs.scheduled_seq_groups: for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(seq_group) self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
@ -418,6 +470,7 @@ class Scheduler:
self.block_manager.fork(parent_seq, child_seq) self.block_manager.fork(parent_seq, child_seq)
def free_seq(self, seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None:
"""Free a sequence from a block table."""
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
@ -480,7 +533,8 @@ class Scheduler:
assert len(seqs) == 1 assert len(seqs) == 1
for seq in seqs: for seq in seqs:
seq.status = SequenceStatus.WAITING seq.status = SequenceStatus.WAITING
self.block_manager.free(seq) self.free_seq(seq)
seq.reset_state_for_recompute()
# NOTE: For FCFS, we insert the preempted sequence group to the front # NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue. # of the waiting queue.
self.waiting.appendleft(seq_group) self.waiting.appendleft(seq_group)

View File

@ -62,6 +62,7 @@ class EngineArgs:
image_input_shape: Optional[str] = None image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
@ -356,6 +357,12 @@ class EngineArgs:
default=EngineArgs.scheduler_delay_factor, default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous' help='Apply a delay (of delay factor multiplied by previous'
'prompt latency) before scheduling next prompt.') 'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
return parser return parser
@classmethod @classmethod
@ -394,11 +401,14 @@ class EngineArgs:
self.tokenizer_pool_type, self.tokenizer_pool_type,
self.tokenizer_pool_extra_config, self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight) ), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(
self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len,
self.use_v2_block_manager, self.use_v2_block_manager,
self.scheduler_delay_factor) delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,

View File

@ -553,7 +553,10 @@ class LLMEngine:
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, outputs in zip(scheduled_seq_groups, output): for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.update_num_computed_tokens(token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs) self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
@ -561,7 +564,8 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups: for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
@ -676,17 +680,20 @@ class LLMEngine:
# Number of Tokens. # Number of Tokens.
if prompt_run: if prompt_run:
num_prompt_tokens = sum( num_prompt_tokens = sum(
len(seq_group.prompt_token_ids) len(scheduled_seq_group.seq_group.prompt_token_ids)
for seq_group in scheduler_outputs.scheduled_seq_groups) for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
num_generation_tokens = sum( num_generation_tokens = sum(
seq_group.num_seqs() scheduled_seq_group.seq_group.num_seqs()
for seq_group in scheduler_outputs.scheduled_seq_groups) for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
else: else:
num_generation_tokens = scheduler_outputs.num_batched_tokens num_generation_tokens = scheduler_outputs.num_batched_tokens
# Latency Timings. # Latency Timings.
time_last_iters = [] time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
# Time since last token. # Time since last token.
# (n.b. updates seq_group.metrics.last_token_time) # (n.b. updates seq_group.metrics.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now)) time_last_iters.append(seq_group.get_last_latency(now))

View File

@ -113,6 +113,8 @@ class SequenceData:
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.output_token_ids = output_token_ids self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id) self.output_token_ids.append(token_id)
@ -130,6 +132,28 @@ class SequenceData:
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids return self.prompt_token_ids + self.output_token_ids
def get_num_computed_tokens(self) -> int:
"""Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens
def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int:
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
def reset_num_computed_tokens(self) -> None:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted).
"""
self._num_computed_tokens = 0
def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefil tokens that are not computed."""
# we use `get_len()` which includes prompt_len + output_len instead
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
return self.get_len() - self.get_num_computed_tokens()
def get_last_token_id(self) -> int: def get_last_token_id(self) -> int:
if not self.output_token_ids: if not self.output_token_ids:
return self.prompt_token_ids[-1] return self.prompt_token_ids[-1]
@ -208,6 +232,10 @@ class Sequence:
def num_hashed_tokens_of_block(self, logical_idx: int): def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size return logical_idx * self.block_size + self.block_size
def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation."""
self.data.reset_num_computed_tokens()
def _append_logical_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks), block_number=len(self.logical_token_blocks),
@ -430,6 +458,18 @@ class SequenceGroup:
def get_finished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()] return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
for seq in self.seqs_dict.values():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int:
# All sequences in the group should have the same prompt, so the
# number of unfinished prefill tokens are the same across all
# sequences.
return list(
self.seqs_dict.values())[0].data.get_num_uncomputed_tokens()
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status)) return len(self.get_seqs(status))
@ -473,6 +513,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
token_chunk_size: The number of tokens to be processed. None if
chunking is not required.
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
lora_request: LoRA request. lora_request: LoRA request.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
@ -485,6 +527,7 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData], seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], block_tables: Dict[int, List[int]],
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None, computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None, state: Optional[SequenceGroupState] = None,
@ -499,11 +542,23 @@ class SequenceGroupMetadata:
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size
if self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len()
else:
self._token_chunk_size = 1
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size
class SequenceOutput: class SequenceOutput:
"""The model output associated with a sequence. """The model output associated with a sequence.

View File

@ -150,39 +150,58 @@ class ModelRunner:
subquery_lens: List[int] = [] subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1 assert len(seq_ids) == 1
seq_id = seq_ids[0] seq_id = seq_ids[0]
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and computed_block_nums is not None):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = len(prompt_tokens) prompt_len = len(prompt_tokens)
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert prefill_end == seq_data.get_len()
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
computed_len = 0
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
computed_block_nums = seq_group_metadata.computed_block_nums
if computed_block_nums is not None and len( if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None: computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window # Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:] prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
context_len = computed_len
else: else:
prefix_block_tables.append([]) prefix_block_tables.append([])
context_len = 0 # Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert computed_len == 0
# actual prompt lens # actual prompt lens
context_lens.append(context_len) context_lens.append(computed_len)
subquery_lens.append(prompt_len - computed_len) subquery_lens.append(prompt_len - computed_len)
input_tokens.extend(prompt_tokens) input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend( input_positions.extend(list(range(computed_len, prefill_end)))
list(range(computed_len, computed_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id
@ -218,7 +237,8 @@ class ModelRunner:
"Prefix caching is currently not supported with " "Prefix caching is currently not supported with "
"sliding window attention") "sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, prompt_len - self.sliding_window)
for i in range(computed_len, prompt_len):
for i in range(computed_len, prefill_end):
if i < start_idx: if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
@ -331,6 +351,7 @@ class ModelRunner:
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id