[TPU][CI] Fix TPUModelRunner Test (#15667)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
355f66348c
commit
2d9045fce8
@ -30,7 +30,7 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
&& echo TEST_4 \
|
&& echo TEST_4 \
|
||||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||||
&& echo TEST_5 \
|
&& echo TEST_5 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
||||||
&& echo TEST_6 \
|
&& echo TEST_6 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||||
SchedulerOutput)
|
SchedulerOutput)
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
|
||||||
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
|
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
|
||||||
_get_padded_token_len,
|
_get_padded_token_len,
|
||||||
_get_paddings)
|
_get_paddings)
|
||||||
@ -113,12 +112,6 @@ def _is_req_added(model_runner, req_id: str) -> bool:
|
|||||||
return req_id in model_runner.requests
|
return req_id in model_runner.requests
|
||||||
|
|
||||||
|
|
||||||
def _is_sampling_metadata_changed(model_runner,
|
|
||||||
sampling_metadata_before: SamplingMetadata):
|
|
||||||
return model_runner.input_batch.sampling_metadata is not (
|
|
||||||
sampling_metadata_before)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||||
block_table = model_runner.input_batch.block_table
|
block_table = model_runner.input_batch.block_table
|
||||||
@ -136,10 +129,8 @@ def test_update_states_new_request(model_runner):
|
|||||||
# new req
|
# new req
|
||||||
scheduler_output = _schedule_new_request(req_id)
|
scheduler_output = _schedule_new_request(req_id)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
|
|
||||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
assert _is_req_added(model_runner, req_id)
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
assert _is_req_scheduled(model_runner, req_id)
|
||||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||||
@ -170,9 +161,7 @@ def test_update_states_request_finished(model_runner):
|
|||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
|
||||||
assert not _is_req_added(model_runner, req_id)
|
assert not _is_req_added(model_runner, req_id)
|
||||||
assert not _is_req_scheduled(model_runner, req_id)
|
assert not _is_req_scheduled(model_runner, req_id)
|
||||||
|
|
||||||
@ -229,9 +218,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
assert _is_req_added(model_runner, req_id)
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
assert _is_req_scheduled(model_runner, req_id)
|
||||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||||
@ -262,9 +249,7 @@ def test_update_states_no_changes(model_runner):
|
|||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
|
|
||||||
assert _is_req_added(model_runner, req_id)
|
assert _is_req_added(model_runner, req_id)
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
assert _is_req_scheduled(model_runner, req_id)
|
||||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||||
@ -299,8 +284,7 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
grammar_bitmask=None,
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
|
||||||
|
|
||||||
assert _is_req_added(model_runner, req_ids[0])
|
assert _is_req_added(model_runner, req_ids[0])
|
||||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user