[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)

This commit is contained in:
Cade Daniel 2024-04-23 01:02:36 -07:00 committed by GitHub
parent 050f285ff6
commit 62b8aebc6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1164 additions and 175 deletions

View File

@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids,
)
# Bonus tokens are currently disabled. Verify they're set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
# Expect all bonus tokens to be included.
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
elif which_tokens_accepted == "no_tokens_accepted":
# Expect first token to be equal to recovered tokens.
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
torch.ones_like(output_token_ids[:, 1:]) * -1)
elif which_tokens_accepted == "some_tokens_accepted":
recovered_plus_bonus = torch.cat(
(recovered_token_ids, bonus_token_ids), dim=-1)
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
# Assert first rejected token is a recovered token or bonus token.
assert torch.equal(
recovered_plus_bonus[torch.arange(0, batch_size),

View File

@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

View File

View File

@ -1,3 +1,5 @@
from typing import List, Tuple
import pytest
from tests.conftest import cleanup
@ -6,28 +8,34 @@ from vllm.model_executor.utils import set_random_seed
@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def baseline_llm_generator(request, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
seed):
return create_llm_generator("baseline", request, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, seed)
@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed)
return create_llm_generator("test", request, common_llm_kwargs,
per_test_common_llm_kwargs, test_llm_kwargs,
seed)
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
distinct_llm_kwargs, seed):
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
per_test_common_llm_kwargs, distinct_llm_kwargs,
seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}
test_name = request.node.name
def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)
set_random_seed(seed)
@ -36,6 +44,23 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
del llm
cleanup()
def generator_outer():
for llm in generator_inner():
yield llm
del llm
return generator_outer
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
for llm in llm_generator():
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
del llm
return tokens, token_ids

View File

@ -0,0 +1,169 @@
import pytest
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
"speculative_max_model_len": 4096 + 1,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
"""Verify that speculative decoding with block manager v1 fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding requires usage of the V2"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)

View File

@ -1,11 +1,42 @@
"""The tests in this file verify end-to-end speculative decoding correctness.
This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
same input. vLLM largely guarantees this.
@cadedaniel has seen cases where the output probabilities of a draft/target
model change slightly with certain batch sizes or prompts, even with Torch
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
determinism in on-device batched operations, a bug in vLLM's spec decode
implementation, or the "hardware numerics" limitations. Either way, rejection
sampling ensures the output distribution matches the target model, but it breaks
greedy-equality tests for those batch sizes/prompts.
"""
from itertools import cycle
from typing import List, Tuple
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
@ -14,9 +45,6 @@ from vllm import SamplingParams
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip real loading for fast test.
"load_format": "dummy",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@ -31,22 +59,15 @@ from vllm import SamplingParams
"num_speculative_tokens": 5,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 1,
},
{
# No spec decode.
# Verify the detokenizer assertions in the test work when spec
# decode is disabled.
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1])
# NOTE: We should run more permutations of this test (more BS, more seeds). But
# because our spec decode generates gibberish token ids, the likelihood of
# emitting an invalid token combination is nontrivial. This causes divergence in
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
# start" bytes are emitted.
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
batch_tokens, batch_token_ids = get_output_from_llm_generator(
@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
# Expect a generation for each prompt in the batch.
assert len(batch_token_ids) == len(prompts)
# Expect each generation to have expected number of tokens (note
# ignore_eos=True).
assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
# Expect each generation to have expected number of tokens (note ignore_eos
# is True).
assert [len(token_ids)
for token_ids in batch_token_ids] == ([output_len] * batch_size)
# Expect detokenized string to match.
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
@ -92,13 +112,293 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model": "JackFram/llama-68m",
},
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use long output len for the small model test.
1536,
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model with batch size of one.
# Skip real loading for fast test.
"load_format": "dummy",
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model": "JackFram/llama-68m",
},
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model": "JackFram/llama-68m",
},
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("max_output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
baseline_llm_generator, test_llm_generator, batch_size: int,
max_output_len: int):
"""Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len=False)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use decently long output len for a high quality test.
256,
])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_with_preemption(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@ -109,43 +409,189 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# As of this writing, vLLM only compiles with these 3 block sizes by
# default.
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
"block_size": 8,
},
{
"block_size": 16,
},
{
"block_size": 32,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
def test_spec_decode_different_block_size(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality over different block sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with many different values of k.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
with pytest.raises(AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
del llm
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
return tokens, token_ids
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

View File

@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
num_draft_tokens = 0
k = 5
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k)
rej_sampler = MagicMock()
@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens /
num_possible_tokens)
max_num_emitted_tokens)
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)

View File

@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations():
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([0, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]

View File

@ -1,4 +1,5 @@
import random
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0]
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
actual_proposal_token_ids) = args
_, kwargs = rejection_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs)
assert torch.equal(actual_bonus_token_ids,
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual_proposal_scores,
actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
assert torch.equal(actual_proposal_probs, proposal_probs)
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@ -500,8 +506,8 @@ def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

View File

@ -63,11 +63,14 @@ def create_execute_model_data(
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
rank: int = 0) -> MagicMock:
rank: int = 0,
use_spec: bool = True) -> MagicMock:
if cls is None:
cls = Worker
worker = MagicMock(spec=cls)
spec = cls if use_spec else None
worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank

View File

@ -655,6 +655,9 @@ class SpeculativeConfig:
target_dtype: str,
speculative_model: Optional[str],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
@ -672,6 +675,15 @@ class SpeculativeConfig:
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since its not
yet compatible with spec decode.
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
@ -690,12 +702,21 @@ class SpeculativeConfig:
assert (speculative_model is not None
and num_speculative_tokens is not None)
if enable_chunked_prefill:
raise ValueError(
"Speculative decoding and chunked prefill are "
f"currently mutually exclusive ({enable_chunked_prefill=}).")
if not use_v2_block_manager:
raise ValueError(
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager.")
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = None
draft_max_model_len = None
draft_model_config = ModelConfig(
model=speculative_model,
@ -707,7 +728,7 @@ class SpeculativeConfig:
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=draft_max_model_len,
max_model_len=None,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
@ -715,6 +736,13 @@ class SpeculativeConfig:
max_logprobs=target_model_config.max_logprobs,
)
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
draft_model_config.max_model_len,
target_model_config.max_model_len,
))
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
@ -725,6 +753,41 @@ class SpeculativeConfig:
num_speculative_tokens,
)
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
draft_max_model_len: int,
target_max_model_len: int,
) -> int:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if speculative_max_model_len is not None:
if speculative_max_model_len > draft_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {draft_max_model_len=}")
if speculative_max_model_len > target_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {target_max_model_len=}")
return speculative_max_model_len
return min(
draft_max_model_len,
target_max_model_len,
)
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:

View File

@ -73,6 +73,7 @@ class EngineArgs:
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
def __post_init__(self):
if self.tokenizer is None:
@ -237,7 +238,7 @@ class EngineArgs:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 128],
choices=[8, 16, 32],
help='Token block size for contiguous chunks of '
'tokens.')
@ -420,17 +421,25 @@ class EngineArgs:
parser.add_argument(
'--speculative-model',
type=str,
default=None,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=None,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
type=str,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument('--model-loader-extra-config',
type=str,
default=EngineArgs.model_loader_extra_config,
@ -481,6 +490,9 @@ class EngineArgs:
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
)
scheduler_config = SchedulerConfig(

View File

@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup)
SequenceGroup, SequenceStage)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@ -480,9 +480,12 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
# If all sequences in the sequence group are in DECODE, then we can
# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
@ -569,7 +572,8 @@ class LLMEngine:
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output=output))
return request_outputs
@ -578,9 +582,18 @@ class LLMEngine:
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
def _get_stats(self,
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
"""Get Stats to be Logged to Prometheus."""
def _get_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus.
Args:
scheduler_outputs: Optional, used to populate metrics related to
the scheduled batch,
model_output: Optional, used to emit speculative decoding metrics
which are created by the workers.
"""
now = time.time()
# KV Cache Usage in %.
@ -637,6 +650,14 @@ class LLMEngine:
time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
return Stats(
now=now,
num_running=num_running,
@ -649,6 +670,7 @@ class LLMEngine:
time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests,
spec_decode_metrics=spec_decode_metrics,
)
def add_lora(self, lora_request: LoRARequest) -> bool:

View File

@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, List, Protocol
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__)
disable_created_metrics()
@ -118,6 +121,8 @@ class Stats:
time_per_output_tokens: List[float]
time_e2e_requests: List[float]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
@ -235,3 +240,19 @@ class StatLogger:
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")

View File

@ -83,6 +83,7 @@ class GPUExecutor(ExecutorBase):
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
# TODO allow draft-model specific load config.
load_config=self.load_config,
local_rank=0,
rank=0,

View File

@ -144,6 +144,7 @@ class RejectionSampler(nn.Module):
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
@ -307,6 +308,12 @@ class RejectionSampler(nn.Module):
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))

View File

@ -35,6 +35,14 @@ class Sampler(nn.Module):
in logits for each token in the input prompt.
"""
def __init__(self):
super().__init__()
# Whether or not the SamplerOutput should have on-device tensors
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
def forward(
self,
logits: torch.Tensor,
@ -79,13 +87,45 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata,
sampling_tensors)
sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
sampled_tokens_tensor = maybe_sampled_tokens_tensor
on_device_tensors = (probs, sampled_tokens_tensor)
else:
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)
return _build_sampler_output(sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
"""Whether or not the sampler should modify the probability distribution
of greedily-sampled tokens such that multinomial sampling would sample
the greedily-sampled token.
In other words, if True then we set the probability of the greedily-
sampled token to 1.
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return self.include_gpu_probs_tensor
def _get_bin_counts_and_mask(
@ -359,7 +399,9 @@ def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[Tuple[List[int], List[int]]]:
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
@ -371,6 +413,15 @@ def _sample_with_torch(
sample_metadata = {}
multinomial_samples = {}
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
1,
dtype=torch.long,
device=logprobs.device)
else:
sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
@ -383,9 +434,25 @@ def _sample_with_torch(
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
if include_gpu_probs_tensor:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1)
if modify_greedy_probs:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace(logprobs, probs,
long_sample_indices,
greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
@ -397,15 +464,23 @@ def _sample_with_torch(
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices.long()], max_best_of_in_batch,
probs[long_sample_indices], max_best_of_in_batch,
**seeded_args)
if include_gpu_probs_tensor:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type]
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
@ -427,7 +502,7 @@ def _sample_with_torch(
sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
return sample_results, sampled_token_ids_tensor
def _sample_with_triton_kernel(
@ -511,12 +586,17 @@ def _sample_with_triton_kernel(
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
return _sample_with_torch(probs, logprobs, sampling_metadata)
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
return _sample_with_torch(
probs,
logprobs,
sampling_metadata,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
@ -680,12 +760,73 @@ def _get_logprobs(
return result_prompt_logprobs, result_sample_logprobs
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
sample_indices: torch.Tensor,
greedy_samples: torch.Tensor) -> None:
"""Modify the probability distributions of the greedily-sampled tokens such
that each sampled token has a "probability" of 1.0. This is required by
speculative decoding, which depends on the sampling method being encoded
within the probability distribution for correctness.
# Why do we only need to do this for greedy sampling?
vLLM's sampler performs the following steps for greedy or multinomial
(random) sampling:
1. Get logits from model.
2. Modify logits according to per-sequence sampling parameters.
- Multiply by temperature, top-k and top-p masking, penalize tokens
according to their frequency, etc.
3. Sample a token.
- Random sampling simply samples from the modified probability
distribution.
- Greedy sampling performs `argmax` to obtain the token with the
highest likelihood.
Ignoring greedy sampling for a moment, we find that the computed probability
distribution has the following property: we can sample from it independently
and find that the token sampled by the Sampler has a frequency corresponding
to how often we see it in our sampling. In other words, for tokens sampled
with vLLM's random SamplingType, the computed probability distribution
encodes the sampling methodology completely.
Greedy sampling does not normally have this property. vLLM modifies logits
according to sampling params, then performs `argmax`, then returns the
sampled token and the computed probability distribution. If we sample from
the distribution, we'll find the likelihood of the greedily-sampled token
is not always 1.0.
Since lossless speculative decoding requires that the sampling methodology
be encoded within the probability distribution, we are motivated to modify
the probability distribution such that the sampled token has probability 1
when speculative decoding is used.
NOTE: Alternatively, we could use an extremely low temperature to achieve
greedy sampling using multinomial computation and unite the codepaths. This
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs[sample_indices, :] = -float('inf')
logprobs[sample_indices, greedy_samples] = 0.0
probs[sample_indices, :] = 0
probs[sample_indices, greedy_samples] = 1.0
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
@ -701,4 +842,15 @@ def _build_sampler_output(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return SamplerOutput(outputs=sampler_output)
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
sampled_token_probs, sampled_token_ids = on_device_tensors
else:
sampled_token_probs, sampled_token_ids = (None, None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
)

View File

@ -6,8 +6,8 @@ import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
nvtx_range, sampler_output_to_torch,
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
sampler_output_to_torch,
split_batch_by_proposal_len)
from vllm.worker.worker_base import WorkerBase
@ -72,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
# Filter the list to ignore -1 proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if -1 not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
@ -89,7 +95,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list),
contracted_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
@ -128,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input(
spec_seqs, proposal_token_ids_list)
seq_group_metadata_list=spec_seqs,
proposal_token_ids=proposal_token_ids_list,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter=self._create_target_seq_id_iterator(
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
)
num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs)
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, original_bs: int,
def _contract_batch(self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
@ -144,42 +157,41 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
maybe_mock_device_tensors(
sampler_output=target_sampler_output,
batch_size=len(non_spec_indices) + num_scoring_tokens,
vocab_size=self._vocab_size,
device=self._device,
)
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size, k = proposals.proposal_token_ids.shape
expanded_batch_size, k = proposals.proposal_token_ids.shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape(
batch_size, k + 1)
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
all_tokens = torch.full(size=(original_bs, k + 1),
all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(original_bs,
all_probs = torch.zeros(contracted_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
if spec_indices:
@ -192,17 +204,19 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
"""
if not seq_group_metadata_list:
return []
target_seq_ids_iter = self._create_target_seq_id_iterator(
get_all_seq_ids(seq_group_metadata_list))
target_seq_group_metadata = list(
chain.from_iterable(
self._create_target_seq_group_metadata(

View File

@ -24,9 +24,9 @@ class SpeculativeProposals:
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids.shape}, "
f"proposal_token_ids={self.proposal_token_ids}, "
f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens.shape})")
f"proposal_lens={self.proposal_lens})")
@dataclass

View File

@ -147,15 +147,16 @@ class AsyncMetricsCollector:
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
draft_tokens, k)
if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens
else:
draft_acceptance_rate = float("nan")
if num_possible_tokens > 0:
system_efficiency = emitted_tokens / num_possible_tokens
if max_num_emitted_tokens > 0:
system_efficiency = emitted_tokens / max_num_emitted_tokens
else:
system_efficiency = float("nan")
@ -169,8 +170,22 @@ class AsyncMetricsCollector:
)
@staticmethod
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
# Divide by k since batch size can be variable.
total_num_spec_seqs = draft_tokens / k
num_accepted_per_seq_if_all_accepted = k + 1
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
"""Calculate the number of emitted tokens, assuming all tokens are
accepted.
This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert draft_tokens % k == 0
total_num_spec_seqs = draft_tokens // k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted = k + 1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted

View File

@ -6,8 +6,7 @@ import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import (maybe_mock_device_tensors,
sampler_output_to_torch)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker import Worker
@ -329,12 +328,15 @@ class DraftModelTop1Proposer(SpeculativeProposer):
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty tensors.
proposal_tokens = torch.zeros(0,
# In this case we return empty proposals.
proposal_tokens = torch.full(size=(
batch_size,
max_proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(0,
proposal_probs = torch.zeros(batch_size,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
@ -345,17 +347,6 @@ class DraftModelTop1Proposer(SpeculativeProposer):
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
for step_output in sampler_output:
maybe_mock_device_tensors(
sampler_output=step_output,
batch_size=len(proposal_lens),
vocab_size=self._vocab_size,
device=self._device,
)
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)

View File

@ -111,6 +111,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode()
def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of rejection sampling.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
done outside of the model/sampler; this way the "last-mile" worker
object which interfaces with the scheduler can serialize and incur the
performance hit as necessary. This allows us to run the worker several
iterations in a row without incurring the "move to CPU and serialize"
performance penalty.
Since this requires a large change to vLLM, we defer it to later and
temporarily accept this broken abstraction boundary.
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.proposer_worker.model_runner.model.sampler.
include_gpu_probs_tensor) = True
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
@ -286,15 +312,26 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices
proposal_probs = proposal_scores.probs[spec_indices, :-1]
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
# Get probabilities of target model, excluding bonus token.
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs[spec_indices]
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
accepted_token_ids = self.rejection_sampler(
proposal_probs,
bonus_token_ids,
proposals.proposal_probs,
proposals.proposal_token_ids,
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
)
# Append output tokens from non-speculative sequences to