[V1][Metrics] Initial speculative decoding metrics (#15151)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
7e3f7a4ee7
commit
a79cc68b3a
@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
|||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||||
|
|
||||||
|
|
||||||
|
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spec_tokens,output_tokens,expected",
|
||||||
|
[
|
||||||
|
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
|
||||||
|
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
|
||||||
|
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
|
||||||
|
([[1]], [[1, 2]], (1, 1)), # single token sequence
|
||||||
|
([[]], [[5]], (0, 0)), # empty sequence
|
||||||
|
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
|
||||||
|
(6, 3)), # multiple mismatches
|
||||||
|
])
|
||||||
|
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||||
|
"""Test scheduling behavior with speculative decoding.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. Speculated tokens get scheduled correctly
|
||||||
|
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||||
|
"""
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
||||||
|
req_ids = []
|
||||||
|
req_to_index = {}
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
req_ids.append(request.request_id)
|
||||||
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
# Schedule a decode, which will also draft speculative tokens
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
assert output.total_num_scheduled_tokens == len(requests)
|
||||||
|
for i in range(len(requests)):
|
||||||
|
req_id = requests[i].request_id
|
||||||
|
assert output.num_scheduled_tokens[req_id] == 1
|
||||||
|
assert req_id not in output.scheduled_spec_decode_tokens
|
||||||
|
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||||
|
spec_token_ids=spec_tokens,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
|
model_runner_output)
|
||||||
|
|
||||||
|
for i in range(len(requests)):
|
||||||
|
running_req = scheduler.running[i]
|
||||||
|
# The prompt token
|
||||||
|
assert running_req.num_computed_tokens == 1
|
||||||
|
# The prompt token and the sampled token
|
||||||
|
assert running_req.num_tokens == 2
|
||||||
|
# The prompt token, the sampled token, and the speculated tokens
|
||||||
|
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
||||||
|
|
||||||
|
# No draft or accepted tokens counted yet
|
||||||
|
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
|
||||||
|
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
|
||||||
|
assert stats.num_draft_tokens == 0
|
||||||
|
assert stats.num_accepted_tokens == 0
|
||||||
|
|
||||||
|
# Schedule the speculated tokens for validation
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
|
# The sampled token and speculated tokens
|
||||||
|
assert output.total_num_scheduled_tokens == \
|
||||||
|
len(requests) + sum(len(ids) for ids in spec_tokens)
|
||||||
|
for i in range(len(requests)):
|
||||||
|
req_id = requests[i].request_id
|
||||||
|
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
|
||||||
|
if spec_tokens[i]:
|
||||||
|
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
|
||||||
|
len(spec_tokens[i])
|
||||||
|
else:
|
||||||
|
assert req_id not in output.scheduled_spec_decode_tokens
|
||||||
|
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=output_tokens,
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
|
model_runner_output)
|
||||||
|
|
||||||
|
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
|
||||||
|
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
|
||||||
|
assert stats.num_draft_tokens == expected[0]
|
||||||
|
assert stats.num_accepted_tokens == expected[1]
|
||||||
|
@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
|
|||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -552,6 +553,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
spec_token_ids = model_runner_output.spec_token_ids
|
spec_token_ids = model_runner_output.spec_token_ids
|
||||||
logprobs = model_runner_output.logprobs
|
logprobs = model_runner_output.logprobs
|
||||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||||
|
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
|
||||||
new_running: list[Request] = []
|
new_running: list[Request] = []
|
||||||
@ -584,6 +586,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
len(generated_token_ids))
|
len(generated_token_ids))
|
||||||
request.num_computed_tokens -= num_tokens_rejected
|
request.num_computed_tokens -= num_tokens_rejected
|
||||||
|
|
||||||
|
if spec_decoding_stats is not None:
|
||||||
|
spec_decoding_stats.observe(
|
||||||
|
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||||
|
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||||
|
|
||||||
cached_encoder_input_ids = (
|
cached_encoder_input_ids = (
|
||||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||||
@ -657,7 +664,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.running = new_running
|
self.running = new_running
|
||||||
engine_core_outputs = EngineCoreOutputs(
|
engine_core_outputs = EngineCoreOutputs(
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
scheduler_stats=self.make_stats(),
|
scheduler_stats=self.make_stats(spec_decoding_stats),
|
||||||
)
|
)
|
||||||
if self.include_finished_set:
|
if self.include_finished_set:
|
||||||
#TODO currently sending duplicates here, improve this
|
#TODO currently sending duplicates here, improve this
|
||||||
@ -724,7 +731,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
def reset_prefix_cache(self) -> bool:
|
def reset_prefix_cache(self) -> bool:
|
||||||
return self.kv_cache_manager.reset_prefix_cache()
|
return self.kv_cache_manager.reset_prefix_cache()
|
||||||
|
|
||||||
def make_stats(self) -> Optional[SchedulerStats]:
|
def make_stats(
|
||||||
|
self,
|
||||||
|
spec_decoding_stats: Optional[SpecDecodingStats] = None,
|
||||||
|
) -> Optional[SchedulerStats]:
|
||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return None
|
return None
|
||||||
return SchedulerStats(
|
return SchedulerStats(
|
||||||
@ -732,4 +742,5 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_waiting_reqs=len(self.waiting),
|
num_waiting_reqs=len(self.waiting),
|
||||||
gpu_cache_usage=self.kv_cache_manager.usage,
|
gpu_cache_usage=self.kv_cache_manager.usage,
|
||||||
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
|
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
|
||||||
|
spec_decoding_stats=spec_decoding_stats,
|
||||||
)
|
)
|
||||||
|
@ -12,6 +12,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||||
from vllm.v1.engine import FinishReason
|
from vllm.v1.engine import FinishReason
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
|
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
# Prefix cache metrics. This cannot be reset.
|
# Prefix cache metrics. This cannot be reset.
|
||||||
# TODO: Make the interval configurable.
|
# TODO: Make the interval configurable.
|
||||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||||
|
self.spec_decoding_metrics = SpecDecodingMetrics()
|
||||||
|
|
||||||
def _reset(self, now):
|
def _reset(self, now):
|
||||||
self.last_log_time = now
|
self.last_log_time = now
|
||||||
@ -65,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
|
|
||||||
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
|
||||||
|
|
||||||
|
if scheduler_stats.spec_decoding_stats is not None:
|
||||||
|
self.spec_decoding_metrics.observe(
|
||||||
|
scheduler_stats.spec_decoding_stats)
|
||||||
|
|
||||||
self.last_scheduler_stats = scheduler_stats
|
self.last_scheduler_stats = scheduler_stats
|
||||||
|
|
||||||
def log(self):
|
def log(self):
|
||||||
@ -94,6 +100,9 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.prefix_caching_metrics.hit_rate * 100,
|
self.prefix_caching_metrics.hit_rate * 100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if scheduler_stats.spec_decoding_stats is not None:
|
||||||
|
self.spec_decoding_metrics.log()
|
||||||
|
|
||||||
|
|
||||||
class PrometheusStatLogger(StatLoggerBase):
|
class PrometheusStatLogger(StatLoggerBase):
|
||||||
|
|
||||||
@ -302,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.labelname_running_lora_adapters,
|
self.labelname_running_lora_adapters,
|
||||||
])
|
])
|
||||||
|
|
||||||
|
#
|
||||||
|
# Speculative Decoding metrics
|
||||||
|
# The acceptance rate can be calculated using a PromQL query:
|
||||||
|
#
|
||||||
|
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||||
|
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
||||||
|
#
|
||||||
|
self.counter_spec_decode_num_draft_tokens = \
|
||||||
|
prometheus_client.Counter(
|
||||||
|
name="vllm:spec_decode_num_draft_tokens_total",
|
||||||
|
documentation="Number of draft tokens.",
|
||||||
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
self.counter_spec_decode_num_accepted_tokens = \
|
||||||
|
prometheus_client.Counter(
|
||||||
|
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||||
|
documentation="Number of accepted tokens.",
|
||||||
|
labelnames=labelnames).labels(*labelvalues)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Cache config info metric
|
# Cache config info metric
|
||||||
#
|
#
|
||||||
@ -338,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
self.counter_gpu_prefix_cache_hits.inc(
|
self.counter_gpu_prefix_cache_hits.inc(
|
||||||
scheduler_stats.prefix_cache_stats.hits)
|
scheduler_stats.prefix_cache_stats.hits)
|
||||||
|
|
||||||
|
if scheduler_stats.spec_decoding_stats is not None:
|
||||||
|
self.counter_spec_decode_num_draft_tokens.inc(
|
||||||
|
scheduler_stats.spec_decoding_stats.num_draft_tokens)
|
||||||
|
self.counter_spec_decode_num_accepted_tokens.inc(
|
||||||
|
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
|
||||||
|
|
||||||
if iteration_stats is None:
|
if iteration_stats is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@ import time
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
||||||
from vllm.v1.engine.output_processor import RequestState
|
from vllm.v1.engine.output_processor import RequestState
|
||||||
@ -35,6 +37,8 @@ class SchedulerStats:
|
|||||||
prefix_cache_stats: PrefixCacheStats = field(
|
prefix_cache_stats: PrefixCacheStats = field(
|
||||||
default_factory=PrefixCacheStats)
|
default_factory=PrefixCacheStats)
|
||||||
|
|
||||||
|
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAStats:
|
class LoRAStats:
|
||||||
|
59
vllm/v1/spec_decode/metrics.py
Normal file
59
vllm/v1/spec_decode/metrics.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpecDecodingStats:
|
||||||
|
num_draft_tokens: int = 0
|
||||||
|
num_accepted_tokens: int = 0
|
||||||
|
|
||||||
|
def take(self):
|
||||||
|
copied = SpecDecodingStats(self.num_draft_tokens,
|
||||||
|
self.num_accepted_tokens)
|
||||||
|
self.reset()
|
||||||
|
return copied
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.num_draft_tokens = 0
|
||||||
|
self.num_accepted_tokens = 0
|
||||||
|
|
||||||
|
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
|
||||||
|
self.num_draft_tokens += num_draft_tokens
|
||||||
|
self.num_accepted_tokens += num_accepted_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class SpecDecodingMetrics:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.num_draft_tokens: list[int] = []
|
||||||
|
self.num_accepted_tokens: list[int] = []
|
||||||
|
|
||||||
|
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||||
|
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
||||||
|
self.num_accepted_tokens.append(
|
||||||
|
spec_decoding_stats.num_accepted_tokens)
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
num_draft_tokens = np.sum(self.num_draft_tokens)
|
||||||
|
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
||||||
|
|
||||||
|
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
|
||||||
|
if num_draft_tokens > 0 else float("nan"))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Speculative metrics: "
|
||||||
|
"Draft acceptance rate: %.3f, "
|
||||||
|
"Number of accepted tokens: %d, "
|
||||||
|
"Number of draft tokens: %d, ", draft_acceptance_rate,
|
||||||
|
num_accepted_tokens, num_draft_tokens)
|
||||||
|
self.reset()
|
Loading…
x
Reference in New Issue
Block a user