[V1][Metrics] Initial speculative decoding metrics (#15151)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
7e3f7a4ee7
commit
a79cc68b3a
@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||
|
||||
|
||||
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
|
||||
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
|
||||
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
|
||||
([[1]], [[1, 2]], (1, 1)), # single token sequence
|
||||
([[]], [[5]], (0, 0)), # empty sequence
|
||||
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
|
||||
(6, 3)), # multiple mismatches
|
||||
])
|
||||
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
"""Test scheduling behavior with speculative decoding.
|
||||
|
||||
This test verifies that:
|
||||
1. Speculated tokens get scheduled correctly
|
||||
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||
"""
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
req_ids.append(request.request_id)
|
||||
req_to_index[request.request_id] = i
|
||||
|
||||
# Schedule a decode, which will also draft speculative tokens
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
assert output.total_num_scheduled_tokens == len(requests)
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
assert output.num_scheduled_tokens[req_id] == 1
|
||||
assert req_id not in output.scheduled_spec_decode_tokens
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
spec_token_ids=spec_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
|
||||
for i in range(len(requests)):
|
||||
running_req = scheduler.running[i]
|
||||
# The prompt token
|
||||
assert running_req.num_computed_tokens == 1
|
||||
# The prompt token and the sampled token
|
||||
assert running_req.num_tokens == 2
|
||||
# The prompt token, the sampled token, and the speculated tokens
|
||||
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
||||
|
||||
# No draft or accepted tokens counted yet
|
||||
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
|
||||
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
|
||||
assert stats.num_draft_tokens == 0
|
||||
assert stats.num_accepted_tokens == 0
|
||||
|
||||
# Schedule the speculated tokens for validation
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
# The sampled token and speculated tokens
|
||||
assert output.total_num_scheduled_tokens == \
|
||||
len(requests) + sum(len(ids) for ids in spec_tokens)
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
|
||||
if spec_tokens[i]:
|
||||
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
|
||||
len(spec_tokens[i])
|
||||
else:
|
||||
assert req_id not in output.scheduled_spec_decode_tokens
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
|
||||
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
|
||||
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
|
||||
assert stats.num_draft_tokens == expected[0]
|
||||
assert stats.num_accepted_tokens == expected[1]
|
||||
|
@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -552,6 +553,7 @@ class Scheduler(SchedulerInterface):
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
new_running: list[Request] = []
|
||||
@ -584,6 +586,11 @@ class Scheduler(SchedulerInterface):
|
||||
len(generated_token_ids))
|
||||
request.num_computed_tokens -= num_tokens_rejected
|
||||
|
||||
if spec_decoding_stats is not None:
|
||||
spec_decoding_stats.observe(
|
||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||
|
||||
cached_encoder_input_ids = (
|
||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||
@ -657,7 +664,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.running = new_running
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
scheduler_stats=self.make_stats(),
|
||||
scheduler_stats=self.make_stats(spec_decoding_stats),
|
||||
)
|
||||
if self.include_finished_set:
|
||||
#TODO currently sending duplicates here, improve this
|
||||
@ -724,7 +731,10 @@ class Scheduler(SchedulerInterface):
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.kv_cache_manager.reset_prefix_cache()
|
||||
|
||||
def make_stats(self) -> Optional[SchedulerStats]:
|
||||
def make_stats(
|
||||
self,
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None,
|
||||
) -> Optional[SchedulerStats]:
|
||||
if not self.log_stats:
|
||||
return None
|
||||
return SchedulerStats(
|
||||
@ -732,4 +742,5 @@ class Scheduler(SchedulerInterface):
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
gpu_cache_usage=self.kv_cache_manager.usage,
|
||||
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
|
||||
spec_decoding_stats=spec_decoding_stats,
|
||||
)
|
||||
|
@ -12,6 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||
from vllm.v1.engine import FinishReason
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -38,6 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
# Prefix cache metrics. This cannot be reset.
|
||||
# TODO: Make the interval configurable.
|
||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||
self.spec_decoding_metrics = SpecDecodingMetrics()
|
||||
|
||||
def _reset(self, now):
|
||||
self.last_log_time = now
|
||||
@ -65,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_metrics.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
@ -94,6 +100,9 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_metrics.log()
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
|
||||
@ -302,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.labelname_running_lora_adapters,
|
||||
])
|
||||
|
||||
#
|
||||
# Speculative Decoding metrics
|
||||
# The acceptance rate can be calculated using a PromQL query:
|
||||
#
|
||||
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
||||
#
|
||||
self.counter_spec_decode_num_draft_tokens = \
|
||||
prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.counter_spec_decode_num_accepted_tokens = \
|
||||
prometheus_client.Counter(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
#
|
||||
# Cache config info metric
|
||||
#
|
||||
@ -338,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.counter_gpu_prefix_cache_hits.inc(
|
||||
scheduler_stats.prefix_cache_stats.hits)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.counter_spec_decode_num_draft_tokens.inc(
|
||||
scheduler_stats.spec_decoding_stats.num_draft_tokens)
|
||||
self.counter_spec_decode_num_accepted_tokens.inc(
|
||||
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
|
||||
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
|
@ -4,6 +4,8 @@ import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
||||
from vllm.v1.engine.output_processor import RequestState
|
||||
@ -35,6 +37,8 @@ class SchedulerStats:
|
||||
prefix_cache_stats: PrefixCacheStats = field(
|
||||
default_factory=PrefixCacheStats)
|
||||
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAStats:
|
||||
|
59
vllm/v1/spec_decode/metrics.py
Normal file
59
vllm/v1/spec_decode/metrics.py
Normal file
@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodingStats:
|
||||
num_draft_tokens: int = 0
|
||||
num_accepted_tokens: int = 0
|
||||
|
||||
def take(self):
|
||||
copied = SpecDecodingStats(self.num_draft_tokens,
|
||||
self.num_accepted_tokens)
|
||||
self.reset()
|
||||
return copied
|
||||
|
||||
def reset(self):
|
||||
self.num_draft_tokens = 0
|
||||
self.num_accepted_tokens = 0
|
||||
|
||||
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
|
||||
self.num_draft_tokens += num_draft_tokens
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
|
||||
|
||||
class SpecDecodingMetrics:
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.num_draft_tokens: list[int] = []
|
||||
self.num_accepted_tokens: list[int] = []
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
||||
self.num_accepted_tokens.append(
|
||||
spec_decoding_stats.num_accepted_tokens)
|
||||
|
||||
def log(self):
|
||||
num_draft_tokens = np.sum(self.num_draft_tokens)
|
||||
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
||||
|
||||
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
|
||||
if num_draft_tokens > 0 else float("nan"))
|
||||
|
||||
logger.info(
|
||||
"Speculative metrics: "
|
||||
"Draft acceptance rate: %.3f, "
|
||||
"Number of accepted tokens: %d, "
|
||||
"Number of draft tokens: %d, ", draft_acceptance_rate,
|
||||
num_accepted_tokens, num_draft_tokens)
|
||||
self.reset()
|
Loading…
x
Reference in New Issue
Block a user