[2/N] Chunked prefill data update (#3538)
This commit is contained in:
parent
ce567a2926
commit
b51c1cc9d2
@ -26,7 +26,9 @@ def main(args: argparse.Namespace):
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
device=args.device,
|
||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||
download_dir=args.download_dir)
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
download_dir=args.download_dir,
|
||||
block_size=args.block_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
@ -145,6 +147,16 @@ if __name__ == '__main__':
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=16,
|
||||
help='block size of key/value cache')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument(
|
||||
"--ray-workers-use-nsight",
|
||||
action='store_true',
|
||||
|
@ -256,6 +256,8 @@ class VllmRunner:
|
||||
dtype: str = "half",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16,
|
||||
enable_chunked_prefill: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
@ -266,6 +268,8 @@ class VllmRunner:
|
||||
swap_space=0,
|
||||
disable_log_stats=disable_log_stats,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
block_size=block_size,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,10 @@ from vllm.sequence import Logprob, SequenceGroup
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
running: List[SequenceGroup] = []
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
@ -68,7 +72,7 @@ def test_scheduler_schedule_simple():
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set(running)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@ -76,7 +80,7 @@ def test_scheduler_schedule_simple():
|
||||
|
||||
# Schedule seq groups generation.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set(running)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
|
||||
assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
|
||||
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
|
||||
# Schedule seq groups generation and preempt seq group b.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_a]
|
||||
assert get_sequence_groups(out) == [seq_group_a]
|
||||
assert out.num_batched_tokens == 1
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
|
||||
scheduler.abort_seq_group("1")
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_b]
|
||||
assert get_sequence_groups(out) == [seq_group_b]
|
||||
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
_, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
|
||||
# Schedule seq groups generation.
|
||||
_, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
|
||||
# Append 2 more seq group
|
||||
scheduler.add_seq_group(all_seq_groups[1])
|
||||
@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
|
||||
# Only 1 seq group should be scheduled since max_seq_group is 2
|
||||
# and one is prompting.
|
||||
_, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]])
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
|
||||
|
||||
|
||||
def test_scheduler_delay_factor():
|
||||
|
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput
|
||||
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||
SequenceOutput)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
|
||||
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
|
||||
assert sampler_output1 == sampler_output2
|
||||
assert sampler_output1 != sampler_output3
|
||||
|
||||
|
||||
def test_sequence_data_prefill():
|
||||
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
|
||||
assert seq_data.get_num_uncomputed_tokens() == 4
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
# advance by 2
|
||||
seq_data.update_num_computed_tokens(2)
|
||||
assert seq_data.get_num_uncomputed_tokens() == 2
|
||||
assert seq_data.get_num_computed_tokens() == 2
|
||||
|
||||
# advance by 1
|
||||
seq_data.update_num_computed_tokens(1)
|
||||
assert seq_data.get_num_uncomputed_tokens() == 1
|
||||
assert seq_data.get_num_computed_tokens() == 3
|
||||
|
||||
# append tokens and reset, simulating recompute
|
||||
seq_data.append_token_id(1, logprob=0.0)
|
||||
seq_data.reset_num_computed_tokens()
|
||||
assert seq_data.get_num_uncomputed_tokens() == 5
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
|
@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData(seq_data)},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
))
|
||||
seq_data = SequenceData(list(range(prompt_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: SequenceData(seq_data)},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_data = SequenceData(seq_data)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
input_tokens, input_positions, attn_metadata, _, _, _ = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
|
@ -533,6 +533,8 @@ class SchedulerConfig:
|
||||
delay_factor: Apply a delay (of delay factor multiplied by previous
|
||||
prompt latency) before scheduling next prompt.
|
||||
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
|
||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -542,6 +544,7 @@ class SchedulerConfig:
|
||||
max_model_len: int,
|
||||
use_v2_block_manager: bool = False,
|
||||
delay_factor: float = 0.0,
|
||||
enable_chunked_prefill: bool = False,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@ -553,6 +556,7 @@ class SchedulerConfig:
|
||||
self.max_model_len = max_model_len
|
||||
self.delay_factor = delay_factor
|
||||
self.use_v2_block_manager = use_v2_block_manager
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import enum
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
@ -27,11 +28,24 @@ class PreemptionMode(enum.Enum):
|
||||
RECOMPUTE = enum.auto()
|
||||
|
||||
|
||||
# seq_group: SequenceGroup to schedule.
|
||||
# token_chunk_size: The number of prefill tokens to be processed in the next
|
||||
# step.
|
||||
@dataclass
|
||||
class ScheduledSequenceGroup:
|
||||
# A sequence group that's scheduled.
|
||||
seq_group: SequenceGroup
|
||||
# The total chunk size (number of tokens) to process for next iteration.
|
||||
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
|
||||
# chunked, it can be smaller than that.
|
||||
token_chunk_size: int
|
||||
|
||||
|
||||
class SchedulerOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduled_seq_groups: Iterable[SequenceGroup],
|
||||
scheduled_seq_groups: Iterable[ScheduledSequenceGroup],
|
||||
prompt_run: bool,
|
||||
num_batched_tokens: int,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
@ -39,17 +53,41 @@ class SchedulerOutputs:
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
) -> None:
|
||||
self.scheduled_seq_groups = scheduled_seq_groups
|
||||
self.prompt_run = prompt_run
|
||||
self.num_batched_tokens = num_batched_tokens
|
||||
self.blocks_to_swap_in = blocks_to_swap_in
|
||||
self.blocks_to_swap_out = blocks_to_swap_out
|
||||
self.blocks_to_copy = blocks_to_copy
|
||||
"""A list of sequence groups to be scheduled as a single batch.
|
||||
|
||||
Args:
|
||||
scheduled_seq_groups: A tuple of scheduled sequence group and its
|
||||
token chunk size.
|
||||
prompt_run: True if all sequence groups are in prefill phase.
|
||||
If False, all sequence groups are in decoding phase.
|
||||
num_batched_tokens: Total number of batched tokens.
|
||||
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
|
||||
number.
|
||||
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
|
||||
number.
|
||||
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
|
||||
ignored_seq_groups: Sequence groups that are going to be ignored.
|
||||
"""
|
||||
# A tuple of scheduled sequence group and its chunk size.
|
||||
self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups
|
||||
# True if all sequence groups are in prefill phase. If False, all
|
||||
# sequence groups are in decoding phase.
|
||||
self.prompt_run: bool = prompt_run
|
||||
# Total number of batched tokens.
|
||||
self.num_batched_tokens: int = num_batched_tokens
|
||||
# Blocks to swap in. Dict of CPU -> GPU block number.
|
||||
self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in
|
||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
||||
self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out
|
||||
# Blocks to copy. Source to a list of dest blocks.
|
||||
self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy
|
||||
# Sequence groups that are going to be ignored.
|
||||
self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups
|
||||
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
self.ignored_seq_groups = ignored_seq_groups
|
||||
|
||||
self.num_loras = len(self.lora_requests)
|
||||
self.num_loras: int = len(self.lora_requests)
|
||||
if self.num_loras > 0:
|
||||
self._sort_by_lora_ids()
|
||||
|
||||
@ -59,13 +97,13 @@ class SchedulerOutputs:
|
||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||
|
||||
def _sort_by_lora_ids(self) -> bool:
|
||||
self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
|
||||
key=lambda g:
|
||||
(g.lora_int_id, g.request_id))
|
||||
self.scheduled_seq_groups = sorted(
|
||||
self.scheduled_seq_groups,
|
||||
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
|
||||
|
||||
@property
|
||||
def lora_requests(self) -> Set[LoRARequest]:
|
||||
return {g.lora_request for g in self.scheduled_seq_groups}
|
||||
return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
|
||||
|
||||
|
||||
class Scheduler:
|
||||
@ -198,11 +236,13 @@ class Scheduler:
|
||||
assert len(waiting_seqs) == 1, (
|
||||
"Waiting sequence group should have only one prompt "
|
||||
"sequence.")
|
||||
num_prompt_tokens = waiting_seqs[0].get_len()
|
||||
if num_prompt_tokens > self.prompt_limit:
|
||||
# get_len includes output tokens if the request has been
|
||||
# preempted.
|
||||
num_prefill_tokens = waiting_seqs[0].get_len()
|
||||
if num_prefill_tokens > self.prompt_limit:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds limit of {self.prompt_limit}")
|
||||
f"Input prompt ({num_prefill_tokens} tokens) is too "
|
||||
f"long and exceeds limit of {self.prompt_limit}")
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
@ -215,8 +255,8 @@ class Scheduler:
|
||||
break
|
||||
elif can_allocate == AllocStatus.NEVER:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds the capacity of block_manager")
|
||||
f"Input prompt ({num_prefill_tokens} tokens) is too "
|
||||
f"long and exceeds the capacity of block_manager")
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
@ -235,7 +275,7 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
num_batched_tokens += num_prefill_tokens
|
||||
if (num_batched_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
@ -253,8 +293,10 @@ class Scheduler:
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_curr_seqs += num_new_seqs
|
||||
scheduled.append(seq_group)
|
||||
|
||||
scheduled.append(
|
||||
ScheduledSequenceGroup(
|
||||
seq_group=seq_group,
|
||||
token_chunk_size=num_prefill_tokens))
|
||||
self.waiting.extendleft(leftover_waiting_sequences)
|
||||
|
||||
if scheduled or ignored_seq_groups:
|
||||
@ -352,7 +394,11 @@ class Scheduler:
|
||||
for seq_group in self.running)
|
||||
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=self.running,
|
||||
scheduled_seq_groups=[
|
||||
ScheduledSequenceGroup(seq_group=running_group,
|
||||
token_chunk_size=1)
|
||||
for running_group in self.running
|
||||
],
|
||||
prompt_run=False,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
@ -371,10 +417,14 @@ class Scheduler:
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
token_chunk_size = scheduled_seq_group.token_chunk_size
|
||||
seq_group.maybe_set_first_scheduled_time(now)
|
||||
|
||||
# seq_id -> SequenceData
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
# seq_id -> physical block numbers
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
@ -393,6 +443,7 @@ class Scheduler:
|
||||
seq_data=seq_data,
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
token_chunk_size=token_chunk_size,
|
||||
lora_request=seq_group.lora_request,
|
||||
computed_block_nums=common_computed_block_nums,
|
||||
state=seq_group.state,
|
||||
@ -409,8 +460,9 @@ class Scheduler:
|
||||
# batch will have been computed before the next scheduling invocation.
|
||||
# This is because the engine assumes that a failure in model execution
|
||||
# will crash the vLLM instance / will not retry.
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
self.block_manager.mark_blocks_as_computed(seq_group)
|
||||
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
self.block_manager.mark_blocks_as_computed(
|
||||
scheduled_seq_group.seq_group)
|
||||
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
@ -418,6 +470,7 @@ class Scheduler:
|
||||
self.block_manager.fork(parent_seq, child_seq)
|
||||
|
||||
def free_seq(self, seq: Sequence) -> None:
|
||||
"""Free a sequence from a block table."""
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
@ -480,7 +533,8 @@ class Scheduler:
|
||||
assert len(seqs) == 1
|
||||
for seq in seqs:
|
||||
seq.status = SequenceStatus.WAITING
|
||||
self.block_manager.free(seq)
|
||||
self.free_seq(seq)
|
||||
seq.reset_state_for_recompute()
|
||||
# NOTE: For FCFS, we insert the preempted sequence group to the front
|
||||
# of the waiting queue.
|
||||
self.waiting.appendleft(seq_group)
|
||||
|
@ -62,6 +62,7 @@ class EngineArgs:
|
||||
image_input_shape: Optional[str] = None
|
||||
image_feature_size: Optional[int] = None
|
||||
scheduler_delay_factor: float = 0.0
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -356,6 +357,12 @@ class EngineArgs:
|
||||
default=EngineArgs.scheduler_delay_factor,
|
||||
help='Apply a delay (of delay factor multiplied by previous'
|
||||
'prompt latency) before scheduling next prompt.')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -394,11 +401,14 @@ class EngineArgs:
|
||||
self.tokenizer_pool_type,
|
||||
self.tokenizer_pool_extra_config,
|
||||
), self.ray_workers_use_nsight)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
self.use_v2_block_manager,
|
||||
self.scheduler_delay_factor)
|
||||
scheduler_config = SchedulerConfig(
|
||||
self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
self.use_v2_block_manager,
|
||||
delay_factor=self.scheduler_delay_factor,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
)
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
max_loras=self.max_loras,
|
||||
|
@ -553,7 +553,10 @@ class LLMEngine:
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
|
||||
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
token_chunk_size = scheduled_seq_group.token_chunk_size
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
@ -561,7 +564,8 @@ class LLMEngine:
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in scheduled_seq_groups:
|
||||
for scheduled_seq_group in scheduled_seq_groups:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
@ -676,17 +680,20 @@ class LLMEngine:
|
||||
# Number of Tokens.
|
||||
if prompt_run:
|
||||
num_prompt_tokens = sum(
|
||||
len(seq_group.prompt_token_ids)
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups)
|
||||
len(scheduled_seq_group.seq_group.prompt_token_ids)
|
||||
for scheduled_seq_group in
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
num_generation_tokens = sum(
|
||||
seq_group.num_seqs()
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups)
|
||||
scheduled_seq_group.seq_group.num_seqs()
|
||||
for scheduled_seq_group in
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
else:
|
||||
num_generation_tokens = scheduler_outputs.num_batched_tokens
|
||||
|
||||
# Latency Timings.
|
||||
time_last_iters = []
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
# Time since last token.
|
||||
# (n.b. updates seq_group.metrics.last_token_time)
|
||||
time_last_iters.append(seq_group.get_last_latency(now))
|
||||
|
@ -113,6 +113,8 @@ class SequenceData:
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.output_token_ids = output_token_ids
|
||||
self.cumulative_logprob = 0.0
|
||||
# The number of tokens that are computed (that run against the model).
|
||||
self._num_computed_tokens = 0
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
self.output_token_ids.append(token_id)
|
||||
@ -130,6 +132,28 @@ class SequenceData:
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.prompt_token_ids + self.output_token_ids
|
||||
|
||||
def get_num_computed_tokens(self) -> int:
|
||||
"""Return the number of prefill tokens that are already computed."""
|
||||
return self._num_computed_tokens
|
||||
|
||||
def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int:
|
||||
"""Update number of tokens computed so far."""
|
||||
self._num_computed_tokens += num_new_computed_tokens
|
||||
|
||||
def reset_num_computed_tokens(self) -> None:
|
||||
"""Reset the number of computed tokens from this sequence. It is
|
||||
supposed to be called when a sequence needs to be started from
|
||||
the beginning again (e.g., sequence is preempted).
|
||||
"""
|
||||
self._num_computed_tokens = 0
|
||||
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
"""Return the number of prefil tokens that are not computed."""
|
||||
# we use `get_len()` which includes prompt_len + output_len instead
|
||||
# of prompt_len here. This is because during recompute we need to
|
||||
# prefill for both prompt and output.
|
||||
return self.get_len() - self.get_num_computed_tokens()
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
if not self.output_token_ids:
|
||||
return self.prompt_token_ids[-1]
|
||||
@ -208,6 +232,10 @@ class Sequence:
|
||||
def num_hashed_tokens_of_block(self, logical_idx: int):
|
||||
return logical_idx * self.block_size + self.block_size
|
||||
|
||||
def reset_state_for_recompute(self):
|
||||
"""Reset the sequence states for recomputation."""
|
||||
self.data.reset_num_computed_tokens()
|
||||
|
||||
def _append_logical_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
block_number=len(self.logical_token_blocks),
|
||||
@ -430,6 +458,18 @@ class SequenceGroup:
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||
|
||||
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
||||
"""Update number of tokens computed so far."""
|
||||
for seq in self.seqs_dict.values():
|
||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
# All sequences in the group should have the same prompt, so the
|
||||
# number of unfinished prefill tokens are the same across all
|
||||
# sequences.
|
||||
return list(
|
||||
self.seqs_dict.values())[0].data.get_num_uncomputed_tokens()
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
@ -473,6 +513,8 @@ class SequenceGroupMetadata:
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
block_tables: The block tables. (Seq id -> list of physical block
|
||||
numbers)
|
||||
token_chunk_size: The number of tokens to be processed. None if
|
||||
chunking is not required.
|
||||
state: Internal state tied to this sequence group.
|
||||
lora_request: LoRA request.
|
||||
multi_modal_data: Multi modal data.
|
||||
@ -485,6 +527,7 @@ class SequenceGroupMetadata:
|
||||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]],
|
||||
token_chunk_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
state: Optional[SequenceGroupState] = None,
|
||||
@ -499,11 +542,23 @@ class SequenceGroupMetadata:
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
self._token_chunk_size = token_chunk_size
|
||||
|
||||
if self._token_chunk_size is None:
|
||||
if is_prompt:
|
||||
self._token_chunk_size = list(seq_data.values())[0].get_len()
|
||||
else:
|
||||
self._token_chunk_size = 1
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def token_chunk_size(self) -> int:
|
||||
"""Return the number of tokens to be processed (chunk size)."""
|
||||
return self._token_chunk_size
|
||||
|
||||
|
||||
class SequenceOutput:
|
||||
"""The model output associated with a sequence.
|
||||
|
@ -150,39 +150,58 @@ class ModelRunner:
|
||||
subquery_lens: List[int] = []
|
||||
prefix_block_tables: List[List[int]] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
if (self.scheduler_config is not None
|
||||
and self.scheduler_config.chunked_prefill_enabled
|
||||
and computed_block_nums is not None):
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching "
|
||||
"now.")
|
||||
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
# We should use get_len here because in case of preemption
|
||||
# it contains output tokens.
|
||||
prefill_end = min(seq_data.get_len(),
|
||||
computed_len + token_chunk_size)
|
||||
# TODO(sang): Rename it after chunked prefill is introduced.
|
||||
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
|
||||
prompt_len = len(prompt_tokens)
|
||||
# Right now, the prefill_end is always same as the length of
|
||||
# sequence. However, once chunked prefill is introduced, this
|
||||
# assumption can be changed.
|
||||
assert prefill_end == seq_data.get_len()
|
||||
prompt_lens.append(prompt_len)
|
||||
computed_len = 0
|
||||
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
if computed_block_nums is not None and len(
|
||||
computed_block_nums) > 0 and self.sliding_window is None:
|
||||
# Prefix is not supported with sliding_window
|
||||
computed_len = len(computed_block_nums) * self.block_size
|
||||
prompt_tokens = prompt_tokens[computed_len:]
|
||||
prefix_block_tables.append(computed_block_nums)
|
||||
context_len = computed_len
|
||||
else:
|
||||
prefix_block_tables.append([])
|
||||
context_len = 0
|
||||
# Right now, prefill start is always 0. However, this
|
||||
# assumption can be changed once chunked prefill is introduced.
|
||||
assert computed_len == 0
|
||||
|
||||
# actual prompt lens
|
||||
context_lens.append(context_len)
|
||||
context_lens.append(computed_len)
|
||||
subquery_lens.append(prompt_len - computed_len)
|
||||
|
||||
input_tokens.extend(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(
|
||||
list(range(computed_len, computed_len + len(prompt_tokens))))
|
||||
input_positions.extend(list(range(computed_len, prefill_end)))
|
||||
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
|
||||
@ -218,7 +237,8 @@ class ModelRunner:
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention")
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
for i in range(computed_len, prompt_len):
|
||||
|
||||
for i in range(computed_len, prefill_end):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
@ -331,6 +351,7 @@ class ModelRunner:
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
|
Loading…
x
Reference in New Issue
Block a user