
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
303 lines
10 KiB
Python
303 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.v1.stats.common import RequestStats, RequestStatsUpdate
|
|
|
|
|
|
def make_update(
|
|
request_id: str,
|
|
update_type: RequestStatsUpdate.Type,
|
|
monotonic_ts_s: float,
|
|
**kwargs,
|
|
):
|
|
if update_type == RequestStatsUpdate.Type.INPUT_PROCESSED:
|
|
kwargs.setdefault("sampling_params", SamplingParams(n=1))
|
|
kwargs.setdefault("num_prompt_tokens", 10)
|
|
elif update_type == RequestStatsUpdate.Type.PREFILLING:
|
|
kwargs.setdefault("num_computed_tokens", 10)
|
|
kwargs.setdefault("num_cached_tokens", 10)
|
|
elif update_type == RequestStatsUpdate.Type.DETOKENIZED:
|
|
kwargs.setdefault("num_new_tokens", 10)
|
|
elif update_type == RequestStatsUpdate.Type.FINISHED:
|
|
kwargs.setdefault("finish_reason", "test_reason")
|
|
|
|
return RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=update_type,
|
|
monotonic_ts_s=monotonic_ts_s,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def test_invalid_request_update():
|
|
request_id = "test_request"
|
|
update_specific_required_fields = {
|
|
RequestStatsUpdate.Type.INPUT_PROCESSED: [
|
|
"sampling_params",
|
|
"num_prompt_tokens",
|
|
],
|
|
RequestStatsUpdate.Type.PREFILLING: [
|
|
"num_computed_tokens",
|
|
"num_cached_tokens",
|
|
],
|
|
RequestStatsUpdate.Type.DETOKENIZED: ["num_new_tokens"],
|
|
RequestStatsUpdate.Type.FINISHED: ["finish_reason"],
|
|
}
|
|
|
|
# Missing a required field should raise an assertion error.
|
|
for update_type in RequestStatsUpdate.Type:
|
|
required_fields = update_specific_required_fields.get(update_type, [])
|
|
|
|
# Try to miss one of the required fields.
|
|
kwargs = {field: object() for field in required_fields}
|
|
for field in required_fields:
|
|
copy_kwargs = kwargs.copy()
|
|
copy_kwargs.pop(field)
|
|
with pytest.raises(ValueError):
|
|
RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=update_type,
|
|
**copy_kwargs,
|
|
)
|
|
|
|
|
|
def test_invalid_request_update_transition():
|
|
# Test invalid transition type.
|
|
for src in RequestStatsUpdate.Type:
|
|
for dst in RequestStatsUpdate.Type:
|
|
if dst not in RequestStatsUpdate._VALID_TRANSITIONS[src]:
|
|
with pytest.raises(AssertionError):
|
|
RequestStatsUpdate.check_valid_update(
|
|
make_update(
|
|
update_type=dst,
|
|
request_id="test_request",
|
|
monotonic_ts_s=1,
|
|
),
|
|
last_update_type=src,
|
|
last_updated_ts_s=0,
|
|
)
|
|
else:
|
|
RequestStatsUpdate.check_valid_update(
|
|
make_update(
|
|
request_id="test_request",
|
|
update_type=dst,
|
|
monotonic_ts_s=1,
|
|
),
|
|
last_update_type=src,
|
|
last_updated_ts_s=0,
|
|
)
|
|
|
|
# Test invalid timestamp.
|
|
with pytest.raises(AssertionError):
|
|
RequestStatsUpdate.check_valid_update(
|
|
make_update(
|
|
request_id="test_request",
|
|
update_type=RequestStatsUpdate.Type.ARRIVED,
|
|
monotonic_ts_s=1,
|
|
),
|
|
last_update_type=None,
|
|
last_updated_ts_s=2,
|
|
)
|
|
|
|
|
|
def test_lifecycle_updates():
|
|
request_id = "test_request"
|
|
stats = RequestStats(request_id=request_id)
|
|
|
|
# Test the below scenario:
|
|
arrived_ts = 0
|
|
input_processed_ts = 1
|
|
queued_ts = 2
|
|
prefilling_ts = 3
|
|
decoded_ts = 5
|
|
detokenized_ts = 6
|
|
decoded_2_ts = 7
|
|
detokenized_2_ts = 8
|
|
preempted_ts = 9
|
|
resumed_ts = 10
|
|
decoded_3_ts = 11
|
|
detokenized_3_ts = 12
|
|
finished_ts = 13
|
|
|
|
# Test ARRIVED
|
|
arrived_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.ARRIVED,
|
|
monotonic_ts_s=arrived_ts,
|
|
)
|
|
stats.update_from(arrived_update)
|
|
assert stats.arrival_ts_s == arrived_ts
|
|
assert stats.last_updated_ts_s == arrived_ts
|
|
|
|
# Test INPUT_PROCESSED
|
|
sampling_params = SamplingParams(n=1)
|
|
input_processed_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.INPUT_PROCESSED,
|
|
monotonic_ts_s=input_processed_ts,
|
|
sampling_params=sampling_params,
|
|
num_prompt_tokens=6,
|
|
)
|
|
stats.update_from(input_processed_update)
|
|
assert stats.input_processor_end_ts_s == input_processed_ts
|
|
assert stats.last_updated_ts_s == input_processed_ts
|
|
assert stats.num_prompt_tokens == 6
|
|
assert stats.sampling_params == sampling_params
|
|
|
|
assert stats.first_token_ts_s is None
|
|
assert stats.prefill_ts_s is None
|
|
|
|
# Test QUEUED
|
|
queued_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.QUEUED,
|
|
monotonic_ts_s=queued_ts,
|
|
)
|
|
stats.update_from(queued_update)
|
|
assert stats.queued_ts_s == queued_ts
|
|
assert stats.last_updated_ts_s == queued_ts
|
|
|
|
# Test PREFILLING
|
|
prefilling_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.PREFILLING,
|
|
monotonic_ts_s=prefilling_ts,
|
|
num_computed_tokens=3,
|
|
num_cached_tokens=1,
|
|
)
|
|
stats.update_from(prefilling_update)
|
|
assert stats.prefill_ts_s == prefilling_ts
|
|
assert stats.num_computed_tokens == 3
|
|
assert stats.num_cached_tokens == 1
|
|
assert stats.queue_duration_s == prefilling_ts - queued_ts
|
|
|
|
# Test DECODING
|
|
decoded_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.DECODING,
|
|
monotonic_ts_s=decoded_ts,
|
|
)
|
|
stats.update_from(decoded_update)
|
|
assert stats.last_updated_ts_s == decoded_ts
|
|
|
|
# Test DETOKENIZED
|
|
detokenized_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.DETOKENIZED,
|
|
monotonic_ts_s=detokenized_ts,
|
|
num_new_tokens=1,
|
|
)
|
|
stats.update_from(detokenized_update)
|
|
assert stats.last_updated_ts_s == detokenized_ts
|
|
assert stats.num_output_tokens == 1
|
|
# Since arrival
|
|
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
|
|
# Since first scheduled
|
|
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
|
|
|
|
# Test another DECODING and DETOKENIZED should
|
|
# yield correct inter token latency
|
|
decoded_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.DECODING,
|
|
monotonic_ts_s=decoded_2_ts,
|
|
)
|
|
stats.update_from(decoded_update)
|
|
|
|
detokenized_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.DETOKENIZED,
|
|
monotonic_ts_s=detokenized_2_ts,
|
|
num_new_tokens=1,
|
|
)
|
|
stats.update_from(detokenized_update)
|
|
assert stats.output_token_latency_s_lst == [
|
|
detokenized_2_ts - detokenized_ts,
|
|
]
|
|
assert stats.num_output_tokens == 2
|
|
|
|
# Test PREEMPTED
|
|
preempted_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.PREEMPTED,
|
|
monotonic_ts_s=preempted_ts,
|
|
)
|
|
stats.update_from(preempted_update)
|
|
assert stats.last_updated_ts_s == preempted_ts
|
|
assert stats.preempted_ts_s_lst == [preempted_ts]
|
|
# States should be reset
|
|
assert stats.num_computed_tokens == 0
|
|
assert stats.num_cached_tokens == 0
|
|
# These states should not be reset
|
|
assert stats.num_output_tokens == 2
|
|
assert stats.output_token_latency_s_lst == [
|
|
detokenized_2_ts - detokenized_ts,
|
|
]
|
|
assert stats.prefill_latency_s == prefilling_ts - arrived_ts
|
|
assert stats.num_prompt_tokens == 6
|
|
assert stats.prefill_start_ts_s_lst == [prefilling_ts]
|
|
|
|
# Test resumed
|
|
resumed_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.PREFILLING,
|
|
monotonic_ts_s=resumed_ts,
|
|
num_computed_tokens=6,
|
|
num_cached_tokens=2,
|
|
)
|
|
stats.update_from(resumed_update)
|
|
# prefill timestamp should not be updated since it's a resumed prefill
|
|
assert stats.prefill_ts_s == prefilling_ts
|
|
assert stats.num_computed_tokens == 6
|
|
assert stats.num_cached_tokens == 2
|
|
assert stats.prefill_start_ts_s_lst == [
|
|
prefilling_ts,
|
|
resumed_ts,
|
|
]
|
|
assert stats.last_updated_ts_s == resumed_ts
|
|
|
|
# Test another DECODED/DETOKENIZED should yield correct first token latency.
|
|
decoded_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.DECODING,
|
|
monotonic_ts_s=decoded_3_ts,
|
|
)
|
|
detokenized_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.DETOKENIZED,
|
|
monotonic_ts_s=detokenized_3_ts,
|
|
num_new_tokens=1,
|
|
)
|
|
stats.update_from(decoded_update)
|
|
stats.update_from(detokenized_update)
|
|
assert stats.first_token_ts_s == detokenized_ts - arrived_ts
|
|
assert stats.num_output_tokens == 3
|
|
assert stats.output_token_latency_s_lst == [
|
|
detokenized_2_ts - detokenized_ts,
|
|
detokenized_3_ts - detokenized_2_ts,
|
|
]
|
|
|
|
# Test FINISHED
|
|
finished_update = RequestStatsUpdate(
|
|
request_id=request_id,
|
|
type=RequestStatsUpdate.Type.FINISHED,
|
|
monotonic_ts_s=finished_ts,
|
|
finish_reason="test_reason",
|
|
)
|
|
stats.update_from(finished_update)
|
|
assert stats.last_updated_ts_s == finished_ts
|
|
assert stats.e2e_latency_s == finished_ts - arrived_ts
|
|
assert stats.inference_latency_s == finished_ts - prefilling_ts
|
|
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
|
|
assert stats.decode_latency_s == finished_ts - detokenized_ts
|
|
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
|
|
assert stats.queue_duration_s == prefilling_ts - queued_ts
|
|
assert stats.is_finished
|
|
assert stats.finish_reason == "test_reason"
|
|
|
|
# TODO(rickyx): Add model forward/execute time.
|
|
assert stats.model_forward_duration_s == 0.0
|
|
assert stats.model_execute_duration_s == 0.0
|