From 0a049c7d8608ea54ca2aa4543cae7c3927e0a0ca Mon Sep 17 00:00:00 2001 From: yarongmu-google <150371854+yarongmu-google@users.noreply.github.com> Date: Tue, 25 Mar 2025 09:27:16 -0700 Subject: [PATCH] [CI/Build] Add tests for the V1 tpu_model_runner. (#14843) Signed-off-by: Yarong Mu --- .buildkite/run-tpu-v1-test.sh | 4 +- tests/v1/tpu/worker/__init__.py | 0 tests/v1/tpu/worker/test_tpu_model_runner.py | 307 +++++++++++++++++++ 3 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 tests/v1/tpu/worker/__init__.py create mode 100644 tests/v1/tpu/worker/test_tpu_model_runner.py diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index f0f53d3b..d557feef 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -30,7 +30,9 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_4 \ && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ && echo TEST_5 \ - && python3 /workspace/vllm/examples/offline_inference/tpu.py" \ + && python3 /workspace/vllm/examples/offline_inference/tpu.py \ + && echo TEST_6 \ + && pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/v1/tpu/worker/__init__.py b/tests/v1/tpu/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py new file mode 100644 index 00000000..40ae52ef --- /dev/null +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +import unittest.mock as mock + +import pytest + +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.sampling_params import SamplingParams +from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, + SchedulerOutput) +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +# Mock torch_xla module since it may not be available in the test environments +torch_xla_patcher = mock.patch.dict( + "sys.modules", { + "torch_xla": mock.MagicMock(), + "torch_xla.core.xla_model": mock.MagicMock(), + "torch_xla.runtime": mock.MagicMock(), + }) +torch_xla_patcher.start() + +# Mock the PallasAttentionBackend +pallas_attention_backend_patcher = mock.patch( + "vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", ) +pallas_attention_backend_patcher.start() + + +@pytest.fixture +def model_runner(): + # Patchers have already been started at module level. + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + ) + model_config = ModelConfig( + model="facebook/opt-125m", + task="generate", + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", # TPUs typically use bfloat16 + seed=42, + ) + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ) + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + ) + device = "xla:0" # Mocking TPU device + with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \ + mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \ + mock.patch("vllm.v1.worker.tpu_model_runner.xr"): + return TPUModelRunner(vllm_config, device) + + +@pytest.fixture(autouse=True, scope="session") +def cleanup_patches(): + yield + torch_xla_patcher.stop() + pallas_attention_backend_patcher.stop() + + +def _schedule_new_request(*req_ids: str) -> SchedulerOutput: + new_reqs = [] + num_scheduled_tokens = {} + total_num_scheduled_tokens = 0 + for req_id in req_ids: + new_reqs.append( + NewRequestData( + req_id=req_id, + prompt_token_ids=[1, 2, 3], + prompt="test", + mm_inputs=[], + mm_hashes=[], + mm_positions=[], + sampling_params=SamplingParams(), + block_ids=[0], + num_computed_tokens=0, + lora_request=None, + )) + num_scheduled_tokens[req_id] = 3 + total_num_scheduled_tokens += num_scheduled_tokens[req_id] + + return SchedulerOutput( + scheduled_new_reqs=new_reqs, + scheduled_cached_reqs=[], + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + +def _is_req_scheduled(model_runner, req_id: str) -> bool: + return req_id in model_runner.input_batch.req_id_to_index + + +def _is_req_added(model_runner, req_id: str) -> bool: + return req_id in model_runner.requests + + +def _is_sampling_metadata_changed(model_runner, + sampling_metadata_before: SamplingMetadata): + return model_runner.input_batch.sampling_metadata is not ( + sampling_metadata_before) + + +def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: + req_index = model_runner.input_batch.req_id_to_index[req_id] + block_table = model_runner.input_batch.block_table + req_state = model_runner.requests[req_id] + if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): + return False + num_blocks = block_table.num_blocks_per_row[req_index] + return (block_table.block_table_np[req_index, :num_blocks] == + req_state.block_ids).all() + + +def test_update_states_new_request(model_runner): + req_id = "req_0" + + # new req + scheduler_output = _schedule_new_request(req_id) + + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + + assert _is_sampling_metadata_changed(model_runner, metadata_before) + assert _is_req_added(model_runner, req_id) + assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) + + +def test_update_states_request_finished(model_runner): + req_id = "req_0" + + # new req + scheduler_output = _schedule_new_request(req_id) + + model_runner._update_states(scheduler_output) + assert _is_req_added(model_runner, req_id) + assert _is_req_scheduled(model_runner, req_id) + + # finish req + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids={req_id}, + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) + assert not _is_req_added(model_runner, req_id) + assert not _is_req_scheduled(model_runner, req_id) + + +def test_update_states_request_resumed(model_runner): + req_id = "req_0" + + # new req + scheduler_output = _schedule_new_request(req_id) + + model_runner._update_states(scheduler_output) + assert _is_req_added(model_runner, req_id) + assert _is_req_scheduled(model_runner, req_id) + + # unschedule req + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + model_runner._update_states(scheduler_output) + assert _is_req_added(model_runner, req_id) + assert not _is_req_scheduled(model_runner, req_id) + + # resume req + cached_req_data = CachedRequestData( + req_id=req_id, + resumed_from_preemption=False, + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=0, + ) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[cached_req_data], + num_scheduled_tokens={req_id: 1}, + total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) + assert _is_req_added(model_runner, req_id) + assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) + + +def test_update_states_no_changes(model_runner): + req_id = "req_0" + + # new req + scheduler_output = _schedule_new_request(req_id) + + model_runner._update_states(scheduler_output) + assert _is_req_added(model_runner, req_id) + assert _is_req_scheduled(model_runner, req_id) + + # schedule req + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={req_id: 1}, + total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert not _is_sampling_metadata_changed(model_runner, metadata_before) + assert _is_req_added(model_runner, req_id) + assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) + + +def test_update_states_request_unscheduled(model_runner): + req_ids = ("req_0", "req_1") + + # new reqs + scheduler_output = _schedule_new_request(*req_ids) + + model_runner._update_states(scheduler_output) + + assert _is_req_added(model_runner, req_ids[0]) + assert _is_req_scheduled(model_runner, req_ids[0]) + + assert _is_req_added(model_runner, req_ids[1]) + assert _is_req_scheduled(model_runner, req_ids[1]) + + # unschedule req_1 + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={req_ids[0]: 1}, + total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + metadata_before = model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) + + assert _is_req_added(model_runner, req_ids[0]) + assert _is_req_scheduled(model_runner, req_ids[0]) + + assert _is_req_added(model_runner, req_ids[1]) + assert not _is_req_scheduled(model_runner, req_ids[1])