[V1][Minor] Move cascade attn logic outside _prepare_inputs (#12943)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
e31498bdcb
commit
4ea48fb35c
@ -476,67 +476,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.device, non_blocking=True).long()
|
self.device, non_blocking=True).long()
|
||||||
|
|
||||||
# Prepare for cascade attention if needed.
|
# Prepare for cascade attention if needed.
|
||||||
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||||
self.block_size)
|
num_scheduled_tokens,
|
||||||
if common_prefix_len == 0:
|
scheduler_output.num_common_prefix_blocks,
|
||||||
# Common case.
|
)
|
||||||
use_cascade = False
|
use_cascade = common_prefix_len > 0
|
||||||
else:
|
|
||||||
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
|
||||||
# for the common prefix and the other for the rest. For the first
|
|
||||||
# kernel, we concatenate all the query tokens (possibly from
|
|
||||||
# different requests) and treat them as if they are from the same
|
|
||||||
# request. Then, we use bi-directional attention to process the
|
|
||||||
# common prefix in the KV cache. Importantly, this means that the
|
|
||||||
# first kernel does not do any masking.
|
|
||||||
|
|
||||||
# Consider the following example:
|
|
||||||
# Request 1's input query: [D, E, X]
|
|
||||||
# Request 1's kv cache: [A, B, C, D, E, X]
|
|
||||||
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
|
||||||
# Request 2's input query: [E, Y]
|
|
||||||
# Request 2's kv cache: [A, B, C, D, E, Y]
|
|
||||||
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
|
||||||
|
|
||||||
# If we use [A, B, C, D, E] as the common prefix, then the
|
|
||||||
# first kernel will compute the bi-directional attention between
|
|
||||||
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
|
||||||
# However, this is wrong because D in Request 1 should not attend to
|
|
||||||
# E in the common prefix (i.e., we need masking).
|
|
||||||
# To avoid this, [A, B, C, D] should be the common prefix.
|
|
||||||
# That is, the common prefix should be capped by the minimum
|
|
||||||
# num_computed_tokens among the requests, and plus one to include
|
|
||||||
# the first token of the query.
|
|
||||||
|
|
||||||
# In practice, we use [A, B, C] as the common prefix, instead of
|
|
||||||
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
|
||||||
# num_computed_tokens, without plus one).
|
|
||||||
# This is because of an implementation detail: We want to always
|
|
||||||
# use two kernels for cascade attention. Let's imagine:
|
|
||||||
# Request 3's input query: [D]
|
|
||||||
# Request 3's kv cache: [A, B, C, D]
|
|
||||||
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
|
||||||
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
|
||||||
# then Request 3 will be processed only by the first kernel,
|
|
||||||
# and the second kernel will get an empty input. While this is not
|
|
||||||
# a fundamental problem, our current implementation does not support
|
|
||||||
# this case.
|
|
||||||
common_prefix_len = min(
|
|
||||||
common_prefix_len,
|
|
||||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
|
||||||
# common_prefix_len should be a multiple of the block size.
|
|
||||||
common_prefix_len = (common_prefix_len // self.block_size *
|
|
||||||
self.block_size)
|
|
||||||
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
|
||||||
common_prefix_len=common_prefix_len,
|
|
||||||
query_lens=num_scheduled_tokens,
|
|
||||||
num_query_heads=self.num_query_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
use_alibi=False, # FIXME
|
|
||||||
use_sliding_window=self.sliding_window is not None,
|
|
||||||
num_sms=self.num_sms,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
# TODO: Optimize.
|
# TODO: Optimize.
|
||||||
cu_prefix_query_lens = torch.tensor(
|
cu_prefix_query_lens = torch.tensor(
|
||||||
@ -581,6 +525,90 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
return attn_metadata, logits_indices
|
return attn_metadata, logits_indices
|
||||||
|
|
||||||
|
def _compute_cascade_attn_prefix_len(
|
||||||
|
self,
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
num_common_prefix_blocks: int,
|
||||||
|
) -> int:
|
||||||
|
"""Compute the length of the common prefix for cascade attention.
|
||||||
|
|
||||||
|
NOTE(woosuk): The common prefix length returned by this function
|
||||||
|
represents the length used specifically for cascade attention, not the
|
||||||
|
actual number of tokens shared between requests. When cascade attention
|
||||||
|
is disabled (use_cascade=False), this function returns 0 even if
|
||||||
|
requests share common tokens. Additionally, the common prefix length is
|
||||||
|
truncated to a multiple of the block size and may be further truncated
|
||||||
|
due to implementation details explained below.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_scheduled_tokens: Number of tokens scheduled per request.
|
||||||
|
num_common_prefix_blocks: Number of shared KV cache blocks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Length of common prefix in tokens.
|
||||||
|
"""
|
||||||
|
common_prefix_len = num_common_prefix_blocks * self.block_size
|
||||||
|
if common_prefix_len == 0:
|
||||||
|
# Common case.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
||||||
|
# for the common prefix and the other for the rest. For the first
|
||||||
|
# kernel, we concatenate all the query tokens (possibly from
|
||||||
|
# different requests) and treat them as if they are from the same
|
||||||
|
# request. Then, we use bi-directional attention to process the
|
||||||
|
# common prefix in the KV cache. Importantly, this means that the
|
||||||
|
# first kernel does not do any masking.
|
||||||
|
|
||||||
|
# Consider the following example:
|
||||||
|
# Request 1's input query: [D, E, X]
|
||||||
|
# Request 1's kv cache: [A, B, C, D, E, X]
|
||||||
|
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
||||||
|
# Request 2's input query: [E, Y]
|
||||||
|
# Request 2's kv cache: [A, B, C, D, E, Y]
|
||||||
|
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||||
|
|
||||||
|
# If we use [A, B, C, D, E] as the common prefix, then the
|
||||||
|
# first kernel will compute the bi-directional attention between
|
||||||
|
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
||||||
|
# However, this is wrong because D in Request 1 should not attend to
|
||||||
|
# E in the common prefix (i.e., we need masking).
|
||||||
|
# To avoid this, [A, B, C, D] should be the common prefix.
|
||||||
|
# That is, the common prefix should be capped by the minimum
|
||||||
|
# num_computed_tokens among the requests, and plus one to include
|
||||||
|
# the first token of the query.
|
||||||
|
|
||||||
|
# In practice, we use [A, B, C] as the common prefix, instead of
|
||||||
|
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
||||||
|
# num_computed_tokens, without plus one).
|
||||||
|
# This is because of an implementation detail: We want to always
|
||||||
|
# use two kernels for cascade attention. Let's imagine:
|
||||||
|
# Request 3's input query: [D]
|
||||||
|
# Request 3's kv cache: [A, B, C, D]
|
||||||
|
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
||||||
|
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
||||||
|
# then Request 3 will be processed only by the first kernel,
|
||||||
|
# and the second kernel will get an empty input. While this is not
|
||||||
|
# a fundamental problem, our current implementation does not support
|
||||||
|
# this case.
|
||||||
|
num_reqs = len(num_scheduled_tokens)
|
||||||
|
common_prefix_len = min(
|
||||||
|
common_prefix_len,
|
||||||
|
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||||
|
# common_prefix_len should be a multiple of the block size.
|
||||||
|
common_prefix_len = (common_prefix_len // self.block_size *
|
||||||
|
self.block_size)
|
||||||
|
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
query_lens=num_scheduled_tokens,
|
||||||
|
num_query_heads=self.num_query_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
use_alibi=False, # FIXME
|
||||||
|
use_sliding_window=self.sliding_window is not None,
|
||||||
|
num_sms=self.num_sms,
|
||||||
|
)
|
||||||
|
return common_prefix_len if use_cascade else 0
|
||||||
|
|
||||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||||
mrope_pos_ptr = 0
|
mrope_pos_ptr = 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user