[V1][Spec Decode] Avoid logging useless nan metrics (#16023)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-04-04 16:52:41 +01:00 committed by GitHub
parent 4ef0bb1fcf
commit a35a8a8392
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 14 deletions

View File

@ -671,10 +671,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == 0
assert stats.num_accepted_tokens == 0
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None
# Schedule the speculated tokens for validation
output = scheduler.schedule()
@ -702,7 +699,11 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1]
scheduler_stats = engine_core_outputs.scheduler_stats
if expected[0] == 0:
assert scheduler_stats.spec_decoding_stats is None
else:
assert scheduler_stats.spec_decoding_stats is not None
stats = scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1]

View File

@ -553,11 +553,11 @@ class Scheduler(SchedulerInterface):
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
@ -585,11 +585,10 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
if spec_decoding_stats is not None:
spec_decoding_stats.observe(
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
@ -744,3 +743,17 @@ class Scheduler(SchedulerInterface):
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
spec_decoding_stats=spec_decoding_stats,
)
def make_spec_decoding_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats],
num_draft_tokens: int,
num_accepted_tokens: int,
) -> Optional[SpecDecodingStats]:
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats()
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
return spec_decoding_stats