[Core][Bugfix] Support prompt_logprobs returned with speculative decoding (#8047)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
13f9f7a3d0
commit
01b6f9e1f0
@ -675,8 +675,6 @@ class VllmRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
) -> Union[List[TokensTextLogprobs],
|
) -> Union[List[TokensTextLogprobs],
|
||||||
List[TokensTextLogprobsPromptLogprobs]]:
|
List[TokensTextLogprobsPromptLogprobs]]:
|
||||||
assert sampling_params.logprobs is not None
|
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
|
|
||||||
@ -754,7 +752,7 @@ class VllmRunner:
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=num_logprobs,
|
logprobs=num_logprobs,
|
||||||
prompt_logprobs=(num_prompt_logprobs),
|
prompt_logprobs=num_prompt_logprobs,
|
||||||
stop_token_ids=stop_token_ids)
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
return self.generate_w_logprobs(prompts,
|
return self.generate_w_logprobs(prompts,
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||||
|
|
||||||
from ...conftest import cleanup
|
from ...conftest import cleanup
|
||||||
from ...models.utils import check_logprobs_close, check_outputs_equal
|
from ...models.utils import (TokensTextLogprobs,
|
||||||
|
TokensTextLogprobsPromptLogprobs,
|
||||||
|
check_logprobs_close, check_outputs_equal)
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
PROMPTS = [
|
PROMPTS = [
|
||||||
@ -81,45 +84,77 @@ def get_output_from_llm_generator(
|
|||||||
return tokens, token_ids, acceptance_rate
|
return tokens, token_ids, acceptance_rate
|
||||||
|
|
||||||
|
|
||||||
def run_logprob_correctness_test(vllm_runner,
|
def check_logprobs_correctness(
|
||||||
common_llm_kwargs,
|
spec_outputs: Sequence[Union[TokensTextLogprobs,
|
||||||
per_test_common_llm_kwargs,
|
TokensTextLogprobsPromptLogprobs]],
|
||||||
baseline_llm_kwargs,
|
baseline_outputs: Sequence[Union[TokensTextLogprobs,
|
||||||
test_llm_kwargs,
|
TokensTextLogprobsPromptLogprobs]],
|
||||||
batch_size: int,
|
disable_logprobs: bool = False,
|
||||||
max_output_len: int,
|
):
|
||||||
seed: Optional[int] = 0,
|
"""Compare sampled and prompt logprobs between baseline and spec decoding
|
||||||
temperature: float = 0.0,
|
"""
|
||||||
logprobs: int = 1):
|
if not disable_logprobs:
|
||||||
org_args = {
|
return check_logprobs_close(
|
||||||
**common_llm_kwargs,
|
outputs_0_lst=baseline_outputs,
|
||||||
**per_test_common_llm_kwargs,
|
outputs_1_lst=spec_outputs,
|
||||||
**baseline_llm_kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
sd_args = {
|
|
||||||
**common_llm_kwargs,
|
|
||||||
**per_test_common_llm_kwargs,
|
|
||||||
**test_llm_kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=temperature,
|
|
||||||
max_tokens=max_output_len,
|
|
||||||
seed=seed,
|
|
||||||
logprobs=logprobs)
|
|
||||||
|
|
||||||
with vllm_runner(**org_args) as vllm_model:
|
|
||||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
|
||||||
|
|
||||||
with vllm_runner(**sd_args) as vllm_model:
|
|
||||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
|
||||||
|
|
||||||
check_logprobs_close(outputs_0_lst=org_outputs,
|
|
||||||
outputs_1_lst=sd_outputs,
|
|
||||||
name_0="org",
|
name_0="org",
|
||||||
name_1="sd")
|
name_1="sd",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check correctness when disable_logprobs == True
|
||||||
|
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
|
||||||
|
# Check generated token logprobs.
|
||||||
|
spec_logprobs = spec_output[2]
|
||||||
|
baseline_logprobs = baseline_output[2]
|
||||||
|
_check_logprobs_when_output_disabled(spec_logprobs,
|
||||||
|
baseline_logprobs,
|
||||||
|
is_prompt_logprobs=False)
|
||||||
|
|
||||||
|
# Check prompt logprobs too, if they exist
|
||||||
|
if len(baseline_output) == 4:
|
||||||
|
assert len(spec_output) == 4
|
||||||
|
spec_prompt_logprobs = spec_output[3]
|
||||||
|
baseline_prompt_logprobs = baseline_output[3]
|
||||||
|
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
|
||||||
|
baseline_prompt_logprobs,
|
||||||
|
is_prompt_logprobs=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_logprobs_when_output_disabled(
|
||||||
|
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||||
|
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||||
|
is_prompt_logprobs: bool = False,
|
||||||
|
):
|
||||||
|
# Prompt logprobs are optional
|
||||||
|
if is_prompt_logprobs and baseline_logprobs is None:
|
||||||
|
assert spec_logprobs is None
|
||||||
|
return
|
||||||
|
|
||||||
|
assert spec_logprobs is not None
|
||||||
|
assert baseline_logprobs is not None
|
||||||
|
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||||
|
|
||||||
|
# For each generated position of the sequence.
|
||||||
|
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||||
|
zip(spec_logprobs, baseline_logprobs)):
|
||||||
|
|
||||||
|
# First prompt logprob is expected to be None
|
||||||
|
if is_prompt_logprobs and baseline_pos_logprobs is None:
|
||||||
|
assert spec_pos_logprobs is None
|
||||||
|
assert pos == 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert spec_pos_logprobs is not None
|
||||||
|
assert baseline_pos_logprobs is not None
|
||||||
|
|
||||||
|
# When disabled, the 1 logprob is returned with dummy values for the
|
||||||
|
# score and rank, but the token id should match the baseline model
|
||||||
|
assert len(spec_pos_logprobs) == 1
|
||||||
|
(spec_pos_logprob_token_id,
|
||||||
|
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||||
|
assert spec_pos_logprob.rank == -1
|
||||||
|
assert spec_pos_logprob.logprob == 0.0
|
||||||
|
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||||
|
|
||||||
|
|
||||||
def run_equality_correctness_test(
|
def run_equality_correctness_test(
|
||||||
@ -135,7 +170,10 @@ def run_equality_correctness_test(
|
|||||||
disable_seed: bool = False,
|
disable_seed: bool = False,
|
||||||
ignore_eos: bool = True,
|
ignore_eos: bool = True,
|
||||||
ensure_all_accepted: bool = False,
|
ensure_all_accepted: bool = False,
|
||||||
expected_acceptance_rate: Optional[float] = None):
|
expected_acceptance_rate: Optional[float] = None,
|
||||||
|
logprobs: Optional[int] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
disable_logprobs: bool = False):
|
||||||
|
|
||||||
org_args = {
|
org_args = {
|
||||||
**common_llm_kwargs,
|
**common_llm_kwargs,
|
||||||
@ -157,10 +195,12 @@ def run_equality_correctness_test(
|
|||||||
sampling_params = SamplingParams(temperature=temperature,
|
sampling_params = SamplingParams(temperature=temperature,
|
||||||
max_tokens=max_output_len,
|
max_tokens=max_output_len,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
ignore_eos=ignore_eos)
|
ignore_eos=ignore_eos,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
|
||||||
with vllm_runner(**org_args) as vllm_model:
|
with vllm_runner(**org_args) as vllm_model:
|
||||||
org_outputs = vllm_model.generate(prompts, sampling_params)
|
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||||
|
|
||||||
with vllm_runner(**sd_args) as vllm_model:
|
with vllm_runner(**sd_args) as vllm_model:
|
||||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||||
@ -169,7 +209,7 @@ def run_equality_correctness_test(
|
|||||||
'prometheus']
|
'prometheus']
|
||||||
stat_logger.local_interval = -100
|
stat_logger.local_interval = -100
|
||||||
|
|
||||||
sd_outputs = vllm_model.generate(prompts, sampling_params)
|
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||||
|
|
||||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||||
acceptance_rate = (stat_logger.metrics.
|
acceptance_rate = (stat_logger.metrics.
|
||||||
@ -185,11 +225,16 @@ def run_equality_correctness_test(
|
|||||||
if expected_acceptance_rate is not None:
|
if expected_acceptance_rate is not None:
|
||||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||||
|
|
||||||
check_outputs_equal(outputs_0_lst=org_outputs,
|
# Only pass token entries, not the logprobs
|
||||||
outputs_1_lst=sd_outputs,
|
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
|
||||||
|
outputs_1_lst=[out[0:2] for out in sd_outputs],
|
||||||
name_0="org",
|
name_0="org",
|
||||||
name_1="sd")
|
name_1="sd")
|
||||||
|
|
||||||
|
# Check logprobs if requested
|
||||||
|
if logprobs is not None or prompt_logprobs is not None:
|
||||||
|
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
|
||||||
|
|
||||||
|
|
||||||
def run_equality_correctness_test_tp(model,
|
def run_equality_correctness_test_tp(model,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
|
@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
batch_size, output_len, seed)
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"disable_logprobs_during_spec_decoding": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int,
|
||||||
|
logprobs: int):
|
||||||
|
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
output_len,
|
||||||
|
seed,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -4,7 +4,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from .conftest import run_logprob_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -25,6 +25,10 @@ from .conftest import run_logprob_correctness_test
|
|||||||
"speculative_model": "JackFram/llama-160m",
|
"speculative_model": "JackFram/llama-160m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
"disable_logprobs_during_spec_decoding": False,
|
||||||
|
}, {
|
||||||
|
"speculative_model": "JackFram/llama-160m",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -41,7 +45,7 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
|||||||
seed: int, logprobs: int):
|
seed: int, logprobs: int):
|
||||||
"""Verify output logprobs are equal with and without speculative decoding.
|
"""Verify output logprobs are equal with and without speculative decoding.
|
||||||
"""
|
"""
|
||||||
run_logprob_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -50,7 +54,10 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
|||||||
output_len,
|
output_len,
|
||||||
seed,
|
seed,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs)
|
logprobs=logprobs,
|
||||||
|
prompt_logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -91,7 +98,7 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
output_len: int, seed: int, logprobs: int):
|
output_len: int, seed: int, logprobs: int):
|
||||||
"""Veriy logprob greedy equality with different speculation lens.
|
"""Veriy logprob greedy equality with different speculation lens.
|
||||||
"""
|
"""
|
||||||
run_logprob_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -100,7 +107,9 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
|||||||
output_len,
|
output_len,
|
||||||
seed,
|
seed,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs)
|
logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -143,7 +152,7 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
seed: int, logprobs: int):
|
seed: int, logprobs: int):
|
||||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||||
"""
|
"""
|
||||||
run_logprob_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -152,7 +161,9 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
|||||||
output_len,
|
output_len,
|
||||||
seed,
|
seed,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs)
|
logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -267,7 +278,7 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
|||||||
"""Check the behavior when logprobs are disabled.
|
"""Check the behavior when logprobs are disabled.
|
||||||
Token choices should match with the base model.
|
Token choices should match with the base model.
|
||||||
"""
|
"""
|
||||||
run_logprob_correctness_test(vllm_runner,
|
run_equality_correctness_test(vllm_runner,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
baseline_llm_kwargs,
|
baseline_llm_kwargs,
|
||||||
@ -276,4 +287,6 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
|||||||
output_len,
|
output_len,
|
||||||
seed,
|
seed,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
logprobs=logprobs)
|
logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"disable_logprobs_during_spec_decoding": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
8,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int, logprobs: int):
|
||||||
|
"""Verify greedy equality with different batch size."""
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -16,7 +16,7 @@ However, we still need to verify below scenario could be passed:
|
|||||||
* Test greedy equality under various number of speculative tokens.
|
* Test greedy equality under various number of speculative tokens.
|
||||||
|
|
||||||
With those tests, we can say at least, MLPSpeculator would not break the
|
With those tests, we can say at least, MLPSpeculator would not break the
|
||||||
correctess for the target model outputs.
|
correctness for the target model outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"disable_logprobs_during_spec_decoding": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [8])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int,
|
||||||
|
logprobs: int):
|
||||||
|
"""Verify greedy equality with different batch size."""
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"model_name": "JackFram/llama-68m",
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
"disable_logprobs_during_spec_decoding": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
8,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
|
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int,
|
||||||
|
logprobs: int):
|
||||||
|
"""Verify greedy equality on a tiny model with different batch size."""
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_logprobs=logprobs,
|
||||||
|
disable_logprobs=test_llm_kwargs[
|
||||||
|
'disable_logprobs_during_spec_decoding'])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -9,8 +9,8 @@ from vllm.engine.output_processor.single_step import (
|
|||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
|
||||||
SequenceOutput, SequenceStatus)
|
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
@ -110,10 +110,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
# we can take the first sample.
|
# we can take the first sample.
|
||||||
samples = [output.samples[0] for output in outputs]
|
samples = [output.samples[0] for output in outputs]
|
||||||
|
|
||||||
# -1 means the output token is not valid (eg. due to spec decode
|
# entries in sample tokens may be invalid (eg. due to spec decode
|
||||||
# rejecting tokens).
|
# rejecting tokens).
|
||||||
valid_samples = [
|
valid_samples = [
|
||||||
sample for sample in samples if sample.output_token != -1
|
sample for sample in samples
|
||||||
|
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
||||||
]
|
]
|
||||||
assert valid_samples
|
assert valid_samples
|
||||||
|
|
||||||
|
@ -15,7 +15,8 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
|||||||
SamplingTensors,
|
SamplingTensors,
|
||||||
SequenceGroupToSample)
|
SequenceGroupToSample)
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||||
|
CompletionSequenceGroupOutput, Logprob,
|
||||||
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
||||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||||
|
|
||||||
@ -759,8 +760,8 @@ def _sample_with_torch(
|
|||||||
|
|
||||||
# Create output tensor for sampled token ids.
|
# Create output tensor for sampled token ids.
|
||||||
if include_gpu_probs_tensor:
|
if include_gpu_probs_tensor:
|
||||||
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
|
||||||
1,
|
VLLM_INVALID_TOKEN_ID,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=logprobs.device)
|
device=logprobs.device)
|
||||||
else:
|
else:
|
||||||
|
@ -26,6 +26,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||||
|
|
||||||
|
VLLM_INVALID_TOKEN_ID = -1
|
||||||
|
|
||||||
|
|
||||||
# We use dataclass for now because it is used for
|
# We use dataclass for now because it is used for
|
||||||
# openai server output, and msgspec is not serializable.
|
# openai server output, and msgspec is not serializable.
|
||||||
|
@ -6,9 +6,9 @@ import torch
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
SequenceData, SequenceGroupMetadata,
|
ExecuteModelRequest, SequenceData,
|
||||||
get_all_seq_ids)
|
SequenceGroupMetadata, get_all_seq_ids)
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||||
@ -69,10 +69,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||||
|
|
||||||
# Filter the list to ignore -1 proposals.
|
# Filter the list to ignore invalid proposals.
|
||||||
proposal_token_ids_list_without_skips = [
|
proposal_token_ids_list_without_skips = [
|
||||||
proposals for proposals in proposal_token_ids_list
|
proposals for proposals in proposal_token_ids_list
|
||||||
if -1 not in proposals
|
if VLLM_INVALID_TOKEN_ID not in proposals
|
||||||
]
|
]
|
||||||
|
|
||||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||||
|
@ -13,9 +13,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
|
|||||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||||
TypicalAcceptanceSampler)
|
TypicalAcceptanceSampler)
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||||
|
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||||
HiddenStates, SequenceGroupMetadata,
|
HiddenStates, SequenceGroupMetadata,
|
||||||
get_all_seq_ids, get_all_seq_ids_and_request_ids)
|
get_all_seq_ids_and_request_ids)
|
||||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
@ -28,7 +29,8 @@ from vllm.spec_decode.ngram_worker import NGramWorker
|
|||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||||
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
|
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||||
|
create_sequence_group_output,
|
||||||
get_all_num_logprobs,
|
get_all_num_logprobs,
|
||||||
get_sampled_token_logprobs, nvtx_range,
|
get_sampled_token_logprobs, nvtx_range,
|
||||||
split_batch_by_proposal_len)
|
split_batch_by_proposal_len)
|
||||||
@ -436,8 +438,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self, execute_model_req: ExecuteModelRequest,
|
self, execute_model_req: ExecuteModelRequest,
|
||||||
sampler_output: SamplerOutput) -> SamplerOutput:
|
sampler_output: SamplerOutput) -> SamplerOutput:
|
||||||
"""
|
"""
|
||||||
Creates and returns a `SamplerOutput` with only the sampled token IDs
|
Creates and returns a `SamplerOutput` with only the token IDs being
|
||||||
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
|
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
|
||||||
All other parameters in `CompletionSequenceGroupOutput` related to log
|
All other parameters in `CompletionSequenceGroupOutput` related to log
|
||||||
probabilities are skipped.
|
probabilities are skipped.
|
||||||
|
|
||||||
@ -449,14 +451,46 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SamplerOutput: A new `SamplerOutput` instance containing a list of
|
SamplerOutput: A new `SamplerOutput` instance containing a list of
|
||||||
`CompletionSequenceGroupOutput` objects with only sampled token
|
`CompletionSequenceGroupOutput` objects with only token IDs
|
||||||
IDs populated.
|
populated.
|
||||||
"""
|
"""
|
||||||
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
|
seq_output_prompt_logprobs = [
|
||||||
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
|
seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
|
||||||
|
and seq.sampling_params.prompt_logprobs > 0
|
||||||
|
for seq in execute_model_req.seq_group_metadata_list
|
||||||
|
]
|
||||||
|
# ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
|
||||||
|
sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
|
||||||
|
# subtracting is faster than testing for equality
|
||||||
|
sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
|
||||||
|
if any(seq_output_prompt_logprobs) else \
|
||||||
|
sampler_output.sampled_token_ids).tolist()
|
||||||
|
|
||||||
|
seq_data_entries = (
|
||||||
|
(seq_id, seq_data) for sg in \
|
||||||
|
execute_model_req.seq_group_metadata_list \
|
||||||
|
for seq_id, seq_data in sg.seq_data.items()
|
||||||
|
)
|
||||||
completion_seq_group_output_list: List[
|
completion_seq_group_output_list: List[
|
||||||
CompletionSequenceGroupOutput] = []
|
CompletionSequenceGroupOutput] = []
|
||||||
for index, seq_id in enumerate(seq_ids):
|
for index, ((seq_id, seq_data), needs_prompt_logprobs) in \
|
||||||
|
enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)):
|
||||||
|
if needs_prompt_logprobs:
|
||||||
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
|
prompt_logprobs = [
|
||||||
|
create_logprobs_output(
|
||||||
|
token_id=p_token_id,
|
||||||
|
token_id_logprob_rank=-1,
|
||||||
|
token_id_logprob=0.0,
|
||||||
|
topk_token_ids=[],
|
||||||
|
topk_logprobs=[],
|
||||||
|
)
|
||||||
|
# no prompt logprobs for the first token
|
||||||
|
for p_token_id in prompt_token_ids[1:]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
prompt_logprobs = None
|
||||||
|
|
||||||
completion_seq_group_output_list.append(
|
completion_seq_group_output_list.append(
|
||||||
create_sequence_group_output(
|
create_sequence_group_output(
|
||||||
token_id=sampled_token_ids_list[index][0],
|
token_id=sampled_token_ids_list[index][0],
|
||||||
@ -465,7 +499,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
seq_id=seq_id,
|
seq_id=seq_id,
|
||||||
topk_token_ids=[],
|
topk_token_ids=[],
|
||||||
topk_logprobs=[],
|
topk_logprobs=[],
|
||||||
))
|
prompt_logprobs=prompt_logprobs))
|
||||||
return SamplerOutput(outputs=completion_seq_group_output_list)
|
return SamplerOutput(outputs=completion_seq_group_output_list)
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||||
@ -485,6 +519,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# Store hidden states from target model execution.
|
# Store hidden states from target model execution.
|
||||||
hidden_states = sampler_output.hidden_states
|
hidden_states = sampler_output.hidden_states
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
|
# remove hidden_states for prompt tokens
|
||||||
|
if any(seq.is_prompt
|
||||||
|
for seq in execute_model_req.seq_group_metadata_list):
|
||||||
|
hidden_states = hidden_states[
|
||||||
|
torch.where(sampler_output.sampled_token_ids -
|
||||||
|
VLLM_INVALID_TOKEN_ID)[0]]
|
||||||
if self.previous_hidden_states is None:
|
if self.previous_hidden_states is None:
|
||||||
self.previous_hidden_states = HiddenStates(
|
self.previous_hidden_states = HiddenStates(
|
||||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||||
|
@ -6,7 +6,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
SequenceGroupMetadata, SequenceOutput)
|
PromptLogprobs, SequenceGroupMetadata,
|
||||||
|
SequenceOutput)
|
||||||
|
|
||||||
SeqId = int
|
SeqId = int
|
||||||
|
|
||||||
@ -49,21 +50,19 @@ def get_sampled_token_logprobs(
|
|||||||
return sampled_token_ids_ranks, selected_logprobs
|
return sampled_token_ids_ranks, selected_logprobs
|
||||||
|
|
||||||
|
|
||||||
def create_sequence_group_output(
|
def create_logprobs_output(
|
||||||
token_id: int,
|
token_id: int,
|
||||||
token_id_logprob_rank: int,
|
token_id_logprob_rank: int,
|
||||||
token_id_logprob: float,
|
token_id_logprob: float,
|
||||||
seq_id: SeqId,
|
|
||||||
topk_token_ids: List[Optional[int]],
|
topk_token_ids: List[Optional[int]],
|
||||||
topk_logprobs: List[Optional[float]],
|
topk_logprobs: List[Optional[float]],
|
||||||
) -> CompletionSequenceGroupOutput:
|
) -> Dict[int, Logprob]:
|
||||||
"""Create a SequenceGroupOutput given the sampling results.
|
"""Create a Logprob Dict for a token given the sampling results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_id (int): The sampled token for the sequence.
|
token_id (int): The sampled token for the sequence.
|
||||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||||
token_id_logprob (float): The logprob value of the sampled token.
|
token_id_logprob (float): The logprob value of the sampled token.
|
||||||
seq_id (int): The sequence id.
|
|
||||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||||
"""
|
"""
|
||||||
@ -85,14 +84,44 @@ def create_sequence_group_output(
|
|||||||
if topk_token_id is not None
|
if topk_token_id is not None
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def create_sequence_group_output(
|
||||||
|
token_id: int,
|
||||||
|
token_id_logprob_rank: int,
|
||||||
|
token_id_logprob: float,
|
||||||
|
seq_id: SeqId,
|
||||||
|
topk_token_ids: List[Optional[int]],
|
||||||
|
topk_logprobs: List[Optional[float]],
|
||||||
|
prompt_logprobs: Optional[PromptLogprobs] = None,
|
||||||
|
) -> CompletionSequenceGroupOutput:
|
||||||
|
"""Create a SequenceGroupOutput given the sampling results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_id (int): The sampled token for the sequence.
|
||||||
|
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||||
|
token_id_logprob (float): The logprob value of the sampled token.
|
||||||
|
seq_id (int): The sequence id.
|
||||||
|
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||||
|
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logprobs = create_logprobs_output(
|
||||||
|
token_id,
|
||||||
|
token_id_logprob_rank,
|
||||||
|
token_id_logprob,
|
||||||
|
topk_token_ids,
|
||||||
|
topk_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
return CompletionSequenceGroupOutput(
|
return CompletionSequenceGroupOutput(
|
||||||
samples=[
|
samples=[
|
||||||
SequenceOutput(parent_seq_id=seq_id,
|
SequenceOutput(parent_seq_id=seq_id,
|
||||||
output_token=token_id,
|
output_token=token_id,
|
||||||
logprobs=logprobs)
|
logprobs=logprobs)
|
||||||
],
|
],
|
||||||
# TODO add prompt logprobs support.
|
prompt_logprobs=prompt_logprobs,
|
||||||
prompt_logprobs=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
|
||||||
|
Sequence, SequenceGroup)
|
||||||
|
|
||||||
from .tokenizer import AnyTokenizer
|
from .tokenizer import AnyTokenizer
|
||||||
from .tokenizer_group import BaseTokenizerGroup
|
from .tokenizer_group import BaseTokenizerGroup
|
||||||
|
|
||||||
# Used eg. for marking rejected tokens in spec decoding.
|
|
||||||
INVALID_TOKEN_ID = -1
|
|
||||||
|
|
||||||
|
|
||||||
class Detokenizer:
|
class Detokenizer:
|
||||||
"""Provides methods to decode the output of a model into text."""
|
"""Provides methods to decode the output of a model into text."""
|
||||||
@ -61,7 +59,7 @@ class Detokenizer:
|
|||||||
continue
|
continue
|
||||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||||
if (sample_logprob.decoded_token is None
|
if (sample_logprob.decoded_token is None
|
||||||
and token_id != INVALID_TOKEN_ID):
|
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||||
prompt_token_ids_with_token = (
|
prompt_token_ids_with_token = (
|
||||||
prompt_token_ids[:token_position] + [token_id])
|
prompt_token_ids[:token_position] + [token_id])
|
||||||
(new_tokens, new_text, new_prefix_offset,
|
(new_tokens, new_text, new_prefix_offset,
|
||||||
@ -143,7 +141,7 @@ class Detokenizer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if (sample_logprob.decoded_token is None
|
if (sample_logprob.decoded_token is None
|
||||||
and token_id != INVALID_TOKEN_ID):
|
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||||
(_, new_text, _, _) = detokenize_incrementally(
|
(_, new_text, _, _) = detokenize_incrementally(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -282,14 +280,14 @@ def detokenize_incrementally(
|
|||||||
assert prev_tokens is not None
|
assert prev_tokens is not None
|
||||||
|
|
||||||
# If the new token id is out of bounds, return an empty string.
|
# If the new token id is out of bounds, return an empty string.
|
||||||
if new_token_id >= len(tokenizer):
|
if 0 <= new_token_id < len(tokenizer):
|
||||||
new_tokens = [""]
|
|
||||||
else:
|
|
||||||
# Put new_token_id in a list so skip_special_tokens is respected
|
# Put new_token_id in a list so skip_special_tokens is respected
|
||||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||||
if isinstance(new_tokens, str):
|
if isinstance(new_tokens, str):
|
||||||
new_tokens = [new_tokens]
|
new_tokens = [new_tokens]
|
||||||
|
else:
|
||||||
|
new_tokens = [""]
|
||||||
output_tokens = prev_tokens + new_tokens
|
output_tokens = prev_tokens + new_tokens
|
||||||
|
|
||||||
# If this is the first iteration, return all tokens.
|
# If this is the first iteration, return all tokens.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user