vllm/tests/v1/e2e/test_ngram_spec_decode.py

50 lines
1.6 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm import LLM, SamplingParams
@pytest.fixture
def test_prompts():
return [
"Can you repeat the sentence ten times, this is a sentence.",
"Can you repeat the sentence ten times, this is a test.",
]
@pytest.fixture
def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)
@pytest.fixture
def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"
def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
model_name):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name)
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(model=model_name,
speculative_model='[ngram]',
ngram_prompt_lookup_max=5,
ngram_prompt_lookup_min=3,
num_speculative_tokens=3)
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
(f"ref_output: {ref_output.outputs[0].text},"
f"spec_output: {spec_output.outputs[0].text}")
del spec_llm