[Misc] Add chunked-prefill support on FlashInfer. (#9781)
This commit is contained in:
parent
81f09cfd80
commit
9ff4511e43
@ -11,6 +11,8 @@ from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
|
||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
@ -28,6 +30,7 @@ MODELS = [
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -38,11 +41,15 @@ def test_models(
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
attention_backend: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
Checks exact match decode between huggingface model and vllm runner with
|
||||
chunked prefill.
|
||||
"""
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
@ -71,13 +78,18 @@ def test_models(
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
def test_models_distributed(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
distributed_executor_backend: str,
|
||||
attention_backend: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
if (model == "meta-llama/Llama-2-7b-hf"
|
||||
and distributed_executor_backend == "ray"):
|
||||
# test ray adag
|
||||
|
@ -268,6 +268,11 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = 1
|
||||
|
||||
use_cuda_graph: bool = True
|
||||
|
||||
@ -335,6 +340,7 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
assert self.block_table_bound is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
|
||||
batch_size = self.query_start_loc.shape[0] - 1
|
||||
assert batch_size >= 0
|
||||
# We will use flash attention for profiling to
|
||||
@ -349,11 +355,13 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.prefill_wrapper.end_forward()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.query_start_loc, self.paged_kv_indptr,
|
||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
||||
self.query_start_loc,
|
||||
self.paged_kv_indptr[:self.num_prefills + 1],
|
||||
self.paged_kv_indices,
|
||||
self.paged_kv_last_page_len[:self.num_prefills],
|
||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||
self.page_size)
|
||||
else:
|
||||
if self.num_decode_tokens > 0:
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
assert self.paged_kv_last_page_len is not None
|
||||
@ -370,9 +378,9 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
assert self.decode_wrapper is not None
|
||||
self.decode_wrapper.end_forward()
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.paged_kv_indptr,
|
||||
self.paged_kv_indptr[self.num_prefills:],
|
||||
self.paged_kv_indices,
|
||||
self.paged_kv_last_page_len,
|
||||
self.paged_kv_last_page_len[self.num_prefills:],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
@ -397,21 +405,14 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
def advance_step(self,
|
||||
@ -599,11 +600,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
decode_query_len = max(query_lens[self.num_prefills:], default=1)
|
||||
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
@ -689,6 +691,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||
|
||||
return FlashInferMetadata(
|
||||
decode_query_len=decode_query_len,
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
@ -811,12 +814,6 @@ def unified_flash_infer(
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if attn_metadata.num_prefill_tokens > 0:
|
||||
assert attn_metadata.num_decode_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
assert attn_metadata.num_prefill_tokens == 0, (
|
||||
"Chunked prefill is not supported with flashinfer yet.")
|
||||
if kv_cache.numel() > 0:
|
||||
# Use the same reshape and cache kernel as flash attention.
|
||||
ops.reshape_and_cache_flash(
|
||||
@ -836,14 +833,33 @@ def unified_flash_infer(
|
||||
kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
query = query.contiguous() # Flashinfer requires query to be contiguous
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
# QKV for prefill.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
query = query[:num_prefill_tokens]
|
||||
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
# This happens when vllm runs the profiling to
|
||||
# determine the number of blocks.
|
||||
if kv_cache.numel() == 0:
|
||||
output = flash_attn_varlen_func(
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@ -859,18 +875,34 @@ def unified_flash_infer(
|
||||
else:
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
output = prefill_meta.prefill_wrapper.forward(
|
||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
|
||||
else:
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
query,
|
||||
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
|
||||
if prefill_output is None and decode_output is not None:
|
||||
# Decode only batch.
|
||||
output, num_tokens = decode_output, num_decode_tokens
|
||||
elif decode_output is None and prefill_output is not None:
|
||||
# Prefill only batch.
|
||||
output, num_tokens = prefill_output, num_prefill_tokens
|
||||
else:
|
||||
# Chunked prefill batch does not work with speculative decoding in
|
||||
# FlashInfer backend, so the query length for decode should be 1.
|
||||
assert prefill_output is not None
|
||||
assert decode_output is not None
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user