2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-08-06 16:51:47 -04:00
|
|
|
import pytest # noqa
|
|
|
|
|
|
|
|
from vllm.config import CacheConfig, SchedulerConfig
|
|
|
|
from vllm.core.scheduler import Scheduler
|
|
|
|
from vllm.sequence import SequenceGroup
|
|
|
|
|
|
|
|
from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
|
|
|
|
get_sequence_groups, schedule_and_update_computed_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
def test_scheduler_schedule_simple_encoder_decoder():
|
|
|
|
'''
|
|
|
|
Test basic scheduler functionality in the context
|
|
|
|
of an encoder/decoder model. Focus on testing
|
|
|
|
enc/dec-specific functionality sense tests already
|
|
|
|
exist for decoder-only functionality
|
|
|
|
|
|
|
|
Test behavior:
|
|
|
|
* Construct Scheduler
|
|
|
|
* Construct dummy encoder/decoder sequence groups
|
|
|
|
* Add dummy seq groups to scheduler backlog
|
|
|
|
* Schedule the next seq group & validate:
|
|
|
|
* Cross-attn block tables
|
|
|
|
* Updated states of seq groups
|
|
|
|
* Number of batched tokens
|
|
|
|
* Number of blocks to copy/swap-in/swap-out
|
|
|
|
* Number of scheduled seq groups
|
|
|
|
* Repeat for both prefill- and decode-phase
|
|
|
|
* Abort scheduled seq groups
|
|
|
|
* Assert that aborted seq groups no longer appear in
|
|
|
|
cross-attention block table
|
|
|
|
'''
|
|
|
|
|
|
|
|
block_size = 4
|
|
|
|
num_seq_group = 4
|
|
|
|
max_model_len = 16
|
2024-10-19 02:31:58 +08:00
|
|
|
scheduler_config = SchedulerConfig(
|
2024-12-11 17:28:00 +08:00
|
|
|
"generate",
|
2024-10-19 02:31:58 +08:00
|
|
|
max_num_batched_tokens=64,
|
|
|
|
max_num_seqs=num_seq_group,
|
|
|
|
max_model_len=max_model_len,
|
|
|
|
)
|
2024-08-06 16:51:47 -04:00
|
|
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
|
|
|
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
|
|
|
|
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
|
|
|
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
2025-03-03 01:34:51 +00:00
|
|
|
running: list[SequenceGroup] = []
|
2024-08-06 16:51:47 -04:00
|
|
|
|
|
|
|
# Add seq groups to scheduler.
|
|
|
|
req_id_list = []
|
|
|
|
for i in range(num_seq_group):
|
|
|
|
req_id = str(i)
|
|
|
|
req_id_list.append(req_id)
|
|
|
|
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
|
|
|
req_id, block_size, block_size, block_size)
|
|
|
|
scheduler.add_seq_group(seq_group)
|
|
|
|
running.append(seq_group)
|
|
|
|
|
|
|
|
# Schedule seq groups prefill.
|
|
|
|
num_tokens = block_size * num_seq_group
|
|
|
|
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
|
|
|
|
# - Verify that sequence group cross-attention block tables are
|
|
|
|
# registered with the block manager
|
|
|
|
assert all([(req_id in scheduler.block_manager.cross_block_tables)
|
|
|
|
for req_id in req_id_list])
|
|
|
|
# - Validate sequence-group status
|
|
|
|
assert set(get_sequence_groups(out)) == set(running)
|
|
|
|
# - Validate number of batched tokens
|
|
|
|
assert out.num_batched_tokens == num_tokens
|
|
|
|
# - Validate there are no remaining blocks to swap
|
|
|
|
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
|
|
|
and not out.blocks_to_swap_out)
|
|
|
|
# - Validate all seq groups were scheduled
|
|
|
|
assert len(seq_group_meta_list) == num_seq_group
|
|
|
|
append_new_token(out, 1)
|
|
|
|
|
|
|
|
# Schedule seq groups decode.
|
|
|
|
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
|
|
|
|
# - Verify that sequence group metadata includes encoder attention
|
|
|
|
# and cross-attention metadata
|
|
|
|
assert all([
|
|
|
|
not ((seq_group_meta.encoder_seq_data is None) or
|
|
|
|
(seq_group_meta.cross_block_table is None))
|
|
|
|
for seq_group_meta in seq_group_meta_list
|
|
|
|
])
|
|
|
|
# - Validate sequence-group status
|
|
|
|
assert set(get_sequence_groups(out)) == set(running)
|
|
|
|
# - Validate there is one batched token per seq group
|
|
|
|
assert out.num_batched_tokens == num_seq_group
|
|
|
|
# - Validate there are no remaining blocks to swap
|
|
|
|
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
|
|
|
and not out.blocks_to_swap_out)
|
|
|
|
# - Validate that all seq groups were scheduled
|
|
|
|
assert len(seq_group_meta_list) == num_seq_group
|
|
|
|
append_new_token(out, 1)
|
|
|
|
|
|
|
|
# Abort sequences
|
|
|
|
for req_id in req_id_list:
|
|
|
|
scheduler.abort_seq_group(req_id)
|
|
|
|
# - Verify that sequence group cross-attention block tables are
|
|
|
|
# NO LONGER registered with the block manager
|
|
|
|
assert req_id not in scheduler.block_manager.cross_block_tables
|