[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
parent
050f285ff6
commit
62b8aebc6f
@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
# Bonus tokens are currently disabled. Verify they're set to -1.
|
||||
# See https://github.com/vllm-project/vllm/issues/4212
|
||||
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
# Expect all tokens to be equal to draft tokens.
|
||||
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
|
||||
|
||||
# Expect all bonus tokens to be included.
|
||||
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
|
||||
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
# Expect first token to be equal to recovered tokens.
|
||||
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
|
||||
@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
torch.ones_like(output_token_ids[:, 1:]) * -1)
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
recovered_plus_bonus = torch.cat(
|
||||
(recovered_token_ids, bonus_token_ids), dim=-1)
|
||||
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
|
||||
# Assert first rejected token is a recovered token or bonus token.
|
||||
assert torch.equal(
|
||||
recovered_plus_bonus[torch.arange(0, batch_size),
|
||||
|
@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
def mock_sample(probs, *args, **kwargs):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
0
tests/spec_decode/e2e/__init__.py
Normal file
0
tests/spec_decode/e2e/__init__.py
Normal file
@ -1,3 +1,5 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import cleanup
|
||||
@ -6,28 +8,34 @@ from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
def baseline_llm_generator(request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
seed):
|
||||
return create_llm_generator("baseline", request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed)
|
||||
return create_llm_generator("test", request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, test_llm_kwargs,
|
||||
seed)
|
||||
|
||||
|
||||
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
distinct_llm_kwargs, seed):
|
||||
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, distinct_llm_kwargs,
|
||||
seed):
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**distinct_llm_kwargs,
|
||||
}
|
||||
test_name = request.node.name
|
||||
|
||||
def generator_inner():
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
set_random_seed(seed)
|
||||
@ -36,6 +44,23 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
def generator_outer():
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
return generator_outer
|
||||
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
tokens = []
|
||||
token_ids = []
|
||||
for llm in llm_generator():
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
del llm
|
||||
|
||||
return tokens, token_ids
|
||||
|
169
tests/spec_decode/e2e/test_compatibility.py
Normal file
169
tests/spec_decode/e2e/test_compatibility.py
Normal file
@ -0,0 +1,169 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Expect failure as spec decode not supported by
|
||||
# Ray backend.
|
||||
"worker_use_ray": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_ray(test_llm_generator):
|
||||
"""Verify that speculative decoding with Ray fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"enable_chunked_prefill": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
|
||||
"""Verify that speculative decoding with chunked prefill fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="Speculative decoding and chunked prefill"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Speculative max model len > overridden max model len should raise.
|
||||
"max_model_len": 128,
|
||||
"speculative_max_model_len": 129,
|
||||
},
|
||||
{
|
||||
# Speculative max model len > draft max model len should raise.
|
||||
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
||||
"speculative_max_model_len": 2048 + 1,
|
||||
},
|
||||
{
|
||||
# Speculative max model len > target max model len should raise.
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
|
||||
"speculative_max_model_len": 4096 + 1,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
||||
"""Verify that speculative decoding validates speculative_max_model_len.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
|
||||
"""Verify that speculative decoding with block manager v1 fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="Speculative decoding requires usage of the V2"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
@ -1,11 +1,42 @@
|
||||
"""The tests in this file verify end-to-end speculative decoding correctness.
|
||||
|
||||
This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality. This gives us good coverage of temp=0.
|
||||
|
||||
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||
output distribution is the same with spec decode vs. no spec decode (this would
|
||||
be prohibitively expensive to run with a real model).
|
||||
|
||||
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||
distributions of the target model and proposal model be deterministic given the
|
||||
same input. vLLM largely guarantees this.
|
||||
|
||||
@cadedaniel has seen cases where the output probabilities of a draft/target
|
||||
model change slightly with certain batch sizes or prompts, even with Torch
|
||||
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
|
||||
determinism in on-device batched operations, a bug in vLLM's spec decode
|
||||
implementation, or the "hardware numerics" limitations. Either way, rejection
|
||||
sampling ensures the output distribution matches the target model, but it breaks
|
||||
greedy-equality tests for those batch sizes/prompts.
|
||||
"""
|
||||
|
||||
from itertools import cycle
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
@ -14,9 +45,6 @@ from vllm import SamplingParams
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip real loading for fast test.
|
||||
"load_format": "dummy",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
@ -31,22 +59,15 @@ from vllm import SamplingParams
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 1,
|
||||
},
|
||||
{
|
||||
# No spec decode.
|
||||
# Verify the detokenizer assertions in the test work when spec
|
||||
# decode is disabled.
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
# NOTE: We should run more permutations of this test (more BS, more seeds). But
|
||||
# because our spec decode generates gibberish token ids, the likelihood of
|
||||
# emitting an invalid token combination is nontrivial. This causes divergence in
|
||||
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
|
||||
# start" bytes are emitted.
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
batch_size: int):
|
||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||
generates the correct number of tokens (via ignore_eos=True), and that the
|
||||
detokenization matches HF transformers.
|
||||
@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
|
||||
batch_tokens, batch_token_ids = get_output_from_llm_generator(
|
||||
@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
# Expect a generation for each prompt in the batch.
|
||||
assert len(batch_token_ids) == len(prompts)
|
||||
|
||||
# Expect each generation to have expected number of tokens (note
|
||||
# ignore_eos=True).
|
||||
assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
|
||||
# Expect each generation to have expected number of tokens (note ignore_eos
|
||||
# is True).
|
||||
assert [len(token_ids)
|
||||
for token_ids in batch_token_ids] == ([output_len] * batch_size)
|
||||
|
||||
# Expect detokenized string to match.
|
||||
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
|
||||
@ -92,13 +112,293 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "JackFram/llama-68m",
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use long output len for the small model test.
|
||||
1536,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model with batch size of one.
|
||||
|
||||
# Skip real loading for fast test.
|
||||
"load_format": "dummy",
|
||||
Since this test is cheaper than other e2e correctness tests, we generate
|
||||
with a higher output_len.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model and large batch size.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("max_output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
max_output_len: int):
|
||||
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
||||
sampling respects the EOS token.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use decently long output len for a high quality test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
||||
separate from large BS tests to make identifying the source of bugs easier.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
||||
This is the closest test to a real production workload.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -109,43 +409,189 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# As of this writing, vLLM only compiles with these 3 block sizes by
|
||||
# default.
|
||||
{
|
||||
# Expect failure as spec decode not supported by
|
||||
# Ray backend.
|
||||
"worker_use_ray": True,
|
||||
"block_size": 8,
|
||||
},
|
||||
{
|
||||
"block_size": 16,
|
||||
},
|
||||
{
|
||||
"block_size": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail(test_llm_generator):
|
||||
"""Verify that speculative decoding with Ray fails.
|
||||
def test_spec_decode_different_block_size(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality over different block sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# This must be a good bit larger than speculative_max_model_len so that
|
||||
# we can test the case where all seqs are skipped, but still small to
|
||||
# ensure fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify greedy equality when some (or all) sequences skip speculation.
|
||||
We do this by setting the max model len of the draft model to an
|
||||
artificially low value, such that when the sequences grow beyond it, they
|
||||
are skipped in speculative decoding.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with many different values of k.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
del llm
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
return tokens, token_ids
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
num_draft_tokens = 0
|
||||
k = 5
|
||||
|
||||
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
|
||||
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
|
||||
num_draft_tokens, k)
|
||||
|
||||
rej_sampler = MagicMock()
|
||||
@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
|
||||
num_draft_tokens)
|
||||
assert (metrics.system_efficiency == num_emitted_tokens /
|
||||
num_possible_tokens)
|
||||
max_num_emitted_tokens)
|
||||
else:
|
||||
assert math.isnan(metrics.draft_acceptance_rate)
|
||||
assert math.isnan(metrics.system_efficiency)
|
||||
|
@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations():
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([0, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
||||
|
@ -1,4 +1,5 @@
|
||||
import random
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
args, _ = rejection_sampler.call_args_list[0]
|
||||
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
|
||||
actual_proposal_token_ids) = args
|
||||
_, kwargs = rejection_sampler.call_args_list[0]
|
||||
actual = SimpleNamespace(**kwargs)
|
||||
|
||||
assert torch.equal(actual_bonus_token_ids,
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(
|
||||
actual_proposal_scores,
|
||||
actual.target_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
||||
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual_proposal_probs, proposal_probs)
|
||||
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@ -500,8 +506,8 @@ def test_init_device():
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
@ -63,11 +63,14 @@ def create_execute_model_data(
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
rank: int = 0) -> MagicMock:
|
||||
rank: int = 0,
|
||||
use_spec: bool = True) -> MagicMock:
|
||||
if cls is None:
|
||||
cls = Worker
|
||||
|
||||
worker = MagicMock(spec=cls)
|
||||
spec = cls if use_spec else None
|
||||
|
||||
worker = MagicMock(spec=spec)
|
||||
worker.vocab_size = vocab_size
|
||||
worker.max_model_len = max_model_len
|
||||
worker.rank = rank
|
||||
|
@ -655,6 +655,9 @@ class SpeculativeConfig:
|
||||
target_dtype: str,
|
||||
speculative_model: Optional[str],
|
||||
num_speculative_tokens: Optional[int],
|
||||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
use_v2_block_manager: bool,
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
|
||||
@ -672,6 +675,15 @@ class SpeculativeConfig:
|
||||
model, if provided.
|
||||
num_speculative_tokens (Optional[int]): The number of speculative
|
||||
tokens, if provided.
|
||||
speculative_max_model_len (Optional[int]): The maximum model len of
|
||||
the speculative model. Used when testing the ability to skip
|
||||
speculation for some sequences.
|
||||
enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||
chunked prefill or not. Used for raising an error since its not
|
||||
yet compatible with spec decode.
|
||||
use_v2_block_manager (bool): Whether vLLM is configured to use the
|
||||
v2 block manager or not. Used for raising an error since the v2
|
||||
block manager is required with spec decode.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
@ -690,12 +702,21 @@ class SpeculativeConfig:
|
||||
assert (speculative_model is not None
|
||||
and num_speculative_tokens is not None)
|
||||
|
||||
if enable_chunked_prefill:
|
||||
raise ValueError(
|
||||
"Speculative decoding and chunked prefill are "
|
||||
f"currently mutually exclusive ({enable_chunked_prefill=}).")
|
||||
|
||||
if not use_v2_block_manager:
|
||||
raise ValueError(
|
||||
"Speculative decoding requires usage of the V2 "
|
||||
"block manager. Enable it with --use-v2-block-manager.")
|
||||
|
||||
# TODO: The user should be able to specify revision/quantization/max
|
||||
# model len for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
draft_code_revision = None
|
||||
draft_quantization = None
|
||||
draft_max_model_len = None
|
||||
|
||||
draft_model_config = ModelConfig(
|
||||
model=speculative_model,
|
||||
@ -707,7 +728,7 @@ class SpeculativeConfig:
|
||||
revision=draft_revision,
|
||||
code_revision=draft_code_revision,
|
||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||
max_model_len=draft_max_model_len,
|
||||
max_model_len=None,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_context_len_to_capture=target_model_config.
|
||||
@ -715,6 +736,13 @@ class SpeculativeConfig:
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
)
|
||||
|
||||
draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len,
|
||||
draft_model_config.max_model_len,
|
||||
target_model_config.max_model_len,
|
||||
))
|
||||
|
||||
draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
target_parallel_config))
|
||||
@ -725,6 +753,41 @@ class SpeculativeConfig:
|
||||
num_speculative_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len: Optional[int],
|
||||
draft_max_model_len: int,
|
||||
target_max_model_len: int,
|
||||
) -> int:
|
||||
"""Determine the max sequence len for the draft model. This is usually
|
||||
the draft_max_model_len, but may be the target_max_model_len if it is
|
||||
less than the draft_max_model_len, or may be speculative_max_model_len
|
||||
if it is specified.
|
||||
|
||||
This is necessary so that sequences do not exceed the capacity of the
|
||||
draft model or the target model.
|
||||
|
||||
speculative_max_model_len is mainly used for testing that sequences can
|
||||
skip speculation.
|
||||
"""
|
||||
|
||||
if speculative_max_model_len is not None:
|
||||
|
||||
if speculative_max_model_len > draft_max_model_len:
|
||||
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {draft_max_model_len=}")
|
||||
|
||||
if speculative_max_model_len > target_max_model_len:
|
||||
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {target_max_model_len=}")
|
||||
|
||||
return speculative_max_model_len
|
||||
|
||||
return min(
|
||||
draft_max_model_len,
|
||||
target_max_model_len,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig) -> ParallelConfig:
|
||||
|
@ -73,6 +73,7 @@ class EngineArgs:
|
||||
# Speculative decoding configuration.
|
||||
speculative_model: Optional[str] = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
speculative_max_model_len: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -237,7 +238,7 @@ class EngineArgs:
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32, 128],
|
||||
choices=[8, 16, 32],
|
||||
help='Token block size for contiguous chunks of '
|
||||
'tokens.')
|
||||
|
||||
@ -420,17 +421,25 @@ class EngineArgs:
|
||||
parser.add_argument(
|
||||
'--speculative-model',
|
||||
type=str,
|
||||
default=None,
|
||||
default=EngineArgs.speculative_model,
|
||||
help=
|
||||
'The name of the draft model to be used in speculative decoding.')
|
||||
|
||||
parser.add_argument(
|
||||
'--num-speculative-tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
default=EngineArgs.num_speculative_tokens,
|
||||
help='The number of speculative tokens to sample from '
|
||||
'the draft model in speculative decoding.')
|
||||
|
||||
parser.add_argument(
|
||||
'--speculative-max-model-len',
|
||||
type=str,
|
||||
default=EngineArgs.speculative_max_model_len,
|
||||
help='The maximum sequence length supported by the '
|
||||
'draft model. Sequences over this length will skip '
|
||||
'speculation.')
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=str,
|
||||
default=EngineArgs.model_loader_extra_config,
|
||||
@ -481,6 +490,9 @@ class EngineArgs:
|
||||
target_dtype=self.dtype,
|
||||
speculative_model=self.speculative_model,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
|
@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup)
|
||||
SequenceGroup, SequenceStage)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@ -480,9 +480,12 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
# If uncomputed tokens > 0, it means prefill is chunked.
|
||||
# We don't need to process outputs in that case.
|
||||
if seq_group.get_num_uncomputed_tokens() == 0:
|
||||
|
||||
# If all sequences in the sequence group are in DECODE, then we can
|
||||
# process the output tokens. Otherwise, they are (chunked) prefill
|
||||
# samples and should not be processed.
|
||||
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
|
||||
if all(stage == SequenceStage.DECODE for stage in stages):
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
@ -569,7 +572,8 @@ class LLMEngine:
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
self.stat_logger.log(
|
||||
self._get_stats(scheduler_outputs, model_output=output))
|
||||
|
||||
return request_outputs
|
||||
|
||||
@ -578,9 +582,18 @@ class LLMEngine:
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
|
||||
|
||||
def _get_stats(self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
|
||||
"""Get Stats to be Logged to Prometheus."""
|
||||
def _get_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs],
|
||||
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
|
||||
"""Get Stats to be Logged to Prometheus.
|
||||
|
||||
Args:
|
||||
scheduler_outputs: Optional, used to populate metrics related to
|
||||
the scheduled batch,
|
||||
model_output: Optional, used to emit speculative decoding metrics
|
||||
which are created by the workers.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# KV Cache Usage in %.
|
||||
@ -637,6 +650,14 @@ class LLMEngine:
|
||||
time_to_first_tokens = time_last_iters if prompt_run else []
|
||||
time_per_output_tokens = [] if prompt_run else time_last_iters
|
||||
|
||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||
# sampler output.
|
||||
if model_output and (model_output[0].spec_decode_worker_metrics
|
||||
is not None):
|
||||
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
||||
else:
|
||||
spec_decode_metrics = None
|
||||
|
||||
return Stats(
|
||||
now=now,
|
||||
num_running=num_running,
|
||||
@ -649,6 +670,7 @@ class LLMEngine:
|
||||
time_to_first_tokens=time_to_first_tokens,
|
||||
time_per_output_tokens=time_per_output_tokens,
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
spec_decode_metrics=spec_decode_metrics,
|
||||
)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||
@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
disable_created_metrics()
|
||||
@ -118,6 +121,8 @@ class Stats:
|
||||
time_per_output_tokens: List[float]
|
||||
time_e2e_requests: List[float]
|
||||
|
||||
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
|
||||
@ -235,3 +240,19 @@ class StatLogger:
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
|
||||
if stats.spec_decode_metrics is not None:
|
||||
logger.info(
|
||||
self._format_spec_decode_metrics_str(
|
||||
stats.spec_decode_metrics))
|
||||
|
||||
def _format_spec_decode_metrics_str(
|
||||
self, metrics: "SpecDecodeWorkerMetrics") -> str:
|
||||
|
||||
return ("Speculative metrics: "
|
||||
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
|
||||
f"System efficiency: {metrics.system_efficiency:.3f}, "
|
||||
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
|
||||
f"Number of accepted tokens: {metrics.accepted_tokens}, "
|
||||
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
|
||||
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
|
||||
|
@ -83,6 +83,7 @@ class GPUExecutor(ExecutorBase):
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
# TODO allow draft-model specific load config.
|
||||
load_config=self.load_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
|
@ -144,6 +144,7 @@ class RejectionSampler(nn.Module):
|
||||
recovered_probs = self._get_recovered_probs(
|
||||
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
||||
|
||||
# NOTE: the recovered_probs are overwritten by this method.
|
||||
recovered_token_ids = _multinomial(recovered_probs,
|
||||
num_samples=1).reshape(
|
||||
batch_size, k)
|
||||
@ -307,6 +308,12 @@ class RejectionSampler(nn.Module):
|
||||
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
||||
bonus_token_ids, -1)
|
||||
|
||||
# We disable bonus tokens because it causes corrupt KV cache for
|
||||
# proposal methods that require KV cache. We can fix it by "prefilling"
|
||||
# the bonus token in the proposer. The following issue tracks the fix.
|
||||
# https://github.com/vllm-project/vllm/issues/4212
|
||||
output_with_bonus_tokens[:, -1] = -1
|
||||
|
||||
# Fill the recovered token ids.
|
||||
output.mul_(~after_false_mask).add_(
|
||||
recovered_token_ids.mul(after_false_mask))
|
||||
|
@ -35,6 +35,14 @@ class Sampler(nn.Module):
|
||||
in logits for each token in the input prompt.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Whether or not the SamplerOutput should have on-device tensors
|
||||
# containing the sampled token ids and probabilities. This is used by
|
||||
# speculative decoding.
|
||||
self.include_gpu_probs_tensor = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@ -79,13 +87,45 @@ class Sampler(nn.Module):
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
sample_results = _sample(probs, logprobs, sampling_metadata,
|
||||
sampling_tensors)
|
||||
sample_results, maybe_sampled_tokens_tensor = _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
||||
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
||||
)
|
||||
|
||||
if self.include_gpu_probs_tensor:
|
||||
assert maybe_sampled_tokens_tensor is not None
|
||||
sampled_tokens_tensor = maybe_sampled_tokens_tensor
|
||||
on_device_tensors = (probs, sampled_tokens_tensor)
|
||||
else:
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||
logprobs, sampling_metadata, sample_results)
|
||||
return _build_sampler_output(sample_results, sampling_metadata,
|
||||
prompt_logprobs, sample_logprobs)
|
||||
return _build_sampler_output(sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors)
|
||||
|
||||
@property
|
||||
def _should_modify_greedy_probs_inplace(self) -> bool:
|
||||
"""Whether or not the sampler should modify the probability distribution
|
||||
of greedily-sampled tokens such that multinomial sampling would sample
|
||||
the greedily-sampled token.
|
||||
|
||||
In other words, if True then we set the probability of the greedily-
|
||||
sampled token to 1.
|
||||
|
||||
This is used by speculative decoding, which requires that the sampling
|
||||
method be encoded into the probability distribution.
|
||||
"""
|
||||
# Modify greedy probs if include_gpu_probs_tensor is set.
|
||||
return self.include_gpu_probs_tensor
|
||||
|
||||
|
||||
def _get_bin_counts_and_mask(
|
||||
@ -359,7 +399,9 @@ def _sample_with_torch(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
@ -371,6 +413,15 @@ def _sample_with_torch(
|
||||
sample_metadata = {}
|
||||
multinomial_samples = {}
|
||||
|
||||
# Create output tensor for sampled token ids.
|
||||
if include_gpu_probs_tensor:
|
||||
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
||||
1,
|
||||
dtype=torch.long,
|
||||
device=logprobs.device)
|
||||
else:
|
||||
sampled_token_ids_tensor = None
|
||||
|
||||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
for sampling_type in SamplingType:
|
||||
@ -383,9 +434,25 @@ def _sample_with_torch(
|
||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||
is_prompts, sample_indices)
|
||||
long_sample_indices = sample_indices.long()
|
||||
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
|
||||
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
||||
dim=-1)
|
||||
|
||||
if include_gpu_probs_tensor:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor[
|
||||
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
||||
|
||||
if modify_greedy_probs:
|
||||
# If required, modify the probabilities such that sampling from
|
||||
# the modified distribution would always sample the argmax
|
||||
# token id.
|
||||
_modify_greedy_probs_inplace(logprobs, probs,
|
||||
long_sample_indices,
|
||||
greedy_samples)
|
||||
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
max_best_of_in_batch = 1
|
||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||
@ -397,15 +464,23 @@ def _sample_with_torch(
|
||||
"seq_groups": seq_groups,
|
||||
"generators": sampling_metadata.generators,
|
||||
}
|
||||
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[sample_indices.long()], max_best_of_in_batch,
|
||||
probs[long_sample_indices], max_best_of_in_batch,
|
||||
**seeded_args)
|
||||
|
||||
if include_gpu_probs_tensor:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor[
|
||||
long_sample_indices] = multinomial_samples[sampling_type]
|
||||
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
beam_search_logprobs = logprobs[sample_indices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
|
||||
# GPU<->CPU sync happens in the loop below.
|
||||
# This also converts the sample output to Python objects.
|
||||
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
@ -427,7 +502,7 @@ def _sample_with_torch(
|
||||
sample_results_dict[i]
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
return sample_results
|
||||
return sample_results, sampled_token_ids_tensor
|
||||
|
||||
|
||||
def _sample_with_triton_kernel(
|
||||
@ -511,12 +586,17 @@ def _sample_with_triton_kernel(
|
||||
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
return _sample_with_torch(probs, logprobs, sampling_metadata)
|
||||
probs: torch.Tensor, logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||
return _sample_with_torch(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
)
|
||||
|
||||
# TODO: Enable once Triton kernel & associated code is faster.
|
||||
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
||||
@ -680,12 +760,73 @@ def _get_logprobs(
|
||||
return result_prompt_logprobs, result_sample_logprobs
|
||||
|
||||
|
||||
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||
sample_indices: torch.Tensor,
|
||||
greedy_samples: torch.Tensor) -> None:
|
||||
"""Modify the probability distributions of the greedily-sampled tokens such
|
||||
that each sampled token has a "probability" of 1.0. This is required by
|
||||
speculative decoding, which depends on the sampling method being encoded
|
||||
within the probability distribution for correctness.
|
||||
|
||||
# Why do we only need to do this for greedy sampling?
|
||||
|
||||
vLLM's sampler performs the following steps for greedy or multinomial
|
||||
(random) sampling:
|
||||
1. Get logits from model.
|
||||
2. Modify logits according to per-sequence sampling parameters.
|
||||
- Multiply by temperature, top-k and top-p masking, penalize tokens
|
||||
according to their frequency, etc.
|
||||
3. Sample a token.
|
||||
- Random sampling simply samples from the modified probability
|
||||
distribution.
|
||||
- Greedy sampling performs `argmax` to obtain the token with the
|
||||
highest likelihood.
|
||||
|
||||
Ignoring greedy sampling for a moment, we find that the computed probability
|
||||
distribution has the following property: we can sample from it independently
|
||||
and find that the token sampled by the Sampler has a frequency corresponding
|
||||
to how often we see it in our sampling. In other words, for tokens sampled
|
||||
with vLLM's random SamplingType, the computed probability distribution
|
||||
encodes the sampling methodology completely.
|
||||
|
||||
Greedy sampling does not normally have this property. vLLM modifies logits
|
||||
according to sampling params, then performs `argmax`, then returns the
|
||||
sampled token and the computed probability distribution. If we sample from
|
||||
the distribution, we'll find the likelihood of the greedily-sampled token
|
||||
is not always 1.0.
|
||||
|
||||
Since lossless speculative decoding requires that the sampling methodology
|
||||
be encoded within the probability distribution, we are motivated to modify
|
||||
the probability distribution such that the sampled token has probability 1
|
||||
when speculative decoding is used.
|
||||
|
||||
NOTE: Alternatively, we could use an extremely low temperature to achieve
|
||||
greedy sampling using multinomial computation and unite the codepaths. This
|
||||
has implications on the overall design of the sampler, e.g. how to record
|
||||
accurate logprobs for the user, so this improvement is deferred to later.
|
||||
"""
|
||||
logprobs[sample_indices, :] = -float('inf')
|
||||
logprobs[sample_indices, greedy_samples] = 0.0
|
||||
probs[sample_indices, :] = 0
|
||||
probs[sample_indices, greedy_samples] = 1.0
|
||||
|
||||
|
||||
def _build_sampler_output(
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||
sample_logprobs: List[SampleLogprobs],
|
||||
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> SamplerOutput:
|
||||
"""Construct Python objects with the output of sampling.
|
||||
|
||||
Args:
|
||||
on_device_tensors: Tuple containing on-device tensors with the
|
||||
probabilities used in sampling and the sampled token ids. This
|
||||
allows post-processing without copies to CPU/serialization, e.g. in
|
||||
speculative decoding rejection sampling.
|
||||
"""
|
||||
|
||||
sampler_output = []
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||
@ -701,4 +842,15 @@ def _build_sampler_output(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||
sampler_output.append(
|
||||
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||
return SamplerOutput(outputs=sampler_output)
|
||||
|
||||
# If not specified, store None values in SamplerOutput.
|
||||
if on_device_tensors is not None:
|
||||
sampled_token_probs, sampled_token_ids = on_device_tensors
|
||||
else:
|
||||
sampled_token_probs, sampled_token_ids = (None, None)
|
||||
|
||||
return SamplerOutput(
|
||||
outputs=sampler_output,
|
||||
sampled_token_probs=sampled_token_probs,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
)
|
||||
|
@ -6,8 +6,8 @@ import torch
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
|
||||
nvtx_range, sampler_output_to_torch,
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
sampler_output_to_torch,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
@ -72,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
# Filter the list to ignore -1 proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if -1 not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
@ -89,7 +95,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs = self._contract_batch(
|
||||
original_bs=len(seq_group_metadata_list),
|
||||
contracted_bs=len(seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
@ -128,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
select_proposal_len_zero=True)
|
||||
|
||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||
spec_seqs, proposal_token_ids_list)
|
||||
seq_group_metadata_list=spec_seqs,
|
||||
proposal_token_ids=proposal_token_ids_list,
|
||||
# NOTE: We determine the seq ids in the expanded batch using the
|
||||
# full seq_group_metadata_list, instead of only spec_seqs.
|
||||
target_seq_ids_iter=self._create_target_seq_id_iterator(
|
||||
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||
)
|
||||
|
||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(self, original_bs: int,
|
||||
def _contract_batch(self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals,
|
||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||
@ -144,42 +157,41 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
|
||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
maybe_mock_device_tensors(
|
||||
sampler_output=target_sampler_output,
|
||||
batch_size=len(non_spec_indices) + num_scoring_tokens,
|
||||
vocab_size=self._vocab_size,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
batch_size, k = proposals.proposal_token_ids.shape
|
||||
expanded_batch_size, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# The number of tokens in the expanded batch used for speculation is
|
||||
# equal to the total expanded batch size minus the number of samples for
|
||||
# non-speculative sequences.
|
||||
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.squeeze().reshape(
|
||||
batch_size, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
|
||||
spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||
self._vocab_size)
|
||||
|
||||
all_tokens = torch.full(size=(original_bs, k + 1),
|
||||
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1,
|
||||
device=self._device,
|
||||
dtype=torch.long)
|
||||
all_probs = torch.zeros(original_bs,
|
||||
all_probs = torch.zeros(contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||
|
||||
if spec_indices:
|
||||
@ -189,20 +201,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
return all_tokens, all_probs
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given the original input sequences and proposed tokens from the draft
|
||||
model, create a list of target sequences that can be used for scoring.
|
||||
|
||||
target_seq_ids_iter provides sequence ids for the expanded batch,
|
||||
fulfilling the requirement that no seq id in the expanded batch is equal
|
||||
to the seq id in the original batch.
|
||||
"""
|
||||
|
||||
if not seq_group_metadata_list:
|
||||
return []
|
||||
|
||||
target_seq_ids_iter = self._create_target_seq_id_iterator(
|
||||
get_all_seq_ids(seq_group_metadata_list))
|
||||
|
||||
target_seq_group_metadata = list(
|
||||
chain.from_iterable(
|
||||
self._create_target_seq_group_metadata(
|
||||
|
@ -24,9 +24,9 @@ class SpeculativeProposals:
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeProposals("
|
||||
f"proposal_token_ids={self.proposal_token_ids.shape}, "
|
||||
f"proposal_token_ids={self.proposal_token_ids}, "
|
||||
f"proposal_probs={self.proposal_probs.shape}, "
|
||||
f"proposal_lens={self.proposal_lens.shape})")
|
||||
f"proposal_lens={self.proposal_lens})")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -147,15 +147,16 @@ class AsyncMetricsCollector:
|
||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||
draft_tokens = self._aggregate_num_draft_tokens
|
||||
|
||||
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
|
||||
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
|
||||
draft_tokens, k)
|
||||
|
||||
if draft_tokens > 0:
|
||||
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||
else:
|
||||
draft_acceptance_rate = float("nan")
|
||||
|
||||
if num_possible_tokens > 0:
|
||||
system_efficiency = emitted_tokens / num_possible_tokens
|
||||
if max_num_emitted_tokens > 0:
|
||||
system_efficiency = emitted_tokens / max_num_emitted_tokens
|
||||
else:
|
||||
system_efficiency = float("nan")
|
||||
|
||||
@ -169,8 +170,22 @@ class AsyncMetricsCollector:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
|
||||
# Divide by k since batch size can be variable.
|
||||
total_num_spec_seqs = draft_tokens / k
|
||||
num_accepted_per_seq_if_all_accepted = k + 1
|
||||
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
|
||||
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
|
||||
"""Calculate the number of emitted tokens, assuming all tokens are
|
||||
accepted.
|
||||
|
||||
This is equal to the number of sequences that have been speculated on,
|
||||
times (speculation len + 1). The +1 comes from the bonus token.
|
||||
"""
|
||||
# Determine the number of sequences that have been speculated on. Since
|
||||
# the batch size can be variable, we divide by k.
|
||||
assert draft_tokens % k == 0
|
||||
total_num_spec_seqs = draft_tokens // k
|
||||
|
||||
# A single sequence may emit k accepted tokens and one bonus token in
|
||||
# the best case.
|
||||
num_emitted_per_seq_if_all_accepted = k + 1
|
||||
|
||||
# The max num of emitted tokens is the number of speculated sequences
|
||||
# times the max emitted per seq.
|
||||
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
|
||||
|
@ -6,8 +6,7 @@ import torch
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import (maybe_mock_device_tensors,
|
||||
sampler_output_to_torch)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@ -329,12 +328,15 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty tensors.
|
||||
proposal_tokens = torch.zeros(0,
|
||||
max_proposal_len,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_probs = torch.zeros(0,
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(size=(
|
||||
batch_size,
|
||||
max_proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_probs = torch.zeros(batch_size,
|
||||
max_proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
@ -345,17 +347,6 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
|
||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
for step_output in sampler_output:
|
||||
maybe_mock_device_tensors(
|
||||
sampler_output=step_output,
|
||||
batch_size=len(proposal_lens),
|
||||
vocab_size=self._vocab_size,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output)
|
||||
|
||||
|
@ -111,6 +111,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||
to keep data on device without transferring to CPU and serializing,
|
||||
which significantly reduces overhead of rejection sampling.
|
||||
|
||||
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
||||
design is to have the "move to CPU and serialize" sampling decision be
|
||||
done outside of the model/sampler; this way the "last-mile" worker
|
||||
object which interfaces with the scheduler can serialize and incur the
|
||||
performance hit as necessary. This allows us to run the worker several
|
||||
iterations in a row without incurring the "move to CPU and serialize"
|
||||
performance penalty.
|
||||
|
||||
Since this requires a large change to vLLM, we defer it to later and
|
||||
temporarily accept this broken abstraction boundary.
|
||||
|
||||
NOTE(cade): This will require a special check if the proposer worker
|
||||
does not have a sampler (e.g. ngram speculation).
|
||||
"""
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
(self.proposer_worker.model_runner.model.sampler.
|
||||
include_gpu_probs_tensor) = True
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
@ -286,15 +312,26 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
select_proposal_len_zero=True)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
proposal_probs = proposal_scores.probs[spec_indices, :-1]
|
||||
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
# Get probabilities of target model, excluding bonus token.
|
||||
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
|
||||
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
# Get bonus tokens from target model.
|
||||
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
|
||||
# Get probabilities according to proposal method.
|
||||
proposal_probs = proposals.proposal_probs[spec_indices]
|
||||
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
accepted_token_ids = self.rejection_sampler(
|
||||
proposal_probs,
|
||||
bonus_token_ids,
|
||||
proposals.proposal_probs,
|
||||
proposals.proposal_token_ids,
|
||||
target_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
)
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
|
Loading…
x
Reference in New Issue
Block a user