[Misc] Add chunked-prefill support on FlashInfer. (#9781)
This commit is contained in:
parent
81f09cfd80
commit
9ff4511e43
@ -11,6 +11,8 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
|
|
||||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
@ -28,6 +30,7 @@ MODELS = [
|
|||||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||||
# reset distributed env properly. Use a value > 1 just when you test.
|
# reset distributed env properly. Use a value > 1 just when you test.
|
||||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||||
|
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||||
def test_models(
|
def test_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
@ -38,11 +41,15 @@ def test_models(
|
|||||||
chunked_prefill_token_size: int,
|
chunked_prefill_token_size: int,
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
|
attention_backend: str,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Checks exact match decode between huggingface model and vllm runner with
|
Checks exact match decode between huggingface model and vllm runner with
|
||||||
chunked prefill.
|
chunked prefill.
|
||||||
"""
|
"""
|
||||||
|
override_backend_env_variable(monkeypatch, attention_backend)
|
||||||
|
|
||||||
max_num_seqs = chunked_prefill_token_size
|
max_num_seqs = chunked_prefill_token_size
|
||||||
max_num_batched_tokens = chunked_prefill_token_size
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
@ -71,13 +78,18 @@ def test_models(
|
|||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||||
def test_models_distributed(
|
def test_models_distributed(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
model: str,
|
model: str,
|
||||||
distributed_executor_backend: str,
|
distributed_executor_backend: str,
|
||||||
|
attention_backend: str,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
override_backend_env_variable(monkeypatch, attention_backend)
|
||||||
|
|
||||||
if (model == "meta-llama/Llama-2-7b-hf"
|
if (model == "meta-llama/Llama-2-7b-hf"
|
||||||
and distributed_executor_backend == "ray"):
|
and distributed_executor_backend == "ray"):
|
||||||
# test ray adag
|
# test ray adag
|
||||||
|
@ -268,6 +268,11 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
# requests only.
|
# requests only.
|
||||||
max_prefill_seq_len: int
|
max_prefill_seq_len: int
|
||||||
|
# Number of query tokens for each request in the batch.
|
||||||
|
# Currently, we require that all requests have the same number of query
|
||||||
|
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||||
|
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||||
|
decode_query_len: Optional[int] = 1
|
||||||
|
|
||||||
use_cuda_graph: bool = True
|
use_cuda_graph: bool = True
|
||||||
|
|
||||||
@ -335,6 +340,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
assert self.paged_kv_last_page_len is not None
|
assert self.paged_kv_last_page_len is not None
|
||||||
assert self.block_table_bound is not None
|
assert self.block_table_bound is not None
|
||||||
assert self.seq_lens_tensor is not None
|
assert self.seq_lens_tensor is not None
|
||||||
|
self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
|
||||||
batch_size = self.query_start_loc.shape[0] - 1
|
batch_size = self.query_start_loc.shape[0] - 1
|
||||||
assert batch_size >= 0
|
assert batch_size >= 0
|
||||||
# We will use flash attention for profiling to
|
# We will use flash attention for profiling to
|
||||||
@ -349,11 +355,13 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||||
self.prefill_wrapper.end_forward()
|
self.prefill_wrapper.end_forward()
|
||||||
self.prefill_wrapper.begin_forward(
|
self.prefill_wrapper.begin_forward(
|
||||||
self.query_start_loc, self.paged_kv_indptr,
|
self.query_start_loc,
|
||||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
self.paged_kv_indptr[:self.num_prefills + 1],
|
||||||
|
self.paged_kv_indices,
|
||||||
|
self.paged_kv_last_page_len[:self.num_prefills],
|
||||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||||
self.page_size)
|
self.page_size)
|
||||||
else:
|
if self.num_decode_tokens > 0:
|
||||||
assert self.paged_kv_indices is not None
|
assert self.paged_kv_indices is not None
|
||||||
assert self.paged_kv_indptr is not None
|
assert self.paged_kv_indptr is not None
|
||||||
assert self.paged_kv_last_page_len is not None
|
assert self.paged_kv_last_page_len is not None
|
||||||
@ -370,9 +378,9 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
assert self.decode_wrapper is not None
|
assert self.decode_wrapper is not None
|
||||||
self.decode_wrapper.end_forward()
|
self.decode_wrapper.end_forward()
|
||||||
self.decode_wrapper.begin_forward(
|
self.decode_wrapper.begin_forward(
|
||||||
self.paged_kv_indptr,
|
self.paged_kv_indptr[self.num_prefills:],
|
||||||
self.paged_kv_indices,
|
self.paged_kv_indices,
|
||||||
self.paged_kv_last_page_len,
|
self.paged_kv_last_page_len[self.num_prefills:],
|
||||||
self.num_qo_heads,
|
self.num_qo_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -397,21 +405,14 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
|
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||||
# Currently chunked prefill is not supported
|
if self.num_prefills == 0:
|
||||||
if self.num_decode_tokens == 0:
|
return None
|
||||||
assert self.num_prefills > 0
|
return self
|
||||||
return self
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||||
# Currently chunked prefill is not supported
|
if self.num_decode_tokens == 0:
|
||||||
if self.num_prefills > 0:
|
|
||||||
assert self.num_decode_tokens == 0, (
|
|
||||||
"Chunked prefill is not supported with flashinfer yet.")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def advance_step(self,
|
def advance_step(self,
|
||||||
@ -599,11 +600,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
|
decode_query_len = max(query_lens[self.num_prefills:], default=1)
|
||||||
|
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||||
num_decode_tokens = batch_size
|
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||||
|
|
||||||
# The shape of graph_block_tables is
|
# The shape of graph_block_tables is
|
||||||
# [max batch size, max context len // block size].
|
# [max batch size, max context len // block size].
|
||||||
@ -689,6 +691,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||||
|
|
||||||
return FlashInferMetadata(
|
return FlashInferMetadata(
|
||||||
|
decode_query_len=decode_query_len,
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
@ -811,12 +814,6 @@ def unified_flash_infer(
|
|||||||
key = key.view(-1, num_kv_heads, head_size)
|
key = key.view(-1, num_kv_heads, head_size)
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
value = value.view(-1, num_kv_heads, head_size)
|
||||||
|
|
||||||
if attn_metadata.num_prefill_tokens > 0:
|
|
||||||
assert attn_metadata.num_decode_tokens == 0, (
|
|
||||||
"Chunked prefill is not supported with flashinfer yet.")
|
|
||||||
if attn_metadata.num_decode_tokens > 0:
|
|
||||||
assert attn_metadata.num_prefill_tokens == 0, (
|
|
||||||
"Chunked prefill is not supported with flashinfer yet.")
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
# Use the same reshape and cache kernel as flash attention.
|
# Use the same reshape and cache kernel as flash attention.
|
||||||
ops.reshape_and_cache_flash(
|
ops.reshape_and_cache_flash(
|
||||||
@ -836,14 +833,33 @@ def unified_flash_infer(
|
|||||||
kv_cache_dtype)
|
kv_cache_dtype)
|
||||||
kv_cache = kv_cache.view(torch_dtype)
|
kv_cache = kv_cache.view(torch_dtype)
|
||||||
|
|
||||||
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||||
|
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||||
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||||
|
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||||
query = query.contiguous() # Flashinfer requires query to be contiguous
|
query = query.contiguous() # Flashinfer requires query to be contiguous
|
||||||
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
|
# QKV for prefill.
|
||||||
|
decode_query = query[num_prefill_tokens:]
|
||||||
|
query = query[:num_prefill_tokens]
|
||||||
|
|
||||||
|
key = key[:num_prefill_tokens]
|
||||||
|
value = value[:num_prefill_tokens]
|
||||||
|
|
||||||
|
assert query.shape[0] == num_prefill_tokens
|
||||||
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
|
|
||||||
|
prefill_output: Optional[torch.Tensor] = None
|
||||||
|
decode_output: Optional[torch.Tensor] = None
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# We will use flash attention for prefill
|
# We will use flash attention for prefill
|
||||||
# when kv_cache is not provided.
|
# when kv_cache is not provided.
|
||||||
# This happens when vllm runs the profiling to
|
# This happens when vllm runs the profiling to
|
||||||
# determine the number of blocks.
|
# determine the number of blocks.
|
||||||
if kv_cache.numel() == 0:
|
if kv_cache.numel() == 0:
|
||||||
output = flash_attn_varlen_func(
|
prefill_output = flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
@ -859,18 +875,34 @@ def unified_flash_infer(
|
|||||||
else:
|
else:
|
||||||
assert prefill_meta is not None
|
assert prefill_meta is not None
|
||||||
assert prefill_meta.prefill_wrapper is not None
|
assert prefill_meta.prefill_wrapper is not None
|
||||||
output = prefill_meta.prefill_wrapper.forward(
|
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||||
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
|
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
|
||||||
else:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
assert attn_metadata.decode_metadata is not None
|
assert attn_metadata.decode_metadata is not None
|
||||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||||
output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
|
||||||
query,
|
decode_query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
k_scale=k_scale,
|
k_scale=k_scale,
|
||||||
v_scale=v_scale)
|
v_scale=v_scale)
|
||||||
|
|
||||||
|
if prefill_output is None and decode_output is not None:
|
||||||
|
# Decode only batch.
|
||||||
|
output, num_tokens = decode_output, num_decode_tokens
|
||||||
|
elif decode_output is None and prefill_output is not None:
|
||||||
|
# Prefill only batch.
|
||||||
|
output, num_tokens = prefill_output, num_prefill_tokens
|
||||||
|
else:
|
||||||
|
# Chunked prefill batch does not work with speculative decoding in
|
||||||
|
# FlashInfer backend, so the query length for decode should be 1.
|
||||||
|
assert prefill_output is not None
|
||||||
|
assert decode_output is not None
|
||||||
|
assert decode_meta is not None
|
||||||
|
assert decode_meta.decode_query_len == 1
|
||||||
|
decode_output = decode_output.squeeze(1)
|
||||||
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user