[Bugfix] EAGLE output norm bug (#14464)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
parent
ee3778d5fc
commit
9ed6ee92d6
@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub:
|
|||||||
## Speculating using EAGLE based draft models
|
## Speculating using EAGLE based draft models
|
||||||
|
|
||||||
The following code configures vLLM to use speculative decoding where proposals are generated by
|
The following code configures vLLM to use speculative decoding where proposals are generated by
|
||||||
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model.
|
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](<gh-file:examples/offline_inference/eagle.py>).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
93
examples/offline_inference/eagle.py
Normal file
93
examples/offline_inference/eagle.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="./examples/data/gsm8k.jsonl",
|
||||||
|
help="downloaded from the eagle repo " \
|
||||||
|
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
|
||||||
|
)
|
||||||
|
parser.add_argument("--max_num_seqs", type=int, default=8)
|
||||||
|
parser.add_argument("--num_prompts", type=int, default=80)
|
||||||
|
parser.add_argument("--num_spec_tokens", type=int, default=2)
|
||||||
|
parser.add_argument("--tp", type=int, default=1)
|
||||||
|
parser.add_argument("--draft_tp", type=int, default=1)
|
||||||
|
parser.add_argument("--enforce_eager", action='store_true')
|
||||||
|
parser.add_argument("--enable_chunked_prefill", action='store_true')
|
||||||
|
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
|
||||||
|
parser.add_argument("--temp", type=float, default=0)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
|
||||||
|
|
||||||
|
max_model_len = 2048
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
|
||||||
|
if os.path.exists(args.dataset):
|
||||||
|
prompts = []
|
||||||
|
num_prompts = args.num_prompts
|
||||||
|
with open(args.dataset) as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
prompts.append(data["turns"][0])
|
||||||
|
else:
|
||||||
|
prompts = ["The future of AI is", "The president of the United States is"]
|
||||||
|
|
||||||
|
prompts = prompts[:args.num_prompts]
|
||||||
|
num_prompts = len(prompts)
|
||||||
|
|
||||||
|
prompt_ids = [
|
||||||
|
tokenizer.apply_chat_template([{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}],
|
||||||
|
add_generation_prompt=True)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=args.tp,
|
||||||
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||||
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
|
enforce_eager=args.enforce_eager,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_num_seqs=args.max_num_seqs,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
speculative_model=eagle_dir,
|
||||||
|
num_speculative_tokens=args.num_spec_tokens,
|
||||||
|
speculative_draft_tensor_parallel_size=args.draft_tp,
|
||||||
|
speculative_max_model_len=max_model_len,
|
||||||
|
disable_log_stats=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompt_token_ids=prompt_ids,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
|
||||||
|
# calculate the average number of accepted tokens per forward pass, +1 is
|
||||||
|
# to account for the token from the target model that's always going to be
|
||||||
|
# accepted
|
||||||
|
acceptance_counts = [0] * (args.num_spec_tokens + 1)
|
||||||
|
for output in outputs:
|
||||||
|
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
|
||||||
|
acceptance_counts[step] += count
|
||||||
|
|
||||||
|
print(f"mean acceptance length: \
|
||||||
|
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
|
@ -853,6 +853,10 @@ class LLMEngine:
|
|||||||
self.generation_config_fields, seq.eos_token_id)
|
self.generation_config_fields, seq.eos_token_id)
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
|
draft_size = 1
|
||||||
|
if self.vllm_config.speculative_config is not None:
|
||||||
|
draft_size = \
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens + 1
|
||||||
seq_group = SequenceGroup(
|
seq_group = SequenceGroup(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
seqs=[seq],
|
seqs=[seq],
|
||||||
@ -862,7 +866,8 @@ class LLMEngine:
|
|||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
encoder_seq=encoder_seq,
|
encoder_seq=encoder_seq,
|
||||||
priority=priority)
|
priority=priority,
|
||||||
|
draft_size=draft_size)
|
||||||
|
|
||||||
return seq_group
|
return seq_group
|
||||||
|
|
||||||
|
@ -100,6 +100,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
seqs = sequence_group.get_seqs(
|
seqs = sequence_group.get_seqs(
|
||||||
status=SequenceStatus.FINISHED_ABORTED)
|
status=SequenceStatus.FINISHED_ABORTED)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
|
||||||
|
sequence_group.metrics.spec_token_acceptance_counts[
|
||||||
|
output.step_index] += 1
|
||||||
|
|
||||||
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
|
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
|
||||||
assert len(seqs) == 1, (
|
assert len(seqs) == 1, (
|
||||||
"Beam search not supported in multi-step decoding.")
|
"Beam search not supported in multi-step decoding.")
|
||||||
|
@ -38,7 +38,7 @@ class DummyOutputNorm(nn.Module):
|
|||||||
if residual is None:
|
if residual is None:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return x, residual
|
return x + residual, None
|
||||||
|
|
||||||
|
|
||||||
class EAGLE(nn.Module):
|
class EAGLE(nn.Module):
|
||||||
|
@ -111,6 +111,13 @@ class RequestMetrics:
|
|||||||
model_execute_time: The time spent in the model execute function. This
|
model_execute_time: The time spent in the model execute function. This
|
||||||
will include model forward, block/sync across
|
will include model forward, block/sync across
|
||||||
workers, cpu-gpu sync time and sampling time.
|
workers, cpu-gpu sync time and sampling time.
|
||||||
|
spec_token_acceptance_counts: number of accepted speculative tokens at
|
||||||
|
each position; the first token is from
|
||||||
|
the target model and is always accepted;
|
||||||
|
e.g., when it's [10, 8, 4, 2] for a req,
|
||||||
|
it means there were 10 forward passes in
|
||||||
|
total, and there were 8, 4, 2 accepted
|
||||||
|
tokens at 1st, 2nd, 3rd speculation step.
|
||||||
"""
|
"""
|
||||||
arrival_time: float
|
arrival_time: float
|
||||||
last_token_time: float
|
last_token_time: float
|
||||||
@ -121,6 +128,7 @@ class RequestMetrics:
|
|||||||
scheduler_time: Optional[float] = None
|
scheduler_time: Optional[float] = None
|
||||||
model_forward_time: Optional[float] = None
|
model_forward_time: Optional[float] = None
|
||||||
model_execute_time: Optional[float] = None
|
model_execute_time: Optional[float] = None
|
||||||
|
spec_token_acceptance_counts: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
class SequenceDataDelta(
|
class SequenceDataDelta(
|
||||||
@ -639,22 +647,25 @@ class SequenceGroup:
|
|||||||
trace_headers: OpenTelemetry trace headers.
|
trace_headers: OpenTelemetry trace headers.
|
||||||
prompt_adapter_request: Prompt Adapter request.
|
prompt_adapter_request: Prompt Adapter request.
|
||||||
priority: User-defined priority of the request.
|
priority: User-defined priority of the request.
|
||||||
|
draft_size: The number of speculative tokens plus one from the target
|
||||||
|
model; equal to max number of tokens a step can generate
|
||||||
|
for single-draft speculative decoding but larger than
|
||||||
|
that for multi-draft SD (currently not supported).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
request_id: str,
|
||||||
request_id: str,
|
seqs: list[Sequence],
|
||||||
seqs: list[Sequence],
|
arrival_time: float,
|
||||||
arrival_time: float,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
pooling_params: Optional[PoolingParams] = None,
|
||||||
pooling_params: Optional[PoolingParams] = None,
|
pooled_data: Optional[torch.Tensor] = None,
|
||||||
pooled_data: Optional[torch.Tensor] = None,
|
encoder_seq: Optional[Sequence] = None,
|
||||||
encoder_seq: Optional[Sequence] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
priority: int = 0,
|
||||||
priority: int = 0,
|
draft_size: int = 1) -> None:
|
||||||
) -> None:
|
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs = seqs
|
self.seqs = seqs
|
||||||
self.first_seq = seqs[0]
|
self.first_seq = seqs[0]
|
||||||
@ -667,7 +678,9 @@ class SequenceGroup:
|
|||||||
last_token_time=arrival_time,
|
last_token_time=arrival_time,
|
||||||
first_scheduled_time=None,
|
first_scheduled_time=None,
|
||||||
first_token_time=None,
|
first_token_time=None,
|
||||||
time_in_queue=None)
|
time_in_queue=None,
|
||||||
|
spec_token_acceptance_counts=[0] *
|
||||||
|
draft_size)
|
||||||
self.last_token_latency = 0.0
|
self.last_token_latency = 0.0
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput(
|
|||||||
samples: list[SequenceOutput]
|
samples: list[SequenceOutput]
|
||||||
# Prompt logprob for each prompt query token.
|
# Prompt logprob for each prompt query token.
|
||||||
prompt_logprobs: Optional[PromptLogprobs]
|
prompt_logprobs: Optional[PromptLogprobs]
|
||||||
|
step_index: Optional[int] = 0
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
||||||
|
@ -1080,7 +1080,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
|
|||||||
[sequence_index][:num_logprobs],
|
[sequence_index][:num_logprobs],
|
||||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||||
[sequence_index][:num_logprobs],
|
[sequence_index][:num_logprobs],
|
||||||
))
|
step_index=step_index))
|
||||||
sampler_output_list.append(
|
sampler_output_list.append(
|
||||||
SamplerOutput(outputs=step_output_token_ids))
|
SamplerOutput(outputs=step_output_token_ids))
|
||||||
|
|
||||||
|
@ -93,14 +93,14 @@ def create_logprobs_output(
|
|||||||
|
|
||||||
|
|
||||||
def create_sequence_group_output(
|
def create_sequence_group_output(
|
||||||
token_id: int,
|
token_id: int,
|
||||||
token_id_logprob_rank: int,
|
token_id_logprob_rank: int,
|
||||||
token_id_logprob: float,
|
token_id_logprob: float,
|
||||||
seq_id: SeqId,
|
seq_id: SeqId,
|
||||||
topk_token_ids: List[Optional[int]],
|
topk_token_ids: List[Optional[int]],
|
||||||
topk_logprobs: List[Optional[float]],
|
topk_logprobs: List[Optional[float]],
|
||||||
prompt_logprobs: Optional[PromptLogprobs] = None,
|
prompt_logprobs: Optional[PromptLogprobs] = None,
|
||||||
) -> CompletionSequenceGroupOutput:
|
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
|
||||||
"""Create a SequenceGroupOutput given the sampling results.
|
"""Create a SequenceGroupOutput given the sampling results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -110,6 +110,7 @@ def create_sequence_group_output(
|
|||||||
seq_id (int): The sequence id.
|
seq_id (int): The sequence id.
|
||||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||||
|
step_index: (Optional[int]): The index of the speculative token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logprobs = create_logprobs_output(
|
logprobs = create_logprobs_output(
|
||||||
@ -120,14 +121,13 @@ def create_sequence_group_output(
|
|||||||
topk_logprobs,
|
topk_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CompletionSequenceGroupOutput(
|
return CompletionSequenceGroupOutput(samples=[
|
||||||
samples=[
|
SequenceOutput(parent_seq_id=seq_id,
|
||||||
SequenceOutput(parent_seq_id=seq_id,
|
output_token=token_id,
|
||||||
output_token=token_id,
|
logprobs=logprobs)
|
||||||
logprobs=logprobs)
|
],
|
||||||
],
|
prompt_logprobs=prompt_logprobs,
|
||||||
prompt_logprobs=prompt_logprobs,
|
step_index=step_index)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def split_batch_by_proposal_len(
|
def split_batch_by_proposal_len(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user