[V1][Spec Decode] Respect prompt_lookup_max (#15348)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
6ebaf9ac71
commit
b9bd76ca14
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from vllm.v1.spec_decode.ngram_proposer import (_find_subarray_kmp,
|
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
|
||||||
|
_find_subarray_kmp,
|
||||||
_kmp_lps_array)
|
_kmp_lps_array)
|
||||||
|
|
||||||
|
|
||||||
@ -35,3 +36,53 @@ def test_find_subarray_kmp():
|
|||||||
# Return on the first match
|
# Return on the first match
|
||||||
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
|
np.testing.assert_array_equal(_find_subarray_kmp(X, 1, 3),
|
||||||
np.array([6, 2, 3]))
|
np.array([6, 2, 3]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_ngram_proposer():
|
||||||
|
proposer = NgramProposer()
|
||||||
|
|
||||||
|
# No match.
|
||||||
|
result = proposer.propose(
|
||||||
|
context_token_ids=np.array([1, 2, 3, 4, 5]),
|
||||||
|
min_n=2,
|
||||||
|
max_n=2,
|
||||||
|
k=2,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# No match for 4-gram.
|
||||||
|
result = proposer.propose(
|
||||||
|
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
|
||||||
|
min_n=4,
|
||||||
|
max_n=4,
|
||||||
|
k=2,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# No match for 4-gram but match for 3-gram.
|
||||||
|
result = proposer.propose(
|
||||||
|
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
|
||||||
|
min_n=3,
|
||||||
|
max_n=4,
|
||||||
|
k=2,
|
||||||
|
)
|
||||||
|
assert np.array_equal(result, np.array([4, 1]))
|
||||||
|
|
||||||
|
# Match for both 4-gram and 3-gram.
|
||||||
|
# In this case, the proposer should return the 4-gram match.
|
||||||
|
result = proposer.propose(
|
||||||
|
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
|
||||||
|
min_n=3,
|
||||||
|
max_n=4,
|
||||||
|
k=2,
|
||||||
|
)
|
||||||
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
|
||||||
|
|
||||||
|
# Match for 2-gram and 3-gram, but not 4-gram.
|
||||||
|
result = proposer.propose(
|
||||||
|
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
|
||||||
|
min_n=2,
|
||||||
|
max_n=4,
|
||||||
|
k=2,
|
||||||
|
)
|
||||||
|
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
|
||||||
|
@ -10,7 +10,8 @@ class NgramProposer:
|
|||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
context_token_ids: np.ndarray,
|
context_token_ids: np.ndarray,
|
||||||
n: int,
|
min_n: int,
|
||||||
|
max_n: int,
|
||||||
k: int,
|
k: int,
|
||||||
) -> Optional[np.ndarray]:
|
) -> Optional[np.ndarray]:
|
||||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||||
@ -21,7 +22,8 @@ class NgramProposer:
|
|||||||
Args:
|
Args:
|
||||||
context_token_ids: Numpy array of token IDs representing the
|
context_token_ids: Numpy array of token IDs representing the
|
||||||
context sequence.
|
context sequence.
|
||||||
n: Length of the n-gram to match.
|
min_n: Minimum length of the n-gram to match.
|
||||||
|
max_n: Maximum length of the n-gram to match.
|
||||||
k: Number of tokens follow the match. If there are less
|
k: Number of tokens follow the match. If there are less
|
||||||
than k tokens follow the match, we will return
|
than k tokens follow the match, we will return
|
||||||
the maximum amount of tokens until the end.
|
the maximum amount of tokens until the end.
|
||||||
@ -32,14 +34,21 @@ class NgramProposer:
|
|||||||
None: If no matching n-gram pattern is found.
|
None: If no matching n-gram pattern is found.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
|
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
|
||||||
|
k = 4:
|
||||||
|
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
|
||||||
- The last 2 tokens [2,3] will be matched against the previous
|
- The last 2 tokens [2,3] will be matched against the previous
|
||||||
4 tokens [1,2,3,4].
|
4 tokens [1,2,3,4].
|
||||||
- Finding a match of [2,3] would return the tokens that
|
- Finding a match of [2,3] would return the tokens that
|
||||||
followed that pattern. Here we will return [4,2,3] because
|
followed that pattern. Here we will return [4,2,3] because
|
||||||
we only have three tokens after the match.
|
we only have three tokens after the match.
|
||||||
"""
|
"""
|
||||||
return _find_subarray_kmp(context_token_ids, n, k)
|
# TODO(woosuk): Optimize this.
|
||||||
|
for n in range(max_n, min_n - 1, -1):
|
||||||
|
result = _find_subarray_kmp(context_token_ids, n, k)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
@jit(nopython=True)
|
||||||
|
@ -160,6 +160,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.drafter.propose(
|
self.drafter.propose(
|
||||||
np.zeros(1024, dtype=np.int32),
|
np.zeros(1024, dtype=np.int32),
|
||||||
self.speculative_config.prompt_lookup_min,
|
self.speculative_config.prompt_lookup_min,
|
||||||
|
self.speculative_config.prompt_lookup_max,
|
||||||
self.speculative_config.num_speculative_tokens,
|
self.speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
self.rejection_sampler = RejectionSampler()
|
self.rejection_sampler = RejectionSampler()
|
||||||
@ -1155,6 +1156,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
drafter_output = self.drafter.propose(
|
drafter_output = self.drafter.propose(
|
||||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||||
self.speculative_config.prompt_lookup_min,
|
self.speculative_config.prompt_lookup_min,
|
||||||
|
self.speculative_config.prompt_lookup_max,
|
||||||
self.speculative_config.num_speculative_tokens,
|
self.speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
if drafter_output is None or len(drafter_output) == 0:
|
if drafter_output is None or len(drafter_output) == 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user