[V1][Spec Decode] Remove deprecated spec decode config params (#15466)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
09e974d483
commit
239b7befdd
@ -63,10 +63,12 @@
|
|||||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"disable_log_requests": "",
|
"disable_log_requests": "",
|
||||||
"tensor_parallel_size": 4,
|
"tensor_parallel_size": 4,
|
||||||
"swap_space": 16,
|
"swap_space": 16,
|
||||||
"speculative_model": "turboderp/Qwama-0.5B-Instruct",
|
"speculative_config": {
|
||||||
"num_speculative_tokens": 4,
|
"model": "turboderp/Qwama-0.5B-Instruct",
|
||||||
"speculative_draft_tensor_parallel_size": 1
|
"num_speculative_tokens": 4,
|
||||||
|
"draft_tensor_parallel_size": 1
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"client_parameters": {
|
"client_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
|
@ -52,7 +52,7 @@ python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model
|
|||||||
```
|
```
|
||||||
|
|
||||||
:::{warning}
|
:::{warning}
|
||||||
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release.
|
Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately has been deprecated now.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Then use a client:
|
Then use a client:
|
||||||
|
@ -69,10 +69,12 @@ llm = LLM(
|
|||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
max_num_seqs=args.max_num_seqs,
|
max_num_seqs=args.max_num_seqs,
|
||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.8,
|
||||||
speculative_model=eagle_dir,
|
speculative_config={
|
||||||
num_speculative_tokens=args.num_spec_tokens,
|
"model": eagle_dir,
|
||||||
speculative_draft_tensor_parallel_size=args.draft_tp,
|
"num_speculative_tokens": args.num_spec_tokens,
|
||||||
speculative_max_model_len=max_model_len,
|
"draft_tensor_parallel_size": args.draft_tp,
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
},
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -248,8 +248,10 @@ def test_metric_spec_decode(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
gpu_memory_utilization=0.4,
|
gpu_memory_utilization=0.4,
|
||||||
speculative_model=model,
|
speculative_config={
|
||||||
num_speculative_tokens=k,
|
"model": model,
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
|
|
||||||
# Force log interval to be 0 to catch all metrics.
|
# Force log interval to be 0 to catch all metrics.
|
||||||
@ -300,8 +302,10 @@ def test_metric_spec_decode_interval(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
gpu_memory_utilization=0.4,
|
gpu_memory_utilization=0.4,
|
||||||
speculative_model=model,
|
speculative_config={
|
||||||
num_speculative_tokens=k,
|
"model": model,
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
},
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,8 +54,10 @@ def test_can_initialize(model_arch):
|
|||||||
model_info.default,
|
model_info.default,
|
||||||
tokenizer=model_info.tokenizer,
|
tokenizer=model_info.tokenizer,
|
||||||
tokenizer_mode=model_info.tokenizer_mode,
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
speculative_model=model_info.speculative_model,
|
speculative_config={
|
||||||
num_speculative_tokens=1 if model_info.speculative_model else None,
|
"model": model_info.speculative_model,
|
||||||
|
"num_speculative_tokens": 1,
|
||||||
|
} if model_info.speculative_model else None,
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
load_format="dummy",
|
load_format="dummy",
|
||||||
hf_overrides=hf_overrides,
|
hf_overrides=hf_overrides,
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
tensor parallelism.
|
tensor parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -28,14 +29,14 @@ from .conftest import run_equality_correctness_test_tp
|
|||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
[
|
[
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
}),
|
}),
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": "ngram",
|
"model": "ngram",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"prompt_lookup_max": 3,
|
"prompt_lookup_max": 3,
|
||||||
@ -88,7 +89,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
|||||||
"model, test_llm_kwargs",
|
"model, test_llm_kwargs",
|
||||||
[("JackFram/llama-68m", [
|
[("JackFram/llama-68m", [
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"draft_tensor_parallel_size": 1,
|
"draft_tensor_parallel_size": 1,
|
||||||
@ -96,7 +97,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
|||||||
]),
|
]),
|
||||||
("ibm-granite/granite-3b-code-instruct", [
|
("ibm-granite/granite-3b-code-instruct", [
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": "ibm-granite/granite-3b-code-instruct",
|
"model": "ibm-granite/granite-3b-code-instruct",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"draft_tensor_parallel_size": 1,
|
"draft_tensor_parallel_size": 1,
|
||||||
@ -147,20 +148,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
|||||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||||
[("JackFram/llama-68m", [
|
[("JackFram/llama-68m", [
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
("JackFram/llama-68m", [
|
("JackFram/llama-68m", [
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"draft_tensor_parallel_size": 1,
|
"draft_tensor_parallel_size": 1,
|
||||||
}),
|
}),
|
||||||
])])
|
])])
|
||||||
@pytest.mark.parametrize("logprobs", [None, 2])
|
@pytest.mark.parametrize("logprobs", [None])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||||
@ -171,9 +172,68 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
|||||||
"""Verify spec decode works well with same and different TP size for
|
"""Verify spec decode works well with same and different TP size for
|
||||||
the draft model with chunked prefill.
|
the draft model with chunked prefill.
|
||||||
"""
|
"""
|
||||||
if logprobs:
|
run_equality_correctness_test_tp(model,
|
||||||
test_llm_kwargs.extend(
|
common_llm_kwargs,
|
||||||
["--disable_logprobs_during_spec_decoding", "False"])
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=32,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[[
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"--enforce-eager",
|
||||||
|
"--tensor_parallel_size",
|
||||||
|
"2",
|
||||||
|
|
||||||
|
# precision
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
]])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"per_test_common_llm_kwargs",
|
||||||
|
[["--enable-chunked-prefill", "False"],
|
||||||
|
[
|
||||||
|
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
|
||||||
|
"--max-num-seqs", "4"
|
||||||
|
]])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
|
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||||
|
[("JackFram/llama-68m", [
|
||||||
|
"--speculative_config",
|
||||||
|
json.dumps({
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
"disable_logprobs": False,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
("JackFram/llama-68m", [
|
||||||
|
"--speculative_config",
|
||||||
|
json.dumps({
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
"draft_tensor_parallel_size": 1,
|
||||||
|
"disable_logprobs": False,
|
||||||
|
}),
|
||||||
|
])])
|
||||||
|
@pytest.mark.parametrize("logprobs", [2])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_chunked_prefill_tp2_with_logprobs(
|
||||||
|
model, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs, logprobs: Optional[int],
|
||||||
|
batch_size: int, seed: int):
|
||||||
|
"""Verify spec decode works well with same and different TP size for
|
||||||
|
the draft model with chunked prefill.
|
||||||
|
"""
|
||||||
run_equality_correctness_test_tp(model,
|
run_equality_correctness_test_tp(model,
|
||||||
common_llm_kwargs,
|
common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs,
|
per_test_common_llm_kwargs,
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
tensor parallelism.
|
tensor parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -33,7 +35,7 @@ SPEC_MODEL = "JackFram/llama-68m"
|
|||||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||||
[
|
[
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": f"{SPEC_MODEL}",
|
"model": f"{SPEC_MODEL}",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"draft_tensor_parallel_size": 1,
|
"draft_tensor_parallel_size": 1,
|
||||||
@ -80,7 +82,7 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
|||||||
# Artificially limit the draft model max model len; this forces vLLM
|
# Artificially limit the draft model max model len; this forces vLLM
|
||||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||||
"--speculative_config",
|
"--speculative_config",
|
||||||
str({
|
json.dumps({
|
||||||
"model": f"{SPEC_MODEL}",
|
"model": f"{SPEC_MODEL}",
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"max_model_len": 32,
|
"max_model_len": 32,
|
||||||
|
@ -49,7 +49,9 @@ def test_unsupported_configs(monkeypatch):
|
|||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
AsyncEngineArgs(
|
AsyncEngineArgs(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
speculative_model=MODEL,
|
speculative_config={
|
||||||
|
"model": MODEL,
|
||||||
|
},
|
||||||
).create_engine_config()
|
).create_engine_config()
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
|
@ -2047,14 +2047,13 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
||||||
# Note: After next release, the method parameter will be used to
|
# Note: "method" is a new parameter that helps to extend the
|
||||||
# specify the speculative method, which helps to extend the
|
# configuration of non-model-based proposers, and the "model" parameter
|
||||||
# configuration of non-model-based proposers, and the model parameter
|
# will be used to set the draft model, eagle head, or additional weight
|
||||||
# will be used when the draft model or head is needed.
|
# when needed. If users do not specify "method", the speculative method
|
||||||
# If users do not specify the method, the speculative method will
|
# will be detected automatically if possible. If the speculative method
|
||||||
# be detected automatically if possible. If the speculative method can
|
# can not be detected, it will be considered as the "draft_model" by
|
||||||
# not be detected, it will be considered as the draft-model-based
|
# default.
|
||||||
# method by default.
|
|
||||||
|
|
||||||
if self.model is None and self.num_speculative_tokens is not None:
|
if self.model is None and self.num_speculative_tokens is not None:
|
||||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||||
@ -2069,8 +2068,8 @@ class SpeculativeConfig:
|
|||||||
raise ValueError("num_speculative_tokens was provided without "
|
raise ValueError("num_speculative_tokens was provided without "
|
||||||
"speculative model.")
|
"speculative model.")
|
||||||
|
|
||||||
# Automatically configure the ngram method during configuration
|
# Automatically configure the method for ngram when "model" is used
|
||||||
# refactoring to ensure a smooth transition.
|
# instead of "method"
|
||||||
if self.method is None and (self.model is not None
|
if self.method is None and (self.model is not None
|
||||||
and self.model in ("ngram", "[ngram]")):
|
and self.model in ("ngram", "[ngram]")):
|
||||||
self.method = "ngram"
|
self.method = "ngram"
|
||||||
|
@ -181,22 +181,7 @@ class EngineArgs:
|
|||||||
guided_decoding_backend: str = 'xgrammar'
|
guided_decoding_backend: str = 'xgrammar'
|
||||||
logits_processor_pattern: Optional[str] = None
|
logits_processor_pattern: Optional[str] = None
|
||||||
|
|
||||||
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
|
speculative_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
# TODO(Shangming): Deprecate these out-of-date params after next release
|
|
||||||
speculative_model: Optional[str] = None
|
|
||||||
speculative_model_quantization: Optional[str] = None
|
|
||||||
speculative_draft_tensor_parallel_size: Optional[int] = None
|
|
||||||
num_speculative_tokens: Optional[int] = None
|
|
||||||
speculative_disable_mqa_scorer: Optional[bool] = False
|
|
||||||
speculative_max_model_len: Optional[int] = None
|
|
||||||
speculative_disable_by_batch_size: Optional[int] = None
|
|
||||||
ngram_prompt_lookup_max: Optional[int] = None
|
|
||||||
ngram_prompt_lookup_min: Optional[int] = None
|
|
||||||
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
|
||||||
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
|
||||||
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
|
||||||
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
|
||||||
|
|
||||||
qlora_adapter_name_or_path: Optional[str] = None
|
qlora_adapter_name_or_path: Optional[str] = None
|
||||||
show_hidden_metrics_for_version: Optional[str] = None
|
show_hidden_metrics_for_version: Optional[str] = None
|
||||||
@ -793,122 +778,10 @@ class EngineArgs:
|
|||||||
help='If set, the prefill requests can be chunked based on the '
|
help='If set, the prefill requests can be chunked based on the '
|
||||||
'max_num_batched_tokens.')
|
'max_num_batched_tokens.')
|
||||||
parser.add_argument('--speculative-config',
|
parser.add_argument('--speculative-config',
|
||||||
type=nullable_str,
|
type=json.loads,
|
||||||
default=None,
|
default=None,
|
||||||
help='The configurations for speculative decoding.'
|
help='The configurations for speculative decoding.'
|
||||||
' Should be a JSON string.')
|
' Should be a JSON string.')
|
||||||
parser.add_argument(
|
|
||||||
'--speculative-model',
|
|
||||||
type=nullable_str,
|
|
||||||
default=EngineArgs.speculative_model,
|
|
||||||
help=
|
|
||||||
'The name of the draft model to be used in speculative decoding.')
|
|
||||||
# Quantization settings for speculative model.
|
|
||||||
parser.add_argument(
|
|
||||||
'--speculative-model-quantization',
|
|
||||||
type=nullable_str,
|
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
|
||||||
default=EngineArgs.speculative_model_quantization,
|
|
||||||
help='Method used to quantize the weights of speculative model. '
|
|
||||||
'If None, we first check the `quantization_config` '
|
|
||||||
'attribute in the model config file. If that is '
|
|
||||||
'None, we assume the model weights are not '
|
|
||||||
'quantized and use `dtype` to determine the data '
|
|
||||||
'type of the weights.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--num-speculative-tokens',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.num_speculative_tokens,
|
|
||||||
help='The number of speculative tokens to sample from '
|
|
||||||
'the draft model in speculative decoding.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--speculative-disable-mqa-scorer',
|
|
||||||
action='store_true',
|
|
||||||
help=
|
|
||||||
'If set to True, the MQA scorer will be disabled in speculative '
|
|
||||||
' and fall back to batch expansion')
|
|
||||||
parser.add_argument(
|
|
||||||
'--speculative-draft-tensor-parallel-size',
|
|
||||||
'-spec-draft-tp',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.speculative_draft_tensor_parallel_size,
|
|
||||||
help='Number of tensor parallel replicas for '
|
|
||||||
'the draft model in speculative decoding.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--speculative-max-model-len',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.speculative_max_model_len,
|
|
||||||
help='The maximum sequence length supported by the '
|
|
||||||
'draft model. Sequences over this length will skip '
|
|
||||||
'speculation.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--speculative-disable-by-batch-size',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.speculative_disable_by_batch_size,
|
|
||||||
help='Disable speculative decoding for new incoming requests '
|
|
||||||
'if the number of enqueue requests is larger than this value.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--ngram-prompt-lookup-max',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.ngram_prompt_lookup_max,
|
|
||||||
help='Max size of window for ngram prompt lookup in speculative '
|
|
||||||
'decoding.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--ngram-prompt-lookup-min',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.ngram_prompt_lookup_min,
|
|
||||||
help='Min size of window for ngram prompt lookup in speculative '
|
|
||||||
'decoding.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--spec-decoding-acceptance-method',
|
|
||||||
type=str,
|
|
||||||
default=EngineArgs.spec_decoding_acceptance_method,
|
|
||||||
choices=['rejection_sampler', 'typical_acceptance_sampler'],
|
|
||||||
help='Specify the acceptance method to use during draft token '
|
|
||||||
'verification in speculative decoding. Two types of acceptance '
|
|
||||||
'routines are supported: '
|
|
||||||
'1) RejectionSampler which does not allow changing the '
|
|
||||||
'acceptance rate of draft tokens, '
|
|
||||||
'2) TypicalAcceptanceSampler which is configurable, allowing for '
|
|
||||||
'a higher acceptance rate at the cost of lower quality, '
|
|
||||||
'and vice versa.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--typical-acceptance-sampler-posterior-threshold',
|
|
||||||
type=float,
|
|
||||||
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
|
|
||||||
help='Set the lower bound threshold for the posterior '
|
|
||||||
'probability of a token to be accepted. This threshold is '
|
|
||||||
'used by the TypicalAcceptanceSampler to make sampling decisions '
|
|
||||||
'during speculative decoding.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--typical-acceptance-sampler-posterior-alpha',
|
|
||||||
type=float,
|
|
||||||
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
|
|
||||||
help='A scaling factor for the entropy-based threshold for token '
|
|
||||||
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
|
|
||||||
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
|
|
||||||
'i.e. 0.3')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--disable-logprobs-during-spec-decoding',
|
|
||||||
action=StoreBoolean,
|
|
||||||
default=EngineArgs.disable_logprobs_during_spec_decoding,
|
|
||||||
nargs="?",
|
|
||||||
const="True",
|
|
||||||
help='If set to True, token log probabilities are not returned '
|
|
||||||
'during speculative decoding. If set to False, log probabilities '
|
|
||||||
'are returned according to the settings in SamplingParams. If '
|
|
||||||
'not specified, it defaults to True. Disabling log probabilities '
|
|
||||||
'during speculative decoding reduces latency by skipping logprob '
|
|
||||||
'calculation in proposal sampling, target sampling, and after '
|
|
||||||
'accepted tokens are determined.')
|
|
||||||
|
|
||||||
parser.add_argument('--model-loader-extra-config',
|
parser.add_argument('--model-loader-extra-config',
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
@ -1221,58 +1094,14 @@ class EngineArgs:
|
|||||||
This function utilizes `speculative_config` to create a
|
This function utilizes `speculative_config` to create a
|
||||||
SpeculativeConfig object. The `speculative_config` can either be
|
SpeculativeConfig object. The `speculative_config` can either be
|
||||||
provided as a JSON string input via CLI arguments or directly as a
|
provided as a JSON string input via CLI arguments or directly as a
|
||||||
dictionary from the engine. If `speculative_config` is not set, this
|
dictionary from the engine.
|
||||||
function will attempt to construct a configuration dictionary using
|
|
||||||
certain parameters, which are scheduled for deprecation in the next
|
|
||||||
release. Note that in next releases, `speculative_config` must be
|
|
||||||
provided, and the deprecated standalone speculative-related parameters
|
|
||||||
will be removed.
|
|
||||||
"""
|
"""
|
||||||
if self.speculative_config is None:
|
if self.speculative_config is None:
|
||||||
if (self.speculative_model is None
|
return None
|
||||||
and self.num_speculative_tokens is None):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
|
|
||||||
# only allow '--speculative-config' after next release
|
|
||||||
logger.warning_once(
|
|
||||||
"Please use '--speculative-config' to set all configurations "
|
|
||||||
"related to speculative decoding. The current method of "
|
|
||||||
"specifying the model through '--speculative-model' and "
|
|
||||||
"adding related parameters (e.g., '--num-speculative-tokens') "
|
|
||||||
"separately will be deprecated in the next release.")
|
|
||||||
|
|
||||||
spec_config_dict = {
|
|
||||||
"model": self.speculative_model,
|
|
||||||
"quantization": self.speculative_model_quantization,
|
|
||||||
"max_model_len": self.speculative_max_model_len,
|
|
||||||
"draft_tensor_parallel_size":
|
|
||||||
self.speculative_draft_tensor_parallel_size,
|
|
||||||
"num_speculative_tokens": self.num_speculative_tokens,
|
|
||||||
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
|
|
||||||
"disable_by_batch_size":
|
|
||||||
self.speculative_disable_by_batch_size,
|
|
||||||
"prompt_lookup_max": self.ngram_prompt_lookup_max,
|
|
||||||
"prompt_lookup_min": self.ngram_prompt_lookup_min,
|
|
||||||
"acceptance_method": self.spec_decoding_acceptance_method,
|
|
||||||
"posterior_threshold":
|
|
||||||
self.typical_acceptance_sampler_posterior_threshold,
|
|
||||||
"posterior_alpha":
|
|
||||||
self.typical_acceptance_sampler_posterior_alpha,
|
|
||||||
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.speculative_config = spec_config_dict
|
|
||||||
else:
|
|
||||||
if isinstance(self.speculative_config, str):
|
|
||||||
import ast
|
|
||||||
self.speculative_config = ast.literal_eval(
|
|
||||||
self.speculative_config)
|
|
||||||
# Note(Shangming): These parameters are not obtained from the cli arg
|
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||||
# '--speculative-config' and must be passed in when creating the engine
|
# '--speculative-config' and must be passed in when creating the engine
|
||||||
# config.
|
# config.
|
||||||
|
|
||||||
assert isinstance(self.speculative_config, dict)
|
|
||||||
self.speculative_config.update({
|
self.speculative_config.update({
|
||||||
"target_model_config": target_model_config,
|
"target_model_config": target_model_config,
|
||||||
"target_parallel_config": target_parallel_config,
|
"target_parallel_config": target_parallel_config,
|
||||||
@ -1638,11 +1467,15 @@ class EngineArgs:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Only Ngram speculative decoding so far.
|
# Only Ngram speculative decoding so far.
|
||||||
if (self.speculative_model is not None
|
is_ngram_enabled = False
|
||||||
or self.num_speculative_tokens is not None):
|
if self.speculative_config is not None:
|
||||||
# This is supported but experimental (handled below).
|
# This is supported but experimental (handled below).
|
||||||
if self.speculative_model in ("ngram", "[ngram]"):
|
if (("method" in self.speculative_config
|
||||||
pass
|
and self.speculative_config["method"] in ("ngram", "[ngram]"))
|
||||||
|
or
|
||||||
|
("model" in self.speculative_config and
|
||||||
|
self.speculative_config["model"] in ("ngram", "[ngram]"))):
|
||||||
|
is_ngram_enabled = True
|
||||||
else:
|
else:
|
||||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
@ -1691,8 +1524,7 @@ class EngineArgs:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# ngram is supported on V1, but off by default for now.
|
# ngram is supported on V1, but off by default for now.
|
||||||
if self.speculative_model in (
|
if is_ngram_enabled and _warn_or_fallback("ngram"):
|
||||||
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Non-CUDA is supported on V1, but off by default for now.
|
# Non-CUDA is supported on V1, but off by default for now.
|
||||||
@ -1721,7 +1553,7 @@ class EngineArgs:
|
|||||||
is_gpu = current_platform.is_cuda()
|
is_gpu = current_platform.is_cuda()
|
||||||
use_sliding_window = (model_config.get_sliding_window()
|
use_sliding_window = (model_config.get_sliding_window()
|
||||||
is not None)
|
is not None)
|
||||||
use_spec_decode = self.speculative_model is not None
|
use_spec_decode = self.speculative_config is not None
|
||||||
|
|
||||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||||
and not self.enable_lora
|
and not self.enable_lora
|
||||||
|
Loading…
x
Reference in New Issue
Block a user