diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json index 415171e2..13fd5aa8 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests.json @@ -63,10 +63,12 @@ "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "disable_log_requests": "", "tensor_parallel_size": 4, - "swap_space": 16, - "speculative_model": "turboderp/Qwama-0.5B-Instruct", - "num_speculative_tokens": 4, - "speculative_draft_tensor_parallel_size": 1 + "swap_space": 16, + "speculative_config": { + "model": "turboderp/Qwama-0.5B-Instruct", + "num_speculative_tokens": 4, + "draft_tensor_parallel_size": 1 + } }, "client_parameters": { "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index 3e1f1d5b..f16e0d96 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -52,7 +52,7 @@ python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model ``` :::{warning} -Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately will be deprecated in the next release. +Note: Please use `--speculative_config` to set all configurations related to speculative decoding. The previous method of specifying the model through `--speculative_model` and adding related parameters (e.g., `--num_speculative_tokens`) separately has been deprecated now. ::: Then use a client: diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index baa91b2d..db5012ba 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -69,10 +69,12 @@ llm = LLM( max_model_len=max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, - speculative_model=eagle_dir, - num_speculative_tokens=args.num_spec_tokens, - speculative_draft_tensor_parallel_size=args.draft_tp, - speculative_max_model_len=max_model_len, + speculative_config={ + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + "draft_tensor_parallel_size": args.draft_tp, + "max_model_len": max_model_len, + }, disable_log_stats=False, ) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 8ddcefd9..e71c87ff 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -248,8 +248,10 @@ def test_metric_spec_decode( dtype=dtype, disable_log_stats=False, gpu_memory_utilization=0.4, - speculative_model=model, - num_speculative_tokens=k, + speculative_config={ + "model": model, + "num_speculative_tokens": k, + }, ) as vllm_model: # Force log interval to be 0 to catch all metrics. @@ -300,8 +302,10 @@ def test_metric_spec_decode_interval( dtype=dtype, disable_log_stats=False, gpu_memory_utilization=0.4, - speculative_model=model, - num_speculative_tokens=k, + speculative_config={ + "model": model, + "num_speculative_tokens": k, + }, enforce_eager=True, ) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index adb2d6d0..58705637 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -54,8 +54,10 @@ def test_can_initialize(model_arch): model_info.default, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, - speculative_model=model_info.speculative_model, - num_speculative_tokens=1 if model_info.speculative_model else None, + speculative_config={ + "model": model_info.speculative_model, + "num_speculative_tokens": 1, + } if model_info.speculative_model else None, trust_remote_code=model_info.trust_remote_code, load_format="dummy", hf_overrides=hf_overrides, diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index b8a2631b..b1129747 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -3,6 +3,7 @@ tensor parallelism. """ +import json from typing import Optional import pytest @@ -28,14 +29,14 @@ from .conftest import run_equality_correctness_test_tp @pytest.mark.parametrize("test_llm_kwargs", [ [ "--speculative_config", - str({ + json.dumps({ "model": "JackFram/llama-68m", "num_speculative_tokens": 3, }), ], [ "--speculative_config", - str({ + json.dumps({ "model": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 3, @@ -88,7 +89,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, "model, test_llm_kwargs", [("JackFram/llama-68m", [ "--speculative_config", - str({ + json.dumps({ "model": "JackFram/llama-68m", "num_speculative_tokens": 5, "draft_tensor_parallel_size": 1, @@ -96,7 +97,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, ]), ("ibm-granite/granite-3b-code-instruct", [ "--speculative_config", - str({ + json.dumps({ "model": "ibm-granite/granite-3b-code-instruct", "num_speculative_tokens": 5, "draft_tensor_parallel_size": 1, @@ -147,20 +148,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, @pytest.mark.parametrize("model, test_llm_kwargs", [("JackFram/llama-68m", [ "--speculative_config", - str({ + json.dumps({ "model": "JackFram/llama-68m", "num_speculative_tokens": 3, }), ]), ("JackFram/llama-68m", [ "--speculative_config", - str({ + json.dumps({ "model": "JackFram/llama-68m", "num_speculative_tokens": 3, "draft_tensor_parallel_size": 1, }), ])]) -@pytest.mark.parametrize("logprobs", [None, 2]) +@pytest.mark.parametrize("logprobs", [None]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, @@ -171,9 +172,68 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, """Verify spec decode works well with same and different TP size for the draft model with chunked prefill. """ - if logprobs: - test_llm_kwargs.extend( - ["--disable_logprobs_during_spec_decoding", "False"]) + run_equality_correctness_test_tp(model, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=32, + seed=seed, + temperature=0.0, + logprobs=logprobs) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [[ + # Skip cuda graph recording for fast test. + "--enforce-eager", + "--tensor_parallel_size", + "2", + + # precision + "--dtype", + "bfloat16", + ]]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [["--enable-chunked-prefill", "False"], + [ + "--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4", + "--max-num-seqs", "4" + ]]) +@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) +@pytest.mark.parametrize("model, test_llm_kwargs", + [("JackFram/llama-68m", [ + "--speculative_config", + json.dumps({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs": False, + }), + ]), + ("JackFram/llama-68m", [ + "--speculative_config", + json.dumps({ + "model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "draft_tensor_parallel_size": 1, + "disable_logprobs": False, + }), + ])]) +@pytest.mark.parametrize("logprobs", [2]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_chunked_prefill_tp2_with_logprobs( + model, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, logprobs: Optional[int], + batch_size: int, seed: int): + """Verify spec decode works well with same and different TP size for + the draft model with chunked prefill. + """ run_equality_correctness_test_tp(model, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index d42d9029..a1b7c8b4 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -3,6 +3,8 @@ tensor parallelism. """ +import json + import openai import pytest import torch @@ -33,7 +35,7 @@ SPEC_MODEL = "JackFram/llama-68m" #TODO(wooyeon): add spec_draft_dp=2 case [ "--speculative_config", - str({ + json.dumps({ "model": f"{SPEC_MODEL}", "num_speculative_tokens": 5, "draft_tensor_parallel_size": 1, @@ -80,7 +82,7 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "--speculative_config", - str({ + json.dumps({ "model": f"{SPEC_MODEL}", "num_speculative_tokens": 5, "max_model_len": 32, diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index d74a96fb..762c7bad 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -49,7 +49,9 @@ def test_unsupported_configs(monkeypatch): with pytest.raises(NotImplementedError): AsyncEngineArgs( model=MODEL, - speculative_model=MODEL, + speculative_config={ + "model": MODEL, + }, ).create_engine_config() with pytest.raises(NotImplementedError): diff --git a/vllm/config.py b/vllm/config.py index bd192af2..b06f1196 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2047,14 +2047,13 @@ class SpeculativeConfig: def __post_init__(self): - # Note: After next release, the method parameter will be used to - # specify the speculative method, which helps to extend the - # configuration of non-model-based proposers, and the model parameter - # will be used when the draft model or head is needed. - # If users do not specify the method, the speculative method will - # be detected automatically if possible. If the speculative method can - # not be detected, it will be considered as the draft-model-based - # method by default. + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting @@ -2069,8 +2068,8 @@ class SpeculativeConfig: raise ValueError("num_speculative_tokens was provided without " "speculative model.") - # Automatically configure the ngram method during configuration - # refactoring to ensure a smooth transition. + # Automatically configure the method for ngram when "model" is used + # instead of "method" if self.method is None and (self.model is not None and self.model in ("ngram", "[ngram]")): self.method = "ngram" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1da021d7..e29b04ab 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -181,22 +181,7 @@ class EngineArgs: guided_decoding_backend: str = 'xgrammar' logits_processor_pattern: Optional[str] = None - speculative_config: Optional[Union[str, Dict[str, Any]]] = None - - # TODO(Shangming): Deprecate these out-of-date params after next release - speculative_model: Optional[str] = None - speculative_model_quantization: Optional[str] = None - speculative_draft_tensor_parallel_size: Optional[int] = None - num_speculative_tokens: Optional[int] = None - speculative_disable_mqa_scorer: Optional[bool] = False - speculative_max_model_len: Optional[int] = None - speculative_disable_by_batch_size: Optional[int] = None - ngram_prompt_lookup_max: Optional[int] = None - ngram_prompt_lookup_min: Optional[int] = None - spec_decoding_acceptance_method: str = 'rejection_sampler' - typical_acceptance_sampler_posterior_threshold: Optional[float] = None - typical_acceptance_sampler_posterior_alpha: Optional[float] = None - disable_logprobs_during_spec_decoding: Optional[bool] = None + speculative_config: Optional[Dict[str, Any]] = None qlora_adapter_name_or_path: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = None @@ -793,122 +778,10 @@ class EngineArgs: help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens.') parser.add_argument('--speculative-config', - type=nullable_str, + type=json.loads, default=None, help='The configurations for speculative decoding.' ' Should be a JSON string.') - parser.add_argument( - '--speculative-model', - type=nullable_str, - default=EngineArgs.speculative_model, - help= - 'The name of the draft model to be used in speculative decoding.') - # Quantization settings for speculative model. - parser.add_argument( - '--speculative-model-quantization', - type=nullable_str, - choices=[*QUANTIZATION_METHODS, None], - default=EngineArgs.speculative_model_quantization, - help='Method used to quantize the weights of speculative model. ' - 'If None, we first check the `quantization_config` ' - 'attribute in the model config file. If that is ' - 'None, we assume the model weights are not ' - 'quantized and use `dtype` to determine the data ' - 'type of the weights.') - parser.add_argument( - '--num-speculative-tokens', - type=int, - default=EngineArgs.num_speculative_tokens, - help='The number of speculative tokens to sample from ' - 'the draft model in speculative decoding.') - parser.add_argument( - '--speculative-disable-mqa-scorer', - action='store_true', - help= - 'If set to True, the MQA scorer will be disabled in speculative ' - ' and fall back to batch expansion') - parser.add_argument( - '--speculative-draft-tensor-parallel-size', - '-spec-draft-tp', - type=int, - default=EngineArgs.speculative_draft_tensor_parallel_size, - help='Number of tensor parallel replicas for ' - 'the draft model in speculative decoding.') - - parser.add_argument( - '--speculative-max-model-len', - type=int, - default=EngineArgs.speculative_max_model_len, - help='The maximum sequence length supported by the ' - 'draft model. Sequences over this length will skip ' - 'speculation.') - - parser.add_argument( - '--speculative-disable-by-batch-size', - type=int, - default=EngineArgs.speculative_disable_by_batch_size, - help='Disable speculative decoding for new incoming requests ' - 'if the number of enqueue requests is larger than this value.') - - parser.add_argument( - '--ngram-prompt-lookup-max', - type=int, - default=EngineArgs.ngram_prompt_lookup_max, - help='Max size of window for ngram prompt lookup in speculative ' - 'decoding.') - - parser.add_argument( - '--ngram-prompt-lookup-min', - type=int, - default=EngineArgs.ngram_prompt_lookup_min, - help='Min size of window for ngram prompt lookup in speculative ' - 'decoding.') - - parser.add_argument( - '--spec-decoding-acceptance-method', - type=str, - default=EngineArgs.spec_decoding_acceptance_method, - choices=['rejection_sampler', 'typical_acceptance_sampler'], - help='Specify the acceptance method to use during draft token ' - 'verification in speculative decoding. Two types of acceptance ' - 'routines are supported: ' - '1) RejectionSampler which does not allow changing the ' - 'acceptance rate of draft tokens, ' - '2) TypicalAcceptanceSampler which is configurable, allowing for ' - 'a higher acceptance rate at the cost of lower quality, ' - 'and vice versa.') - - parser.add_argument( - '--typical-acceptance-sampler-posterior-threshold', - type=float, - default=EngineArgs.typical_acceptance_sampler_posterior_threshold, - help='Set the lower bound threshold for the posterior ' - 'probability of a token to be accepted. This threshold is ' - 'used by the TypicalAcceptanceSampler to make sampling decisions ' - 'during speculative decoding.') - - parser.add_argument( - '--typical-acceptance-sampler-posterior-alpha', - type=float, - default=EngineArgs.typical_acceptance_sampler_posterior_alpha, - help='A scaling factor for the entropy-based threshold for token ' - 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' - 'to sqrt of --typical-acceptance-sampler-posterior-threshold ' - 'i.e. 0.3') - - parser.add_argument( - '--disable-logprobs-during-spec-decoding', - action=StoreBoolean, - default=EngineArgs.disable_logprobs_during_spec_decoding, - nargs="?", - const="True", - help='If set to True, token log probabilities are not returned ' - 'during speculative decoding. If set to False, log probabilities ' - 'are returned according to the settings in SamplingParams. If ' - 'not specified, it defaults to True. Disabling log probabilities ' - 'during speculative decoding reduces latency by skipping logprob ' - 'calculation in proposal sampling, target sampling, and after ' - 'accepted tokens are determined.') parser.add_argument('--model-loader-extra-config', type=nullable_str, @@ -1221,58 +1094,14 @@ class EngineArgs: This function utilizes `speculative_config` to create a SpeculativeConfig object. The `speculative_config` can either be provided as a JSON string input via CLI arguments or directly as a - dictionary from the engine. If `speculative_config` is not set, this - function will attempt to construct a configuration dictionary using - certain parameters, which are scheduled for deprecation in the next - release. Note that in next releases, `speculative_config` must be - provided, and the deprecated standalone speculative-related parameters - will be removed. + dictionary from the engine. """ if self.speculative_config is None: - if (self.speculative_model is None - and self.num_speculative_tokens is None): - return None + return None - # TODO(Shangming): Deprecate this way of setting SpeculativeConfig, - # only allow '--speculative-config' after next release - logger.warning_once( - "Please use '--speculative-config' to set all configurations " - "related to speculative decoding. The current method of " - "specifying the model through '--speculative-model' and " - "adding related parameters (e.g., '--num-speculative-tokens') " - "separately will be deprecated in the next release.") - - spec_config_dict = { - "model": self.speculative_model, - "quantization": self.speculative_model_quantization, - "max_model_len": self.speculative_max_model_len, - "draft_tensor_parallel_size": - self.speculative_draft_tensor_parallel_size, - "num_speculative_tokens": self.num_speculative_tokens, - "disable_mqa_scorer": self.speculative_disable_mqa_scorer, - "disable_by_batch_size": - self.speculative_disable_by_batch_size, - "prompt_lookup_max": self.ngram_prompt_lookup_max, - "prompt_lookup_min": self.ngram_prompt_lookup_min, - "acceptance_method": self.spec_decoding_acceptance_method, - "posterior_threshold": - self.typical_acceptance_sampler_posterior_threshold, - "posterior_alpha": - self.typical_acceptance_sampler_posterior_alpha, - "disable_logprobs": self.disable_logprobs_during_spec_decoding, - } - - self.speculative_config = spec_config_dict - else: - if isinstance(self.speculative_config, str): - import ast - self.speculative_config = ast.literal_eval( - self.speculative_config) # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. - - assert isinstance(self.speculative_config, dict) self.speculative_config.update({ "target_model_config": target_model_config, "target_parallel_config": target_parallel_config, @@ -1638,11 +1467,15 @@ class EngineArgs: return False # Only Ngram speculative decoding so far. - if (self.speculative_model is not None - or self.num_speculative_tokens is not None): + is_ngram_enabled = False + if self.speculative_config is not None: # This is supported but experimental (handled below). - if self.speculative_model in ("ngram", "[ngram]"): - pass + if (("method" in self.speculative_config + and self.speculative_config["method"] in ("ngram", "[ngram]")) + or + ("model" in self.speculative_config and + self.speculative_config["model"] in ("ngram", "[ngram]"))): + is_ngram_enabled = True else: _raise_or_fallback(feature_name="Speculative Decoding", recommend_to_remove=False) @@ -1691,8 +1524,7 @@ class EngineArgs: return False # ngram is supported on V1, but off by default for now. - if self.speculative_model in ( - "ngram", "[ngram]") and _warn_or_fallback("ngram"): + if is_ngram_enabled and _warn_or_fallback("ngram"): return False # Non-CUDA is supported on V1, but off by default for now. @@ -1721,7 +1553,7 @@ class EngineArgs: is_gpu = current_platform.is_cuda() use_sliding_window = (model_config.get_sliding_window() is not None) - use_spec_decode = self.speculative_model is not None + use_spec_decode = self.speculative_config is not None if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora