[ROCm] Enable chunked prefill/paged attention in MLA on ROCm (#14316)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
4a754fcf15
commit
d9f83d6206
@ -1327,21 +1327,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
[0, q.shape[-1] - v.shape[-1]],
|
[0, q.shape[-1] - v.shape[-1]],
|
||||||
value=0)
|
value=0)
|
||||||
|
|
||||||
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
|
if is_vllm_fa:
|
||||||
attn_output, attn_softmax_lse = self.triton_fa_func(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v_padded,
|
|
||||||
None,
|
|
||||||
prefill_metadata.query_start_loc,
|
|
||||||
prefill_metadata.context_chunk_cu_seq_lens[i],
|
|
||||||
prefill_metadata.max_query_len,
|
|
||||||
prefill_metadata.context_chunk_max_seq_lens[i],
|
|
||||||
False, # causal
|
|
||||||
self.scale,
|
|
||||||
None, # attn_mask is None unless applying ALiBi mask
|
|
||||||
)
|
|
||||||
elif is_vllm_fa:
|
|
||||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@ -1416,7 +1402,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||||
value=0)
|
value=0)
|
||||||
|
|
||||||
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
|
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
|
||||||
output = self.triton_fa_func(
|
output = self.triton_fa_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
@ -3450,9 +3450,9 @@ class VllmConfig:
|
|||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
if self.model_config and self.model_config.use_mla and \
|
if self.model_config and self.model_config.use_mla and \
|
||||||
not current_platform.is_cuda():
|
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||||
logger.info(
|
logger.info(
|
||||||
"MLA is enabled on a non-cuda platform; forcing chunked "
|
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||||
"prefill and prefix caching to be disabled.")
|
"prefill and prefix caching to be disabled.")
|
||||||
self.scheduler_config.enable_chunked_prefill = False
|
self.scheduler_config.enable_chunked_prefill = False
|
||||||
self.scheduler_config.chunked_prefill_enabled = False
|
self.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user