2024-05-03 15:52:01 -07:00
|
|
|
from itertools import cycle
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm import SamplingParams
|
|
|
|
|
2024-09-24 18:29:56 -06:00
|
|
|
from .conftest import run_equality_correctness_test
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"common_llm_kwargs",
|
|
|
|
[{
|
2024-09-11 14:07:34 -07:00
|
|
|
"model_name": "JackFram/llama-68m",
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
# Skip cuda graph recording for fast test.
|
|
|
|
"enforce_eager": True,
|
|
|
|
}])
|
|
|
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
|
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
2024-07-20 23:58:58 -07:00
|
|
|
@pytest.mark.parametrize("test_llm_kwargs",
|
|
|
|
[{
|
|
|
|
"speculative_model": "JackFram/llama-160m",
|
|
|
|
"num_speculative_tokens": 3,
|
|
|
|
"disable_logprobs_during_spec_decoding": False,
|
2024-09-24 18:29:56 -06:00
|
|
|
}, {
|
|
|
|
"speculative_model": "JackFram/llama-160m",
|
|
|
|
"num_speculative_tokens": 3,
|
|
|
|
"disable_logprobs_during_spec_decoding": True,
|
2024-07-20 23:58:58 -07:00
|
|
|
}])
|
2024-05-03 15:52:01 -07:00
|
|
|
@pytest.mark.parametrize("batch_size", [8])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"output_len",
|
|
|
|
[
|
|
|
|
# Use smaller output len for fast test.
|
|
|
|
7,
|
|
|
|
])
|
|
|
|
@pytest.mark.parametrize("seed", [1])
|
2024-09-11 14:07:34 -07:00
|
|
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
|
|
|
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs, batch_size: int, output_len: int,
|
|
|
|
seed: int, logprobs: int):
|
2024-05-03 15:52:01 -07:00
|
|
|
"""Verify output logprobs are equal with and without speculative decoding.
|
|
|
|
"""
|
2024-09-24 18:29:56 -06:00
|
|
|
run_equality_correctness_test(vllm_runner,
|
|
|
|
common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs,
|
|
|
|
baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs,
|
|
|
|
batch_size,
|
|
|
|
output_len,
|
|
|
|
seed,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=logprobs,
|
|
|
|
prompt_logprobs=logprobs,
|
|
|
|
disable_logprobs=test_llm_kwargs[
|
|
|
|
'disable_logprobs_during_spec_decoding'])
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"common_llm_kwargs",
|
|
|
|
[{
|
2024-09-11 14:07:34 -07:00
|
|
|
"model_name": "JackFram/llama-68m",
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
# Skip cuda graph recording for fast test.
|
|
|
|
"enforce_eager": True,
|
|
|
|
}])
|
|
|
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
|
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
2024-07-20 23:58:58 -07:00
|
|
|
@pytest.mark.parametrize("test_llm_kwargs",
|
|
|
|
[{
|
|
|
|
"speculative_model": "JackFram/llama-160m",
|
|
|
|
"num_speculative_tokens": 3,
|
|
|
|
"disable_logprobs_during_spec_decoding": False,
|
|
|
|
}, {
|
|
|
|
"speculative_model": "JackFram/llama-160m",
|
|
|
|
"num_speculative_tokens": 6,
|
|
|
|
"disable_logprobs_during_spec_decoding": False,
|
|
|
|
}])
|
2024-05-03 15:52:01 -07:00
|
|
|
@pytest.mark.parametrize("batch_size", [8])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"output_len",
|
|
|
|
[
|
|
|
|
# Use smaller output len for fast test.
|
|
|
|
32,
|
|
|
|
])
|
|
|
|
@pytest.mark.parametrize("seed", [1])
|
2024-09-11 14:07:34 -07:00
|
|
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
|
|
|
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs, batch_size: int,
|
|
|
|
output_len: int, seed: int, logprobs: int):
|
2024-05-03 15:52:01 -07:00
|
|
|
"""Veriy logprob greedy equality with different speculation lens.
|
|
|
|
"""
|
2024-09-24 18:29:56 -06:00
|
|
|
run_equality_correctness_test(vllm_runner,
|
|
|
|
common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs,
|
|
|
|
baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs,
|
|
|
|
batch_size,
|
|
|
|
output_len,
|
|
|
|
seed,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=logprobs,
|
|
|
|
disable_logprobs=test_llm_kwargs[
|
|
|
|
'disable_logprobs_during_spec_decoding'])
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"common_llm_kwargs",
|
|
|
|
[{
|
2024-09-11 14:07:34 -07:00
|
|
|
"model_name": "JackFram/llama-68m",
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
# Skip cuda graph recording for fast test.
|
|
|
|
"enforce_eager": True,
|
|
|
|
}])
|
|
|
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
|
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"test_llm_kwargs",
|
|
|
|
[{
|
|
|
|
"speculative_model": "JackFram/llama-160m",
|
|
|
|
"num_speculative_tokens": 3,
|
2024-07-20 23:58:58 -07:00
|
|
|
"disable_logprobs_during_spec_decoding": False,
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
# Artificially limit the draft model max model len; this forces vLLM
|
|
|
|
# to skip speculation once the sequences grow beyond 32-k tokens.
|
|
|
|
"speculative_max_model_len": 32,
|
|
|
|
}])
|
|
|
|
@pytest.mark.parametrize("batch_size", [8])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"output_len",
|
|
|
|
[
|
|
|
|
# Use smaller output len for fast test.
|
|
|
|
32,
|
|
|
|
])
|
|
|
|
@pytest.mark.parametrize("seed", [1])
|
2024-09-11 14:07:34 -07:00
|
|
|
@pytest.mark.parametrize("logprobs", [1])
|
|
|
|
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs,
|
|
|
|
baseline_llm_kwargs, test_llm_kwargs,
|
|
|
|
batch_size: int, output_len: int,
|
|
|
|
seed: int, logprobs: int):
|
2024-05-03 15:52:01 -07:00
|
|
|
"""Verify logprobs greedy equality when some sequences skip speculation.
|
|
|
|
"""
|
2024-09-24 18:29:56 -06:00
|
|
|
run_equality_correctness_test(vllm_runner,
|
|
|
|
common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs,
|
|
|
|
baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs,
|
|
|
|
batch_size,
|
|
|
|
output_len,
|
|
|
|
seed,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=logprobs,
|
|
|
|
disable_logprobs=test_llm_kwargs[
|
|
|
|
'disable_logprobs_during_spec_decoding'])
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"common_llm_kwargs",
|
|
|
|
[{
|
2024-09-11 14:07:34 -07:00
|
|
|
"model_name": "JackFram/llama-68m",
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
# Skip cuda graph recording for fast test.
|
|
|
|
"enforce_eager": True,
|
|
|
|
}])
|
|
|
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
|
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
2024-07-20 23:58:58 -07:00
|
|
|
@pytest.mark.parametrize("test_llm_kwargs",
|
|
|
|
[{
|
|
|
|
"speculative_model": "JackFram/llama-160m",
|
|
|
|
"num_speculative_tokens": 3,
|
|
|
|
"disable_logprobs_during_spec_decoding": False,
|
|
|
|
}])
|
2024-05-03 15:52:01 -07:00
|
|
|
@pytest.mark.parametrize("batch_size", [1])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"output_len",
|
|
|
|
[
|
|
|
|
# Use smaller output len for fast test.
|
|
|
|
32,
|
|
|
|
])
|
|
|
|
@pytest.mark.parametrize("seed", [1])
|
2024-09-11 14:07:34 -07:00
|
|
|
@pytest.mark.parametrize("logprobs", [6])
|
|
|
|
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs, batch_size: int, output_len: int,
|
|
|
|
seed: int, logprobs: int):
|
2024-05-03 15:52:01 -07:00
|
|
|
"""Verify at least one logprob result has num_logprobs+1, which tests the
|
|
|
|
case where the sampled token is not in top-k logprobs.
|
|
|
|
|
|
|
|
Ideally, this test should validate equality with non-spec by getting
|
|
|
|
logprobs. This is left as future improvement.
|
|
|
|
"""
|
|
|
|
temperature = 1.0
|
|
|
|
|
|
|
|
prompts = [
|
|
|
|
"Hello, my name is",
|
|
|
|
"The president of the United States is",
|
|
|
|
"The capital of France is",
|
|
|
|
"The future of AI is",
|
|
|
|
"San Francisco is know for its",
|
|
|
|
"Facebook was created in 2004 by",
|
|
|
|
"Curious George is a",
|
|
|
|
"Python 3.11 brings improvements to its",
|
|
|
|
]
|
|
|
|
|
|
|
|
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
|
|
|
|
|
|
|
sampling_params = SamplingParams(
|
2024-09-11 14:07:34 -07:00
|
|
|
max_tokens=output_len,
|
|
|
|
ignore_eos=True,
|
2024-05-03 15:52:01 -07:00
|
|
|
temperature=temperature,
|
2024-09-11 14:07:34 -07:00
|
|
|
logprobs=logprobs,
|
2024-05-03 15:52:01 -07:00
|
|
|
)
|
|
|
|
|
2024-09-11 14:07:34 -07:00
|
|
|
sd_args = {
|
|
|
|
**common_llm_kwargs,
|
|
|
|
**per_test_common_llm_kwargs,
|
|
|
|
**test_llm_kwargs,
|
|
|
|
}
|
|
|
|
|
|
|
|
with vllm_runner(**sd_args) as vllm_model:
|
|
|
|
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
2024-05-03 15:52:01 -07:00
|
|
|
|
|
|
|
num_returned_logprobs = [
|
2024-09-11 14:07:34 -07:00
|
|
|
len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
|
2024-05-03 15:52:01 -07:00
|
|
|
]
|
|
|
|
|
|
|
|
# Assert one of the returned logprobs has > num_logprobs (indicating the
|
|
|
|
# sampled token is not in top-k).
|
2024-09-11 14:07:34 -07:00
|
|
|
assert any(
|
|
|
|
[num_returned > logprobs for num_returned in num_returned_logprobs])
|
2024-08-22 07:33:48 -06:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"common_llm_kwargs",
|
|
|
|
[{
|
2024-09-11 14:07:34 -07:00
|
|
|
"model_name": "JackFram/llama-160m",
|
2024-08-22 07:33:48 -06:00
|
|
|
# Skip cuda graph recording for fast test.
|
|
|
|
"enforce_eager": True,
|
|
|
|
}])
|
|
|
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
|
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
|
|
@pytest.mark.parametrize("test_llm_kwargs",
|
|
|
|
[{
|
|
|
|
"speculative_model": "JackFram/llama-68m",
|
|
|
|
"num_speculative_tokens": 3,
|
|
|
|
"disable_logprobs_during_spec_decoding": True,
|
|
|
|
}])
|
|
|
|
@pytest.mark.parametrize("seed", [1])
|
2024-09-11 14:07:34 -07:00
|
|
|
@pytest.mark.parametrize("batch_size", [4])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"output_len",
|
|
|
|
[
|
|
|
|
# Use smaller output len for fast test.
|
|
|
|
32,
|
|
|
|
])
|
|
|
|
@pytest.mark.parametrize("logprobs", [0])
|
|
|
|
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs, batch_size: int, output_len: int,
|
|
|
|
seed: int, logprobs: int):
|
2024-08-22 07:33:48 -06:00
|
|
|
"""Check the behavior when logprobs are disabled.
|
|
|
|
Token choices should match with the base model.
|
|
|
|
"""
|
2024-09-24 18:29:56 -06:00
|
|
|
run_equality_correctness_test(vllm_runner,
|
|
|
|
common_llm_kwargs,
|
|
|
|
per_test_common_llm_kwargs,
|
|
|
|
baseline_llm_kwargs,
|
|
|
|
test_llm_kwargs,
|
|
|
|
batch_size,
|
|
|
|
output_len,
|
|
|
|
seed,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=logprobs,
|
|
|
|
disable_logprobs=test_llm_kwargs[
|
|
|
|
'disable_logprobs_during_spec_decoding'])
|