[Core][Bugfix] Support prompt_logprobs returned with speculative decoding (#8047)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
13f9f7a3d0
commit
01b6f9e1f0
@ -675,8 +675,6 @@ class VllmRunner:
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
) -> Union[List[TokensTextLogprobs],
|
||||
List[TokensTextLogprobsPromptLogprobs]]:
|
||||
assert sampling_params.logprobs is not None
|
||||
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
@ -754,7 +752,7 @@ class VllmRunner:
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs,
|
||||
prompt_logprobs=(num_prompt_logprobs),
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
return self.generate_w_logprobs(prompts,
|
||||
|
@ -1,13 +1,16 @@
|
||||
from itertools import cycle
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
|
||||
from ...conftest import cleanup
|
||||
from ...models.utils import check_logprobs_close, check_outputs_equal
|
||||
from ...models.utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
check_logprobs_close, check_outputs_equal)
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
PROMPTS = [
|
||||
@ -81,45 +84,77 @@ def get_output_from_llm_generator(
|
||||
return tokens, token_ids, acceptance_rate
|
||||
|
||||
|
||||
def run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: Optional[int] = 0,
|
||||
temperature: float = 0.0,
|
||||
logprobs: int = 1):
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**baseline_llm_kwargs,
|
||||
}
|
||||
def check_logprobs_correctness(
|
||||
spec_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
baseline_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
disable_logprobs: bool = False,
|
||||
):
|
||||
"""Compare sampled and prompt logprobs between baseline and spec decoding
|
||||
"""
|
||||
if not disable_logprobs:
|
||||
return check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=spec_outputs,
|
||||
name_0="org",
|
||||
name_1="sd",
|
||||
)
|
||||
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
# Check correctness when disable_logprobs == True
|
||||
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
|
||||
# Check generated token logprobs.
|
||||
spec_logprobs = spec_output[2]
|
||||
baseline_logprobs = baseline_output[2]
|
||||
_check_logprobs_when_output_disabled(spec_logprobs,
|
||||
baseline_logprobs,
|
||||
is_prompt_logprobs=False)
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
# Check prompt logprobs too, if they exist
|
||||
if len(baseline_output) == 4:
|
||||
assert len(spec_output) == 4
|
||||
spec_prompt_logprobs = spec_output[3]
|
||||
baseline_prompt_logprobs = baseline_output[3]
|
||||
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
|
||||
baseline_prompt_logprobs,
|
||||
is_prompt_logprobs=True)
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
logprobs=logprobs)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
def _check_logprobs_when_output_disabled(
|
||||
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
is_prompt_logprobs: bool = False,
|
||||
):
|
||||
# Prompt logprobs are optional
|
||||
if is_prompt_logprobs and baseline_logprobs is None:
|
||||
assert spec_logprobs is None
|
||||
return
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
assert spec_logprobs is not None
|
||||
assert baseline_logprobs is not None
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
check_logprobs_close(outputs_0_lst=org_outputs,
|
||||
outputs_1_lst=sd_outputs,
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# First prompt logprob is expected to be None
|
||||
if is_prompt_logprobs and baseline_pos_logprobs is None:
|
||||
assert spec_pos_logprobs is None
|
||||
assert pos == 0
|
||||
continue
|
||||
|
||||
assert spec_pos_logprobs is not None
|
||||
assert baseline_pos_logprobs is not None
|
||||
|
||||
# When disabled, the 1 logprob is returned with dummy values for the
|
||||
# score and rank, but the token id should match the baseline model
|
||||
assert len(spec_pos_logprobs) == 1
|
||||
(spec_pos_logprob_token_id,
|
||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||
assert spec_pos_logprob.rank == -1
|
||||
assert spec_pos_logprob.logprob == 0.0
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
def run_equality_correctness_test(
|
||||
@ -135,7 +170,10 @@ def run_equality_correctness_test(
|
||||
disable_seed: bool = False,
|
||||
ignore_eos: bool = True,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None):
|
||||
expected_acceptance_rate: Optional[float] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
disable_logprobs: bool = False):
|
||||
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
@ -157,10 +195,12 @@ def run_equality_correctness_test(
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
org_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
@ -169,7 +209,7 @@ def run_equality_correctness_test(
|
||||
'prometheus']
|
||||
stat_logger.local_interval = -100
|
||||
|
||||
sd_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
@ -185,11 +225,16 @@ def run_equality_correctness_test(
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
|
||||
check_outputs_equal(outputs_0_lst=org_outputs,
|
||||
outputs_1_lst=sd_outputs,
|
||||
# Only pass token entries, not the logprobs
|
||||
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
|
||||
outputs_1_lst=[out[0:2] for out in sd_outputs],
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
|
||||
# Check logprobs if requested
|
||||
if logprobs is not None or prompt_logprobs is not None:
|
||||
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
|
||||
|
||||
|
||||
def run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
|
@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import run_logprob_correctness_test
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -25,6 +25,10 @@ from .conftest import run_logprob_correctness_test
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding.
|
||||
"""
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
output_len: int, seed: int, logprobs: int):
|
||||
"""Veriy logprob greedy equality with different speculation lens.
|
||||
"""
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||
"""
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
||||
"""Check the behavior when logprobs are disabled.
|
||||
Token choices should match with the base model.
|
||||
"""
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
@ -16,7 +16,7 @@ However, we still need to verify below scenario could be passed:
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, MLPSpeculator would not break the
|
||||
correctess for the target model outputs.
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [8])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
@ -9,8 +9,8 @@ from vllm.engine.output_processor.single_step import (
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
@ -110,10 +110,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
# we can take the first sample.
|
||||
samples = [output.samples[0] for output in outputs]
|
||||
|
||||
# -1 means the output token is not valid (eg. due to spec decode
|
||||
# entries in sample tokens may be invalid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
valid_samples = [
|
||||
sample for sample in samples if sample.output_token != -1
|
||||
sample for sample in samples
|
||||
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
||||
]
|
||||
assert valid_samples
|
||||
|
||||
|
@ -15,7 +15,8 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
@ -759,10 +760,10 @@ def _sample_with_torch(
|
||||
|
||||
# Create output tensor for sampled token ids.
|
||||
if include_gpu_probs_tensor:
|
||||
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
||||
1,
|
||||
dtype=torch.long,
|
||||
device=logprobs.device)
|
||||
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
|
||||
VLLM_INVALID_TOKEN_ID,
|
||||
dtype=torch.long,
|
||||
device=logprobs.device)
|
||||
else:
|
||||
sampled_token_ids_tensor = None
|
||||
|
||||
|
@ -26,6 +26,8 @@ if TYPE_CHECKING:
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
VLLM_INVALID_TOKEN_ID = -1
|
||||
|
||||
|
||||
# We use dataclass for now because it is used for
|
||||
# openai server output, and msgspec is not serializable.
|
||||
|
@ -6,9 +6,9 @@ import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
|
||||
SequenceData, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||
@ -69,10 +69,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
# Filter the list to ignore -1 proposals.
|
||||
# Filter the list to ignore invalid proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if -1 not in proposals
|
||||
if VLLM_INVALID_TOKEN_ID not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
|
@ -13,9 +13,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SequenceGroupMetadata,
|
||||
get_all_seq_ids, get_all_seq_ids_and_request_ids)
|
||||
get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
@ -28,7 +29,8 @@ from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
|
||||
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||
create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
@ -436,8 +438,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self, execute_model_req: ExecuteModelRequest,
|
||||
sampler_output: SamplerOutput) -> SamplerOutput:
|
||||
"""
|
||||
Creates and returns a `SamplerOutput` with only the sampled token IDs
|
||||
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
|
||||
Creates and returns a `SamplerOutput` with only the token IDs being
|
||||
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
|
||||
All other parameters in `CompletionSequenceGroupOutput` related to log
|
||||
probabilities are skipped.
|
||||
|
||||
@ -449,14 +451,46 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
Returns:
|
||||
SamplerOutput: A new `SamplerOutput` instance containing a list of
|
||||
`CompletionSequenceGroupOutput` objects with only sampled token
|
||||
IDs populated.
|
||||
`CompletionSequenceGroupOutput` objects with only token IDs
|
||||
populated.
|
||||
"""
|
||||
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
|
||||
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
|
||||
seq_output_prompt_logprobs = [
|
||||
seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
|
||||
and seq.sampling_params.prompt_logprobs > 0
|
||||
for seq in execute_model_req.seq_group_metadata_list
|
||||
]
|
||||
# ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
|
||||
sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
|
||||
# subtracting is faster than testing for equality
|
||||
sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
|
||||
if any(seq_output_prompt_logprobs) else \
|
||||
sampler_output.sampled_token_ids).tolist()
|
||||
|
||||
seq_data_entries = (
|
||||
(seq_id, seq_data) for sg in \
|
||||
execute_model_req.seq_group_metadata_list \
|
||||
for seq_id, seq_data in sg.seq_data.items()
|
||||
)
|
||||
completion_seq_group_output_list: List[
|
||||
CompletionSequenceGroupOutput] = []
|
||||
for index, seq_id in enumerate(seq_ids):
|
||||
for index, ((seq_id, seq_data), needs_prompt_logprobs) in \
|
||||
enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)):
|
||||
if needs_prompt_logprobs:
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_logprobs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
)
|
||||
# no prompt logprobs for the first token
|
||||
for p_token_id in prompt_token_ids[1:]
|
||||
]
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
|
||||
completion_seq_group_output_list.append(
|
||||
create_sequence_group_output(
|
||||
token_id=sampled_token_ids_list[index][0],
|
||||
@ -465,7 +499,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
seq_id=seq_id,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
))
|
||||
prompt_logprobs=prompt_logprobs))
|
||||
return SamplerOutput(outputs=completion_seq_group_output_list)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
@ -485,6 +519,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Store hidden states from target model execution.
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
# remove hidden_states for prompt tokens
|
||||
if any(seq.is_prompt
|
||||
for seq in execute_model_req.seq_group_metadata_list):
|
||||
hidden_states = hidden_states[
|
||||
torch.where(sampler_output.sampled_token_ids -
|
||||
VLLM_INVALID_TOKEN_ID)[0]]
|
||||
if self.previous_hidden_states is None:
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||
|
@ -6,7 +6,8 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceGroupMetadata, SequenceOutput)
|
||||
PromptLogprobs, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
@ -49,21 +50,19 @@ def get_sampled_token_logprobs(
|
||||
return sampled_token_ids_ranks, selected_logprobs
|
||||
|
||||
|
||||
def create_sequence_group_output(
|
||||
def create_logprobs_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
) -> Dict[int, Logprob]:
|
||||
"""Create a Logprob Dict for a token given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
seq_id (int): The sequence id.
|
||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||
"""
|
||||
@ -85,14 +84,44 @@ def create_sequence_group_output(
|
||||
if topk_token_id is not None
|
||||
})
|
||||
|
||||
return logprobs
|
||||
|
||||
|
||||
def create_sequence_group_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None,
|
||||
) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
seq_id (int): The sequence id.
|
||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||
"""
|
||||
|
||||
logprobs = create_logprobs_output(
|
||||
token_id,
|
||||
token_id_logprob_rank,
|
||||
token_id_logprob,
|
||||
topk_token_ids,
|
||||
topk_logprobs,
|
||||
)
|
||||
|
||||
return CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs=logprobs)
|
||||
],
|
||||
# TODO add prompt logprobs support.
|
||||
prompt_logprobs=None,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,13 +1,11 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
|
||||
Sequence, SequenceGroup)
|
||||
|
||||
from .tokenizer import AnyTokenizer
|
||||
from .tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
# Used eg. for marking rejected tokens in spec decoding.
|
||||
INVALID_TOKEN_ID = -1
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
"""Provides methods to decode the output of a model into text."""
|
||||
@ -61,7 +59,7 @@ class Detokenizer:
|
||||
continue
|
||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != INVALID_TOKEN_ID):
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
prompt_token_ids_with_token = (
|
||||
prompt_token_ids[:token_position] + [token_id])
|
||||
(new_tokens, new_text, new_prefix_offset,
|
||||
@ -143,7 +141,7 @@ class Detokenizer:
|
||||
continue
|
||||
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != INVALID_TOKEN_ID):
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
@ -282,14 +280,14 @@ def detokenize_incrementally(
|
||||
assert prev_tokens is not None
|
||||
|
||||
# If the new token id is out of bounds, return an empty string.
|
||||
if new_token_id >= len(tokenizer):
|
||||
new_tokens = [""]
|
||||
else:
|
||||
if 0 <= new_token_id < len(tokenizer):
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
if isinstance(new_tokens, str):
|
||||
new_tokens = [new_tokens]
|
||||
else:
|
||||
new_tokens = [""]
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# If this is the first iteration, return all tokens.
|
||||
|
Loading…
x
Reference in New Issue
Block a user