[Bugfix] EAGLE output norm bug (#14464)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Bryan Lu 2025-03-14 23:50:33 -07:00 committed by GitHub
parent ee3778d5fc
commit 9ed6ee92d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 152 additions and 35 deletions

View File

@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub:
## Speculating using EAGLE based draft models
The following code configures vLLM to use speculative decoding where proposals are generated by
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model.
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](<gh-file:examples/offline_inference/eagle.py>).
```python
from vllm import LLM, SamplingParams

View File

@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="./examples/data/gsm8k.jsonl",
help="downloaded from the eagle repo " \
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
)
parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--draft_tp", type=int, default=1)
parser.add_argument("--enforce_eager", action='store_true')
parser.add_argument("--enable_chunked_prefill", action='store_true')
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
args = parser.parse_args()
print(args)
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
max_model_len = 2048
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if os.path.exists(args.dataset):
prompts = []
num_prompts = args.num_prompts
with open(args.dataset) as f:
for line in f:
data = json.loads(line)
prompts.append(data["turns"][0])
else:
prompts = ["The future of AI is", "The president of the United States is"]
prompts = prompts[:args.num_prompts]
num_prompts = len(prompts)
prompt_ids = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True)
for prompt in prompts
]
llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
enforce_eager=args.enforce_eager,
max_model_len=max_model_len,
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_model=eagle_dir,
num_speculative_tokens=args.num_spec_tokens,
speculative_draft_tensor_parallel_size=args.draft_tp,
speculative_max_model_len=max_model_len,
disable_log_stats=False,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)
# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count
print(f"mean acceptance length: \
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")

View File

@ -853,6 +853,10 @@ class LLMEngine:
self.generation_config_fields, seq.eos_token_id)
# Create the sequence group.
draft_size = 1
if self.vllm_config.speculative_config is not None:
draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
@ -862,7 +866,8 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
priority=priority,
draft_size=draft_size)
return seq_group

View File

@ -100,6 +100,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED)
for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")

View File

@ -38,7 +38,7 @@ class DummyOutputNorm(nn.Module):
if residual is None:
return x
else:
return x, residual
return x + residual, None
class EAGLE(nn.Module):

View File

@ -111,6 +111,13 @@ class RequestMetrics:
model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
"""
arrival_time: float
last_token_time: float
@ -121,6 +128,7 @@ class RequestMetrics:
scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None
spec_token_acceptance_counts: Optional[list[int]] = None
class SequenceDataDelta(
@ -639,10 +647,13 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported).
"""
def __init__(
self,
def __init__(self,
request_id: str,
seqs: list[Sequence],
arrival_time: float,
@ -654,7 +665,7 @@ class SequenceGroup:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
draft_size: int = 1) -> None:
self.request_id = request_id
self.seqs = seqs
self.first_seq = seqs[0]
@ -667,7 +678,9 @@ class SequenceGroup:
last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
time_in_queue=None,
spec_token_acceptance_counts=[0] *
draft_size)
self.last_token_latency = 0.0
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput(
samples: list[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
step_index: Optional[int] = 0
def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "

View File

@ -1080,7 +1080,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
step_index=step_index))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))

View File

@ -100,7 +100,7 @@ def create_sequence_group_output(
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None,
) -> CompletionSequenceGroupOutput:
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
@ -110,6 +110,7 @@ def create_sequence_group_output(
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
step_index: (Optional[int]): The index of the speculative token.
"""
logprobs = create_logprobs_output(
@ -120,14 +121,13 @@ def create_sequence_group_output(
topk_logprobs,
)
return CompletionSequenceGroupOutput(
samples=[
return CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
prompt_logprobs=prompt_logprobs,
)
step_index=step_index)
def split_batch_by_proposal_len(