vllm/tests/v1/test_stats.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

303 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.sampling_params import SamplingParams
from vllm.v1.stats.common import RequestStats, RequestStatsUpdate
def make_update(
request_id: str,
update_type: RequestStatsUpdate.Type,
monotonic_ts_s: float,
**kwargs,
):
if update_type == RequestStatsUpdate.Type.INPUT_PROCESSED:
kwargs.setdefault("sampling_params", SamplingParams(n=1))
kwargs.setdefault("num_prompt_tokens", 10)
elif update_type == RequestStatsUpdate.Type.PREFILLING:
kwargs.setdefault("num_computed_tokens", 10)
kwargs.setdefault("num_cached_tokens", 10)
elif update_type == RequestStatsUpdate.Type.DETOKENIZED:
kwargs.setdefault("num_new_tokens", 10)
elif update_type == RequestStatsUpdate.Type.FINISHED:
kwargs.setdefault("finish_reason", "test_reason")
return RequestStatsUpdate(
request_id=request_id,
type=update_type,
monotonic_ts_s=monotonic_ts_s,
**kwargs,
)
def test_invalid_request_update():
request_id = "test_request"
update_specific_required_fields = {
RequestStatsUpdate.Type.INPUT_PROCESSED: [
"sampling_params",
"num_prompt_tokens",
],
RequestStatsUpdate.Type.PREFILLING: [
"num_computed_tokens",
"num_cached_tokens",
],
RequestStatsUpdate.Type.DETOKENIZED: ["num_new_tokens"],
RequestStatsUpdate.Type.FINISHED: ["finish_reason"],
}
# Missing a required field should raise an assertion error.
for update_type in RequestStatsUpdate.Type:
required_fields = update_specific_required_fields.get(update_type, [])
# Try to miss one of the required fields.
kwargs = {field: object() for field in required_fields}
for field in required_fields:
copy_kwargs = kwargs.copy()
copy_kwargs.pop(field)
with pytest.raises(ValueError):
RequestStatsUpdate(
request_id=request_id,
type=update_type,
**copy_kwargs,
)
def test_invalid_request_update_transition():
# Test invalid transition type.
for src in RequestStatsUpdate.Type:
for dst in RequestStatsUpdate.Type:
if dst not in RequestStatsUpdate._VALID_TRANSITIONS[src]:
with pytest.raises(AssertionError):
RequestStatsUpdate.check_valid_update(
make_update(
update_type=dst,
request_id="test_request",
monotonic_ts_s=1,
),
last_update_type=src,
last_updated_ts_s=0,
)
else:
RequestStatsUpdate.check_valid_update(
make_update(
request_id="test_request",
update_type=dst,
monotonic_ts_s=1,
),
last_update_type=src,
last_updated_ts_s=0,
)
# Test invalid timestamp.
with pytest.raises(AssertionError):
RequestStatsUpdate.check_valid_update(
make_update(
request_id="test_request",
update_type=RequestStatsUpdate.Type.ARRIVED,
monotonic_ts_s=1,
),
last_update_type=None,
last_updated_ts_s=2,
)
def test_lifecycle_updates():
request_id = "test_request"
stats = RequestStats(request_id=request_id)
# Test the below scenario:
arrived_ts = 0
input_processed_ts = 1
queued_ts = 2
prefilling_ts = 3
decoded_ts = 5
detokenized_ts = 6
decoded_2_ts = 7
detokenized_2_ts = 8
preempted_ts = 9
resumed_ts = 10
decoded_3_ts = 11
detokenized_3_ts = 12
finished_ts = 13
# Test ARRIVED
arrived_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.ARRIVED,
monotonic_ts_s=arrived_ts,
)
stats.update_from(arrived_update)
assert stats.arrival_ts_s == arrived_ts
assert stats.last_updated_ts_s == arrived_ts
# Test INPUT_PROCESSED
sampling_params = SamplingParams(n=1)
input_processed_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.INPUT_PROCESSED,
monotonic_ts_s=input_processed_ts,
sampling_params=sampling_params,
num_prompt_tokens=6,
)
stats.update_from(input_processed_update)
assert stats.input_processor_end_ts_s == input_processed_ts
assert stats.last_updated_ts_s == input_processed_ts
assert stats.num_prompt_tokens == 6
assert stats.sampling_params == sampling_params
assert stats.first_token_ts_s is None
assert stats.prefill_ts_s is None
# Test QUEUED
queued_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.QUEUED,
monotonic_ts_s=queued_ts,
)
stats.update_from(queued_update)
assert stats.queued_ts_s == queued_ts
assert stats.last_updated_ts_s == queued_ts
# Test PREFILLING
prefilling_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.PREFILLING,
monotonic_ts_s=prefilling_ts,
num_computed_tokens=3,
num_cached_tokens=1,
)
stats.update_from(prefilling_update)
assert stats.prefill_ts_s == prefilling_ts
assert stats.num_computed_tokens == 3
assert stats.num_cached_tokens == 1
assert stats.queue_duration_s == prefilling_ts - queued_ts
# Test DECODING
decoded_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DECODING,
monotonic_ts_s=decoded_ts,
)
stats.update_from(decoded_update)
assert stats.last_updated_ts_s == decoded_ts
# Test DETOKENIZED
detokenized_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DETOKENIZED,
monotonic_ts_s=detokenized_ts,
num_new_tokens=1,
)
stats.update_from(detokenized_update)
assert stats.last_updated_ts_s == detokenized_ts
assert stats.num_output_tokens == 1
# Since arrival
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
# Since first scheduled
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
# Test another DECODING and DETOKENIZED should
# yield correct inter token latency
decoded_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DECODING,
monotonic_ts_s=decoded_2_ts,
)
stats.update_from(decoded_update)
detokenized_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DETOKENIZED,
monotonic_ts_s=detokenized_2_ts,
num_new_tokens=1,
)
stats.update_from(detokenized_update)
assert stats.output_token_latency_s_lst == [
detokenized_2_ts - detokenized_ts,
]
assert stats.num_output_tokens == 2
# Test PREEMPTED
preempted_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.PREEMPTED,
monotonic_ts_s=preempted_ts,
)
stats.update_from(preempted_update)
assert stats.last_updated_ts_s == preempted_ts
assert stats.preempted_ts_s_lst == [preempted_ts]
# States should be reset
assert stats.num_computed_tokens == 0
assert stats.num_cached_tokens == 0
# These states should not be reset
assert stats.num_output_tokens == 2
assert stats.output_token_latency_s_lst == [
detokenized_2_ts - detokenized_ts,
]
assert stats.prefill_latency_s == prefilling_ts - arrived_ts
assert stats.num_prompt_tokens == 6
assert stats.prefill_start_ts_s_lst == [prefilling_ts]
# Test resumed
resumed_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.PREFILLING,
monotonic_ts_s=resumed_ts,
num_computed_tokens=6,
num_cached_tokens=2,
)
stats.update_from(resumed_update)
# prefill timestamp should not be updated since it's a resumed prefill
assert stats.prefill_ts_s == prefilling_ts
assert stats.num_computed_tokens == 6
assert stats.num_cached_tokens == 2
assert stats.prefill_start_ts_s_lst == [
prefilling_ts,
resumed_ts,
]
assert stats.last_updated_ts_s == resumed_ts
# Test another DECODED/DETOKENIZED should yield correct first token latency.
decoded_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DECODING,
monotonic_ts_s=decoded_3_ts,
)
detokenized_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.DETOKENIZED,
monotonic_ts_s=detokenized_3_ts,
num_new_tokens=1,
)
stats.update_from(decoded_update)
stats.update_from(detokenized_update)
assert stats.first_token_ts_s == detokenized_ts - arrived_ts
assert stats.num_output_tokens == 3
assert stats.output_token_latency_s_lst == [
detokenized_2_ts - detokenized_ts,
detokenized_3_ts - detokenized_2_ts,
]
# Test FINISHED
finished_update = RequestStatsUpdate(
request_id=request_id,
type=RequestStatsUpdate.Type.FINISHED,
monotonic_ts_s=finished_ts,
finish_reason="test_reason",
)
stats.update_from(finished_update)
assert stats.last_updated_ts_s == finished_ts
assert stats.e2e_latency_s == finished_ts - arrived_ts
assert stats.inference_latency_s == finished_ts - prefilling_ts
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
assert stats.decode_latency_s == finished_ts - detokenized_ts
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
assert stats.queue_duration_s == prefilling_ts - queued_ts
assert stats.is_finished
assert stats.finish_reason == "test_reason"
# TODO(rickyx): Add model forward/execute time.
assert stats.model_forward_duration_s == 0.0
assert stats.model_execute_duration_s == 0.0