[V1][Minor] Move cascade attn logic outside _prepare_inputs (#12943)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
e31498bdcb
commit
4ea48fb35c
@ -476,67 +476,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.device, non_blocking=True).long()
|
||||
|
||||
# Prepare for cascade attention if needed.
|
||||
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
||||
self.block_size)
|
||||
if common_prefix_len == 0:
|
||||
# Common case.
|
||||
use_cascade = False
|
||||
else:
|
||||
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
||||
# for the common prefix and the other for the rest. For the first
|
||||
# kernel, we concatenate all the query tokens (possibly from
|
||||
# different requests) and treat them as if they are from the same
|
||||
# request. Then, we use bi-directional attention to process the
|
||||
# common prefix in the KV cache. Importantly, this means that the
|
||||
# first kernel does not do any masking.
|
||||
|
||||
# Consider the following example:
|
||||
# Request 1's input query: [D, E, X]
|
||||
# Request 1's kv cache: [A, B, C, D, E, X]
|
||||
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
||||
# Request 2's input query: [E, Y]
|
||||
# Request 2's kv cache: [A, B, C, D, E, Y]
|
||||
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
|
||||
# If we use [A, B, C, D, E] as the common prefix, then the
|
||||
# first kernel will compute the bi-directional attention between
|
||||
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
||||
# However, this is wrong because D in Request 1 should not attend to
|
||||
# E in the common prefix (i.e., we need masking).
|
||||
# To avoid this, [A, B, C, D] should be the common prefix.
|
||||
# That is, the common prefix should be capped by the minimum
|
||||
# num_computed_tokens among the requests, and plus one to include
|
||||
# the first token of the query.
|
||||
|
||||
# In practice, we use [A, B, C] as the common prefix, instead of
|
||||
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
||||
# num_computed_tokens, without plus one).
|
||||
# This is because of an implementation detail: We want to always
|
||||
# use two kernels for cascade attention. Let's imagine:
|
||||
# Request 3's input query: [D]
|
||||
# Request 3's kv cache: [A, B, C, D]
|
||||
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
||||
# then Request 3 will be processed only by the first kernel,
|
||||
# and the second kernel will get an empty input. While this is not
|
||||
# a fundamental problem, our current implementation does not support
|
||||
# this case.
|
||||
common_prefix_len = min(
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // self.block_size *
|
||||
self.block_size)
|
||||
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
query_lens=num_scheduled_tokens,
|
||||
num_query_heads=self.num_query_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_alibi=False, # FIXME
|
||||
use_sliding_window=self.sliding_window is not None,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# TODO: Optimize.
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
@ -581,6 +525,90 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_common_prefix_blocks: int,
|
||||
) -> int:
|
||||
"""Compute the length of the common prefix for cascade attention.
|
||||
|
||||
NOTE(woosuk): The common prefix length returned by this function
|
||||
represents the length used specifically for cascade attention, not the
|
||||
actual number of tokens shared between requests. When cascade attention
|
||||
is disabled (use_cascade=False), this function returns 0 even if
|
||||
requests share common tokens. Additionally, the common prefix length is
|
||||
truncated to a multiple of the block size and may be further truncated
|
||||
due to implementation details explained below.
|
||||
|
||||
Args:
|
||||
num_scheduled_tokens: Number of tokens scheduled per request.
|
||||
num_common_prefix_blocks: Number of shared KV cache blocks.
|
||||
|
||||
Returns:
|
||||
int: Length of common prefix in tokens.
|
||||
"""
|
||||
common_prefix_len = num_common_prefix_blocks * self.block_size
|
||||
if common_prefix_len == 0:
|
||||
# Common case.
|
||||
return 0
|
||||
|
||||
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
||||
# for the common prefix and the other for the rest. For the first
|
||||
# kernel, we concatenate all the query tokens (possibly from
|
||||
# different requests) and treat them as if they are from the same
|
||||
# request. Then, we use bi-directional attention to process the
|
||||
# common prefix in the KV cache. Importantly, this means that the
|
||||
# first kernel does not do any masking.
|
||||
|
||||
# Consider the following example:
|
||||
# Request 1's input query: [D, E, X]
|
||||
# Request 1's kv cache: [A, B, C, D, E, X]
|
||||
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
||||
# Request 2's input query: [E, Y]
|
||||
# Request 2's kv cache: [A, B, C, D, E, Y]
|
||||
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
|
||||
# If we use [A, B, C, D, E] as the common prefix, then the
|
||||
# first kernel will compute the bi-directional attention between
|
||||
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
||||
# However, this is wrong because D in Request 1 should not attend to
|
||||
# E in the common prefix (i.e., we need masking).
|
||||
# To avoid this, [A, B, C, D] should be the common prefix.
|
||||
# That is, the common prefix should be capped by the minimum
|
||||
# num_computed_tokens among the requests, and plus one to include
|
||||
# the first token of the query.
|
||||
|
||||
# In practice, we use [A, B, C] as the common prefix, instead of
|
||||
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
||||
# num_computed_tokens, without plus one).
|
||||
# This is because of an implementation detail: We want to always
|
||||
# use two kernels for cascade attention. Let's imagine:
|
||||
# Request 3's input query: [D]
|
||||
# Request 3's kv cache: [A, B, C, D]
|
||||
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
||||
# then Request 3 will be processed only by the first kernel,
|
||||
# and the second kernel will get an empty input. While this is not
|
||||
# a fundamental problem, our current implementation does not support
|
||||
# this case.
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
common_prefix_len = min(
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // self.block_size *
|
||||
self.block_size)
|
||||
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
query_lens=num_scheduled_tokens,
|
||||
num_query_heads=self.num_query_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_alibi=False, # FIXME
|
||||
use_sliding_window=self.sliding_window is not None,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
|
||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
mrope_pos_ptr = 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
Loading…
x
Reference in New Issue
Block a user