[Kernel] Flashinfer correctness fix for v0.1.3 (#7319)
This commit is contained in:
parent
86ab567bae
commit
ec2affa8ae
@ -60,8 +60,6 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/basic_correctness
|
- tests/basic_correctness
|
||||||
commands:
|
commands:
|
||||||
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
|
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
|
|
||||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
@ -157,7 +155,6 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/models
|
- tests/models
|
||||||
commands:
|
commands:
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
|
||||||
- pytest -v -s models -m \"not vlm\"
|
- pytest -v -s models -m \"not vlm\"
|
||||||
|
|
||||||
- label: Vision Language Models Test # 42min
|
- label: Vision Language Models Test # 42min
|
||||||
@ -212,7 +209,6 @@ steps:
|
|||||||
- vllm/attention
|
- vllm/attention
|
||||||
- tests/kernels
|
- tests/kernels
|
||||||
commands:
|
commands:
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
|
||||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
@ -331,7 +327,6 @@ steps:
|
|||||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
|
||||||
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- pytest -v -s -x lora/test_mixtral.py
|
- pytest -v -s -x lora/test_mixtral.py
|
||||||
|
|
||||||
|
@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
|
|||||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,6 +117,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
# The data type of the paged kv cache
|
# The data type of the paged kv cache
|
||||||
data_type: torch.dtype = None
|
data_type: torch.dtype = None
|
||||||
device: torch.device = torch.device("cuda")
|
device: torch.device = torch.device("cuda")
|
||||||
|
is_profile_run: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Refer to
|
# Refer to
|
||||||
@ -127,7 +128,6 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||||
f"received {self.head_dim}.")
|
f"received {self.head_dim}.")
|
||||||
self.is_profile_run = is_block_tables_empty(self.block_tables)
|
|
||||||
|
|
||||||
def begin_forward(self):
|
def begin_forward(self):
|
||||||
if self.num_prefill_tokens > 0:
|
if self.num_prefill_tokens > 0:
|
||||||
@ -141,23 +141,20 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
assert self.paged_kv_last_page_len is not None
|
assert self.paged_kv_last_page_len is not None
|
||||||
batch_size = self.query_start_loc.shape[0] - 1
|
batch_size = self.query_start_loc.shape[0] - 1
|
||||||
assert batch_size >= 0
|
assert batch_size >= 0
|
||||||
# The profile run does not read kv cache.
|
# We will use flash attention for profiling to
|
||||||
# Both paged_kv_indices and paged_kv_last_page_len are empty.
|
# determine the number of blocks. Therefore,
|
||||||
# paged_kv_indptr is a zero tensor with size batch_size + 1.
|
# we don't need to prepare the input for flashinfer for profile run.
|
||||||
if self.is_profile_run:
|
if not self.is_profile_run:
|
||||||
self.paged_kv_indptr = torch.zeros(batch_size + 1,
|
|
||||||
device=self.device)
|
|
||||||
else:
|
|
||||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||||
self.device)
|
self.device)
|
||||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||||
self.prefill_wrapper.end_forward()
|
self.prefill_wrapper.end_forward()
|
||||||
self.prefill_wrapper.begin_forward(
|
self.prefill_wrapper.begin_forward(
|
||||||
self.query_start_loc, self.paged_kv_indptr,
|
self.query_start_loc, self.paged_kv_indptr,
|
||||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
self.paged_kv_indices, self.paged_kv_last_page_len,
|
||||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||||
self.page_size)
|
self.page_size)
|
||||||
else:
|
else:
|
||||||
if not self.use_cuda_graph:
|
if not self.use_cuda_graph:
|
||||||
assert self.paged_kv_indices is not None
|
assert self.paged_kv_indices is not None
|
||||||
@ -249,6 +246,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# paged_kv_last_page_len is the length of the last page of each request
|
# paged_kv_last_page_len is the length of the last page of each request
|
||||||
self.paged_kv_last_page_len: List[int] = []
|
self.paged_kv_last_page_len: List[int] = []
|
||||||
|
|
||||||
|
self.is_profile_run: bool = False
|
||||||
|
|
||||||
def _add_seq_group(
|
def _add_seq_group(
|
||||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||||
chunked_prefill_enabled: bool):
|
chunked_prefill_enabled: bool):
|
||||||
@ -305,6 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# and paged_kv_last_page_len for profile run because we will
|
# and paged_kv_last_page_len for profile run because we will
|
||||||
# create dummy inputs.
|
# create dummy inputs.
|
||||||
if is_profile_run:
|
if is_profile_run:
|
||||||
|
self.is_profile_run = is_profile_run
|
||||||
return
|
return
|
||||||
|
|
||||||
block_table = block_tables[seq_id]
|
block_table = block_tables[seq_id]
|
||||||
@ -435,7 +435,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
device=device,
|
device=device,
|
||||||
data_type=kv_cache_dtype,
|
data_type=kv_cache_dtype,
|
||||||
use_cuda_graph=use_captured_graph)
|
use_cuda_graph=use_captured_graph,
|
||||||
|
is_profile_run=self.is_profile_run)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferImpl(AttentionImpl):
|
class FlashInferImpl(AttentionImpl):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user