# SPDX-License-Identifier: Apache-2.0 import pytest from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_model_runner import GPUModelRunner @pytest.fixture def model_runner(): scheduler_config = SchedulerConfig( max_num_seqs=10, max_num_batched_tokens=512, max_model_len=512, ) model_config = ModelConfig( model="facebook/opt-125m", task="generate", tokenizer="facebook/opt-125m", tokenizer_mode="auto", trust_remote_code=True, dtype="float16", seed=42, ) cache_config = CacheConfig( block_size=16, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", ) vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, scheduler_config=scheduler_config, ) device = "cuda" return GPUModelRunner(vllm_config, device) def _schedule_new_request(*req_ids: str) -> SchedulerOutput: new_reqs = [] num_scheduled_tokens = {} total_num_scheduled_tokens = 0 for req_id in req_ids: new_reqs.append( NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], prompt="test", mm_inputs=[], mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), block_ids=[0], num_computed_tokens=0, lora_request=None, )) num_scheduled_tokens[req_id] = 3 total_num_scheduled_tokens += num_scheduled_tokens[req_id] return SchedulerOutput( scheduled_new_reqs=new_reqs, scheduled_cached_reqs=[], num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None, ) def _is_req_scheduled(model_runner, req_id: str) -> bool: return req_id in model_runner.input_batch.req_id_to_index def _is_req_added(model_runner, req_id: str) -> bool: return req_id in model_runner.requests def _is_sampling_metadata_changed(model_runner, sampling_metadata_before: SamplingMetadata): return model_runner.input_batch.sampling_metadata is not ( sampling_metadata_before) def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_index = model_runner.input_batch.req_id_to_index[req_id] block_table = model_runner.input_batch.block_table req_state = model_runner.requests[req_id] if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): return False num_blocks = block_table.num_blocks_per_row[req_index] return (block_table.block_table_np[req_index, :num_blocks] == req_state.block_ids).all() def test_update_states_new_request(model_runner): req_id = "req_0" # new req scheduler_output = _schedule_new_request(req_id) metadata_before = model_runner.input_batch.sampling_metadata model_runner._update_states(scheduler_output) assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_request_finished(model_runner): req_id = "req_0" # new req scheduler_output = _schedule_new_request(req_id) model_runner._update_states(scheduler_output) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) # finish req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata model_runner._update_states(scheduler_output) assert _is_sampling_metadata_changed(model_runner, metadata_before) assert not _is_req_added(model_runner, req_id) assert not _is_req_scheduled(model_runner, req_id) def test_update_states_request_resumed(model_runner): req_id = "req_0" # new req scheduler_output = _schedule_new_request(req_id) model_runner._update_states(scheduler_output) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) # unschedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={}, total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None, ) model_runner._update_states(scheduler_output) assert _is_req_added(model_runner, req_id) assert not _is_req_scheduled(model_runner, req_id) # resume req cached_req_data = CachedRequestData( req_id=req_id, resumed_from_preemption=False, new_token_ids=[], new_block_ids=[], num_computed_tokens=0, ) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=[cached_req_data], num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata model_runner._update_states(scheduler_output) assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_no_changes(model_runner): req_id = "req_0" # new req scheduler_output = _schedule_new_request(req_id) model_runner._update_states(scheduler_output) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) # schedule req scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata model_runner._update_states(scheduler_output) assert not _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_request_unscheduled(model_runner): req_ids = ("req_0", "req_1") # new reqs scheduler_output = _schedule_new_request(*req_ids) model_runner._update_states(scheduler_output) assert _is_req_added(model_runner, req_ids[0]) assert _is_req_scheduled(model_runner, req_ids[0]) assert _is_req_added(model_runner, req_ids[1]) assert _is_req_scheduled(model_runner, req_ids[1]) # unschedule req_1 scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={req_ids[0]: 1}, total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None, ) metadata_before = model_runner._update_states(scheduler_output) assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_ids[0]) assert _is_req_scheduled(model_runner, req_ids[0]) assert _is_req_added(model_runner, req_ids[1]) assert not _is_req_scheduled(model_runner, req_ids[1])