[Misc] refactor example eagle (#16100)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid 2025-04-06 17:42:48 +08:00 committed by GitHub
parent 9ca710e525
commit b6c502a150
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,8 +7,28 @@ from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
def load_prompts(dataset_path, num_prompts):
if os.path.exists(dataset_path):
prompts = []
try:
with open(dataset_path) as f:
for line in f:
data = json.loads(line)
prompts.append(data["turns"][0])
except Exception as e:
print(f"Error reading dataset: {e}")
return []
else:
prompts = [
"The future of AI is", "The president of the United States is"
]
return prompts[:num_prompts]
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
@ -25,11 +45,8 @@ parser.add_argument("--enforce_eager", action='store_true')
parser.add_argument("--enable_chunked_prefill", action='store_true')
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
args = parser.parse_args()
print(args)
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
@ -37,18 +54,7 @@ max_model_len = 2048
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if os.path.exists(args.dataset):
prompts = []
num_prompts = args.num_prompts
with open(args.dataset) as f:
for line in f:
data = json.loads(line)
prompts.append(data["turns"][0])
else:
prompts = ["The future of AI is", "The president of the United States is"]
prompts = prompts[:args.num_prompts]
num_prompts = len(prompts)
prompts = load_prompts(args.dataset, args.num_prompts)
prompt_ids = [
tokenizer.apply_chat_template([{
@ -88,8 +94,15 @@ outputs = llm.generate(prompt_token_ids=prompt_ids,
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
for step, count in enumerate(
output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count
print("-" * 50)
print(f"mean acceptance length: \
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
print("-" * 50)
if __name__ == "__main__":
main()