[V1][Metrics] Initial speculative decoding metrics (#15151)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-04-01 18:45:04 +01:00 committed by GitHub
parent 7e3f7a4ee7
commit a79cc68b3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 204 additions and 2 deletions

View File

@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs_dict={}, prompt_logprobs_dict={},
) )
scheduler.update_from_output(scheduler_output1, model_runner_output) scheduler.update_from_output(scheduler_output1, model_runner_output)
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
([[1]], [[1, 2]], (1, 1)), # single token sequence
([[]], [[5]], (0, 0)), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(6, 3)), # multiple mismatches
])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
scheduler = create_scheduler()
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token
assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == 0
assert stats.num_accepted_tokens == 0
# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == \
len(requests) + sum(len(ids) for ids in spec_tokens)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
len(spec_tokens[i])
else:
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1]

View File

@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__) logger = init_logger(__name__)
@ -552,6 +553,7 @@ class Scheduler(SchedulerInterface):
spec_token_ids = model_runner_output.spec_token_ids spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = [] new_running: list[Request] = []
@ -584,6 +586,11 @@ class Scheduler(SchedulerInterface):
len(generated_token_ids)) len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected request.num_computed_tokens -= num_tokens_rejected
if spec_decoding_stats is not None:
spec_decoding_stats.observe(
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = ( cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request)) self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty. # OPTIMIZATION: Avoid list(set) if the set is empty.
@ -657,7 +664,7 @@ class Scheduler(SchedulerInterface):
self.running = new_running self.running = new_running
engine_core_outputs = EngineCoreOutputs( engine_core_outputs = EngineCoreOutputs(
outputs=outputs, outputs=outputs,
scheduler_stats=self.make_stats(), scheduler_stats=self.make_stats(spec_decoding_stats),
) )
if self.include_finished_set: if self.include_finished_set:
#TODO currently sending duplicates here, improve this #TODO currently sending duplicates here, improve this
@ -724,7 +731,10 @@ class Scheduler(SchedulerInterface):
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache() return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> Optional[SchedulerStats]: def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats: if not self.log_stats:
return None return None
return SchedulerStats( return SchedulerStats(
@ -732,4 +742,5 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage, gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
spec_decoding_stats=spec_decoding_stats,
) )

View File

@ -12,6 +12,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
@ -38,6 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
# Prefix cache metrics. This cannot be reset. # Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable. # TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics() self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()
def _reset(self, now): def _reset(self, now):
self.last_log_time = now self.last_log_time = now
@ -65,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
def log(self): def log(self):
@ -94,6 +100,9 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
) )
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log()
class PrometheusStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase):
@ -302,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self.labelname_running_lora_adapters, self.labelname_running_lora_adapters,
]) ])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)
# #
# Cache config info metric # Cache config info metric
# #
@ -338,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self.counter_gpu_prefix_cache_hits.inc( self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits) scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
if iteration_stats is None: if iteration_stats is None:
return return

View File

@ -4,6 +4,8 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState from vllm.v1.engine.output_processor import RequestState
@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats: PrefixCacheStats = field( prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats) default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
@dataclass @dataclass
class LoRAStats: class LoRAStats:

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import numpy as np
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
def take(self):
copied = SpecDecodingStats(self.num_draft_tokens,
self.num_accepted_tokens)
self.reset()
return copied
def reset(self):
self.num_draft_tokens = 0
self.num_accepted_tokens = 0
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
class SpecDecodingMetrics:
def __init__(self):
self.reset()
def reset(self):
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(
spec_decoding_stats.num_accepted_tokens)
def log(self):
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
if num_draft_tokens > 0 else float("nan"))
logger.info(
"Speculative metrics: "
"Draft acceptance rate: %.3f, "
"Number of accepted tokens: %d, "
"Number of draft tokens: %d, ", draft_acceptance_rate,
num_accepted_tokens, num_draft_tokens)
self.reset()