[Bugfix] EAGLE output norm bug (#14464)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Bryan Lu 2025-03-14 23:50:33 -07:00 committed by GitHub
parent ee3778d5fc
commit 9ed6ee92d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 152 additions and 35 deletions

View File

@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub:
## Speculating using EAGLE based draft models ## Speculating using EAGLE based draft models
The following code configures vLLM to use speculative decoding where proposals are generated by The following code configures vLLM to use speculative decoding where proposals are generated by
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](<gh-file:examples/offline_inference/eagle.py>).
```python ```python
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams

View File

@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="./examples/data/gsm8k.jsonl",
help="downloaded from the eagle repo " \
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
)
parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--draft_tp", type=int, default=1)
parser.add_argument("--enforce_eager", action='store_true')
parser.add_argument("--enable_chunked_prefill", action='store_true')
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
args = parser.parse_args()
print(args)
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
max_model_len = 2048
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if os.path.exists(args.dataset):
prompts = []
num_prompts = args.num_prompts
with open(args.dataset) as f:
for line in f:
data = json.loads(line)
prompts.append(data["turns"][0])
else:
prompts = ["The future of AI is", "The president of the United States is"]
prompts = prompts[:args.num_prompts]
num_prompts = len(prompts)
prompt_ids = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True)
for prompt in prompts
]
llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
enforce_eager=args.enforce_eager,
max_model_len=max_model_len,
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_model=eagle_dir,
num_speculative_tokens=args.num_spec_tokens,
speculative_draft_tensor_parallel_size=args.draft_tp,
speculative_max_model_len=max_model_len,
disable_log_stats=False,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)
# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count
print(f"mean acceptance length: \
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")

View File

@ -853,6 +853,10 @@ class LLMEngine:
self.generation_config_fields, seq.eos_token_id) self.generation_config_fields, seq.eos_token_id)
# Create the sequence group. # Create the sequence group.
draft_size = 1
if self.vllm_config.speculative_config is not None:
draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup( seq_group = SequenceGroup(
request_id=request_id, request_id=request_id,
seqs=[seq], seqs=[seq],
@ -862,7 +866,8 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority,
draft_size=draft_size)
return seq_group return seq_group

View File

@ -100,6 +100,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
seqs = sequence_group.get_seqs( seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED) status=SequenceStatus.FINISHED_ABORTED)
for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, ( assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.") "Beam search not supported in multi-step decoding.")

View File

@ -38,7 +38,7 @@ class DummyOutputNorm(nn.Module):
if residual is None: if residual is None:
return x return x
else: else:
return x, residual return x + residual, None
class EAGLE(nn.Module): class EAGLE(nn.Module):

View File

@ -111,6 +111,13 @@ class RequestMetrics:
model_execute_time: The time spent in the model execute function. This model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time. workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
""" """
arrival_time: float arrival_time: float
last_token_time: float last_token_time: float
@ -121,6 +128,7 @@ class RequestMetrics:
scheduler_time: Optional[float] = None scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
spec_token_acceptance_counts: Optional[list[int]] = None
class SequenceDataDelta( class SequenceDataDelta(
@ -639,22 +647,25 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request. priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported).
""" """
def __init__( def __init__(self,
self, request_id: str,
request_id: str, seqs: list[Sequence],
seqs: list[Sequence], arrival_time: float,
arrival_time: float, sampling_params: Optional[SamplingParams] = None,
sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None,
lora_request: Optional[LoRARequest] = None, pooling_params: Optional[PoolingParams] = None,
pooling_params: Optional[PoolingParams] = None, pooled_data: Optional[torch.Tensor] = None,
pooled_data: Optional[torch.Tensor] = None, encoder_seq: Optional[Sequence] = None,
encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0,
priority: int = 0, draft_size: int = 1) -> None:
) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs = seqs self.seqs = seqs
self.first_seq = seqs[0] self.first_seq = seqs[0]
@ -667,7 +678,9 @@ class SequenceGroup:
last_token_time=arrival_time, last_token_time=arrival_time,
first_scheduled_time=None, first_scheduled_time=None,
first_token_time=None, first_token_time=None,
time_in_queue=None) time_in_queue=None,
spec_token_acceptance_counts=[0] *
draft_size)
self.last_token_latency = 0.0 self.last_token_latency = 0.0
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput(
samples: list[SequenceOutput] samples: list[SequenceOutput]
# Prompt logprob for each prompt query token. # Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs] prompt_logprobs: Optional[PromptLogprobs]
step_index: Optional[int] = 0
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, " return (f"CompletionSequenceGroupOutput(samples={self.samples}, "

View File

@ -1080,7 +1080,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
[sequence_index][:num_logprobs], [sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index] topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs], [sequence_index][:num_logprobs],
)) step_index=step_index))
sampler_output_list.append( sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids)) SamplerOutput(outputs=step_output_token_ids))

View File

@ -93,14 +93,14 @@ def create_logprobs_output(
def create_sequence_group_output( def create_sequence_group_output(
token_id: int, token_id: int,
token_id_logprob_rank: int, token_id_logprob_rank: int,
token_id_logprob: float, token_id_logprob: float,
seq_id: SeqId, seq_id: SeqId,
topk_token_ids: List[Optional[int]], topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]], topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None, prompt_logprobs: Optional[PromptLogprobs] = None,
) -> CompletionSequenceGroupOutput: step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results. """Create a SequenceGroupOutput given the sampling results.
Args: Args:
@ -110,6 +110,7 @@ def create_sequence_group_output(
seq_id (int): The sequence id. seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs. topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
step_index: (Optional[int]): The index of the speculative token.
""" """
logprobs = create_logprobs_output( logprobs = create_logprobs_output(
@ -120,14 +121,13 @@ def create_sequence_group_output(
topk_logprobs, topk_logprobs,
) )
return CompletionSequenceGroupOutput( return CompletionSequenceGroupOutput(samples=[
samples=[ SequenceOutput(parent_seq_id=seq_id,
SequenceOutput(parent_seq_id=seq_id, output_token=token_id,
output_token=token_id, logprobs=logprobs)
logprobs=logprobs) ],
], prompt_logprobs=prompt_logprobs,
prompt_logprobs=prompt_logprobs, step_index=step_index)
)
def split_batch_by_proposal_len( def split_batch_by_proposal_len(