2025-02-15 18:05:11 -08:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-03-17 11:35:57 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-03-12 01:12:41 -04:00
|
|
|
import random
|
2025-03-17 11:35:57 +08:00
|
|
|
from typing import Any
|
2025-03-12 01:12:41 -04:00
|
|
|
|
2025-02-15 18:05:11 -08:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm import LLM, SamplingParams
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def test_prompts():
|
2025-03-12 01:12:41 -04:00
|
|
|
prompt_types = ["repeat", "sentence"]
|
|
|
|
num_prompts = 100
|
|
|
|
prompts = []
|
|
|
|
|
|
|
|
random.seed(0)
|
|
|
|
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
|
|
|
|
|
|
|
# Generate a mixed batch of prompts, some of which can be easily
|
|
|
|
# predicted by n-gram matching and some which likely cannot.
|
|
|
|
for kind in random_prompt_type_choices:
|
|
|
|
word_choices = ["test", "temp", "hello", "where"]
|
|
|
|
word = random.choice(word_choices)
|
|
|
|
if kind == "repeat":
|
|
|
|
prompt = f"""
|
|
|
|
please repeat the word '{word}' 10 times.
|
|
|
|
give no other output than the word at least ten times in a row,
|
|
|
|
in lowercase with spaces between each word and without quotes.
|
|
|
|
"""
|
|
|
|
elif kind == "sentence":
|
|
|
|
prompt = f"""
|
|
|
|
please give a ten-word sentence that
|
|
|
|
uses the word {word} at least once.
|
|
|
|
give no other output than that simple sentence without quotes.
|
|
|
|
"""
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown prompt type: {kind}")
|
|
|
|
prompts.append([{"role": "user", "content": prompt}])
|
|
|
|
|
|
|
|
return prompts
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def sampling_config():
|
|
|
|
# Only support greedy for now
|
2025-03-12 01:12:41 -04:00
|
|
|
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
|
2025-02-15 18:05:11 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def model_name():
|
|
|
|
return "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
|
|
|
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
def test_ngram_correctness(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
test_prompts: list[list[dict[str, Any]]],
|
|
|
|
sampling_config: SamplingParams,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-02-15 18:05:11 -08:00
|
|
|
'''
|
|
|
|
Compare the outputs of a original LLM and a speculative LLM
|
|
|
|
should be the same when using ngram speculative decoding.
|
|
|
|
'''
|
|
|
|
with monkeypatch.context() as m:
|
|
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
|
|
|
2025-03-12 01:12:41 -04:00
|
|
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
|
|
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
2025-02-15 18:05:11 -08:00
|
|
|
del ref_llm
|
|
|
|
|
|
|
|
spec_llm = LLM(model=model_name,
|
|
|
|
speculative_model='[ngram]',
|
|
|
|
ngram_prompt_lookup_max=5,
|
|
|
|
ngram_prompt_lookup_min=3,
|
2025-03-12 01:12:41 -04:00
|
|
|
num_speculative_tokens=3,
|
|
|
|
max_model_len=1024)
|
|
|
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
|
|
|
matches = 0
|
|
|
|
misses = 0
|
2025-02-15 18:05:11 -08:00
|
|
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
2025-03-12 01:12:41 -04:00
|
|
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
|
|
|
matches += 1
|
|
|
|
else:
|
|
|
|
misses += 1
|
|
|
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
|
|
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
|
|
|
|
|
|
|
# Heuristic: expect at least 70% of the prompts to match exactly
|
|
|
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
|
|
|
assert matches > int(0.7 * len(ref_outputs))
|
2025-02-15 18:05:11 -08:00
|
|
|
del spec_llm
|