[Kernel] Flashinfer correctness fix for v0.1.3 (#7319)

This commit is contained in:
Lily Liu 2024-08-12 00:59:17 -07:00 committed by GitHub
parent 86ab567bae
commit ec2affa8ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 24 deletions

View File

@ -60,8 +60,6 @@ steps:
- vllm/ - vllm/
- tests/basic_correctness - tests/basic_correctness
commands: commands:
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py - pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
@ -157,7 +155,6 @@ steps:
- vllm/ - vllm/
- tests/models - tests/models
commands: commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- pytest -v -s models -m \"not vlm\" - pytest -v -s models -m \"not vlm\"
- label: Vision Language Models Test # 42min - label: Vision Language Models Test # 42min
@ -212,7 +209,6 @@ steps:
- vllm/attention - vllm/attention
- tests/kernels - tests/kernels
commands: commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4 parallelism: 4
@ -331,7 +327,6 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy # NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details # see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py - pytest -v -s distributed/test_custom_all_reduce.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s -x lora/test_mixtral.py - pytest -v -s -x lora/test_mixtral.py

View File

@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################

View File

@ -117,6 +117,7 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache # The data type of the paged kv cache
data_type: torch.dtype = None data_type: torch.dtype = None
device: torch.device = torch.device("cuda") device: torch.device = torch.device("cuda")
is_profile_run: bool = False
def __post_init__(self): def __post_init__(self):
# Refer to # Refer to
@ -127,7 +128,6 @@ class FlashInferMetadata(AttentionMetadata):
raise ValueError( raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,", f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.") f"received {self.head_dim}.")
self.is_profile_run = is_block_tables_empty(self.block_tables)
def begin_forward(self): def begin_forward(self):
if self.num_prefill_tokens > 0: if self.num_prefill_tokens > 0:
@ -141,23 +141,20 @@ class FlashInferMetadata(AttentionMetadata):
assert self.paged_kv_last_page_len is not None assert self.paged_kv_last_page_len is not None
batch_size = self.query_start_loc.shape[0] - 1 batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0 assert batch_size >= 0
# The profile run does not read kv cache. # We will use flash attention for profiling to
# Both paged_kv_indices and paged_kv_last_page_len are empty. # determine the number of blocks. Therefore,
# paged_kv_indptr is a zero tensor with size batch_size + 1. # we don't need to prepare the input for flashinfer for profile run.
if self.is_profile_run: if not self.is_profile_run:
self.paged_kv_indptr = torch.zeros(batch_size + 1,
device=self.device)
else:
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device) self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward() self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr, self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len, self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim, self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size) self.page_size)
else: else:
if not self.use_cuda_graph: if not self.use_cuda_graph:
assert self.paged_kv_indices is not None assert self.paged_kv_indices is not None
@ -249,6 +246,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# paged_kv_last_page_len is the length of the last page of each request # paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = [] self.paged_kv_last_page_len: List[int] = []
self.is_profile_run: bool = False
def _add_seq_group( def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool): chunked_prefill_enabled: bool):
@ -305,6 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# and paged_kv_last_page_len for profile run because we will # and paged_kv_last_page_len for profile run because we will
# create dummy inputs. # create dummy inputs.
if is_profile_run: if is_profile_run:
self.is_profile_run = is_profile_run
return return
block_table = block_tables[seq_id] block_table = block_tables[seq_id]
@ -435,7 +435,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
device=device, device=device,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph) use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)
class FlashInferImpl(AttentionImpl): class FlashInferImpl(AttentionImpl):