From b9bd76ca14bca7ddb912e65c8aa45a8044869f1b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 23 Mar 2025 10:41:44 -0700 Subject: [PATCH] [V1][Spec Decode] Respect prompt_lookup_max (#15348) Signed-off-by: Woosuk Kwon --- tests/v1/spec_decode/test_ngram.py | 53 ++++++++++++++++++++++++++- vllm/v1/spec_decode/ngram_proposer.py | 17 +++++++-- vllm/v1/worker/gpu_model_runner.py | 2 + 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 2c2e125a..a81b4897 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -2,7 +2,8 @@ import numpy as np -from vllm.v1.spec_decode.ngram_proposer import (_find_subarray_kmp, +from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, + _find_subarray_kmp, _kmp_lps_array) @@ -35,3 +36,53 @@ def test_find_subarray_kmp(): # Return on the first match np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3), np.array([6, 2, 3])) + + +def test_ngram_proposer(): + proposer = NgramProposer() + + # No match. + result = proposer.propose( + context_token_ids=np.array([1, 2, 3, 4, 5]), + min_n=2, + max_n=2, + k=2, + ) + assert result is None + + # No match for 4-gram. + result = proposer.propose( + context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), + min_n=4, + max_n=4, + k=2, + ) + assert result is None + + # No match for 4-gram but match for 3-gram. + result = proposer.propose( + context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), + min_n=3, + max_n=4, + k=2, + ) + assert np.array_equal(result, np.array([4, 1])) + + # Match for both 4-gram and 3-gram. + # In this case, the proposer should return the 4-gram match. + result = proposer.propose( + context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]), + min_n=3, + max_n=4, + k=2, + ) + assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] + + # Match for 2-gram and 3-gram, but not 4-gram. + result = proposer.propose( + context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]), + min_n=2, + max_n=4, + k=2, + ) + assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 33289d05..0bef349e 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -10,7 +10,8 @@ class NgramProposer: def propose( self, context_token_ids: np.ndarray, - n: int, + min_n: int, + max_n: int, k: int, ) -> Optional[np.ndarray]: """Proposes the next sequence of tokens based on n-gram pattern @@ -21,7 +22,8 @@ class NgramProposer: Args: context_token_ids: Numpy array of token IDs representing the context sequence. - n: Length of the n-gram to match. + min_n: Minimum length of the n-gram to match. + max_n: Maximum length of the n-gram to match. k: Number of tokens follow the match. If there are less than k tokens follow the match, we will return the maximum amount of tokens until the end. @@ -32,14 +34,21 @@ class NgramProposer: None: If no matching n-gram pattern is found. Example: - If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4: + If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and + k = 4: + - The last 3 (= max_n) tokens [4,2,3] cannot find a match. - The last 2 tokens [2,3] will be matched against the previous 4 tokens [1,2,3,4]. - Finding a match of [2,3] would return the tokens that followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ - return _find_subarray_kmp(context_token_ids, n, k) + # TODO(woosuk): Optimize this. + for n in range(max_n, min_n - 1, -1): + result = _find_subarray_kmp(context_token_ids, n, k) + if result is not None: + return result + return None @jit(nopython=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 46ad5239..66358d96 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -160,6 +160,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.drafter.propose( np.zeros(1024, dtype=np.int32), self.speculative_config.prompt_lookup_min, + self.speculative_config.prompt_lookup_max, self.speculative_config.num_speculative_tokens, ) self.rejection_sampler = RejectionSampler() @@ -1155,6 +1156,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx], self.speculative_config.prompt_lookup_min, + self.speculative_config.prompt_lookup_max, self.speculative_config.num_speculative_tokens, ) if drafter_output is None or len(drafter_output) == 0: