[V1][Metrics] Initial speculative decoding metrics (#15151)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-04-01 18:45:04 +01:00 committed by GitHub
parent 7e3f7a4ee7
commit a79cc68b3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 204 additions and 2 deletions

View File

@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs_dict={},
)
scheduler.update_from_output(scheduler_output1, model_runner_output)
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
([[1]], [[1, 2]], (1, 1)), # single token sequence
([[]], [[5]], (0, 0)), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(6, 3)), # multiple mismatches
])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
scheduler = create_scheduler()
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token
assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == 0
assert stats.num_accepted_tokens == 0
# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == \
len(requests) + sum(len(ids) for ids in spec_tokens)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
len(spec_tokens[i])
else:
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1]

View File

@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__)
@ -552,6 +553,7 @@ class Scheduler(SchedulerInterface):
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = []
@ -584,6 +586,11 @@ class Scheduler(SchedulerInterface):
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
if spec_decoding_stats is not None:
spec_decoding_stats.observe(
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
@ -657,7 +664,7 @@ class Scheduler(SchedulerInterface):
self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
scheduler_stats=self.make_stats(spec_decoding_stats),
)
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
@ -724,7 +731,10 @@ class Scheduler(SchedulerInterface):
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> Optional[SchedulerStats]:
def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
return SchedulerStats(
@ -732,4 +742,5 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
spec_decoding_stats=spec_decoding_stats,
)

View File

@ -12,6 +12,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
logger = init_logger(__name__)
@ -38,6 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()
def _reset(self, now):
self.last_log_time = now
@ -65,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats
def log(self):
@ -94,6 +100,9 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100,
)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log()
class PrometheusStatLogger(StatLoggerBase):
@ -302,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self.labelname_running_lora_adapters,
])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
#
# Cache config info metric
#
@ -338,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
if iteration_stats is None:
return

View File

@ -4,6 +4,8 @@ import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState
@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@dataclass
class LoRAStats:

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import numpy as np
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
def take(self):
copied = SpecDecodingStats(self.num_draft_tokens,
self.num_accepted_tokens)
self.reset()
return copied
def reset(self):
self.num_draft_tokens = 0
self.num_accepted_tokens = 0
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
class SpecDecodingMetrics:
def __init__(self):
self.reset()
def reset(self):
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(
spec_decoding_stats.num_accepted_tokens)
def log(self):
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
if num_draft_tokens > 0 else float("nan"))
logger.info(
"Speculative metrics: "
"Draft acceptance rate: %.3f, "
"Number of accepted tokens: %d, "
"Number of draft tokens: %d, ", draft_acceptance_rate,
num_accepted_tokens, num_draft_tokens)
self.reset()