[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash Attention to enable embedding models. (#14664)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-03-13 00:31:19 +08:00 committed by GitHub
parent d9f83d6206
commit 916836bbfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 118 additions and 50 deletions

View File

@ -561,6 +561,10 @@ set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
endif()
set_gencode_flags_for_srcs(
SRCS "${VLLM_MOE_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")

View File

@ -52,6 +52,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
// conditionally compiled so impl registration is in source file
#endif
}

View File

@ -7,6 +7,8 @@ import pytest
import torch
from transformers import AutoModelForSequenceClassification
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"model",
@ -15,14 +17,21 @@ from transformers import AutoModelForSequenceClassification
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype",
["half"] if current_platform.is_rocm() else ["float"])
def test_classification_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
monkeypatch,
) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
@ -43,4 +52,8 @@ def test_classification_models(
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output, 1e-3)
# the tolerance value of 1e-2 is selected based on the
# half datatype tests in
# tests/models/embedding/language/test_embedding.py
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)

View File

@ -6,6 +6,7 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
import pytest
from vllm.config import PoolerConfig
from vllm.platforms import current_platform
from ..utils import check_embeddings_close
@ -18,15 +19,15 @@ from ..utils import check_embeddings_close
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
# [Decoder-only]
pytest.param("BAAI/bge-multilingual-gemma2",
marks=[pytest.mark.core_model]),
pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
# [Encoder-decoder]
# [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
@ -37,11 +38,19 @@ def test_models(
example_prompts,
model,
dtype: str,
monkeypatch,
) -> None:
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN")
if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}

View File

@ -15,8 +15,8 @@ import vllm.config
from ....utils import RemoteOpenAIServer
# GritLM embedding implementation is only supported by XFormers backend.
pytest.mark.skipif(not importlib.util.find_spec("xformers"),
reason="GritLM requires XFormers")
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
reason="GritLM requires XFormers")
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000

View File

@ -4,10 +4,27 @@ import pytest
import torch.nn.functional as F
from transformers import AutoModelForVision2Seq
from vllm.platforms import current_platform
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
# Llava Next embedding implementation is only supported by CUDA.
# If run on ROCm, hf_model.model.resize_token_embeddings will
# cause the following error:
# RuntimeError: Calling torch.linalg.cholesky on a CUDA tensor
# requires compiling PyTorch with MAGMA. Please use PyTorch
# built with MAGMA support.
# If run on CPU, hf_model.model.resize_token_embeddings will
# cause the following error:
# RuntimeError: Calling torch.linalg.cholesky on a CPU tensor
# requires compiling PyTorch with LAPACK. Please use PyTorch
# built with LAPACK support.
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Llava Next model uses op that is only supported in CUDA")
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
HF_TEXT_PROMPTS = [

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer ROCm GPUs."""
import itertools
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
@ -342,28 +343,27 @@ def _get_seq_len_block_table_args(
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths
Encoder attn -> select encoder sequence lengths fields
Encoder-only attn -> select prefill sequence lengths with
bidirectional attention
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
encoder/decoder cross-attention, encoder-only
Returns:
* Appropriate sequence-lengths tensors for query and key
* Appropriate max sequence-length scalar
* Causal masking flag
'''
partial_prefix_sum = 0
if attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
causal_mask = False
@ -372,16 +372,29 @@ def _get_seq_len_block_table_args(
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_lens, causal_mask)
elif attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, we use the prefill sequence lengths
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
# Encoder-only models typically use bidirectional attention
causal_mask = False
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
max_seq_len, attn_metadata.seq_lens, causal_mask)
elif attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
@ -393,21 +406,14 @@ def _get_seq_len_block_table_args(
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
partial_prefix_sum = 0
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
key_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
causal_mask = False
@ -584,6 +590,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
* ENCODER_ONLY: bidirectional attention with no KV caching;
use prefill sequence attributes
Args:
query: shape = [num_tokens, num_heads * head_size]
@ -608,7 +616,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
assert value is None
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if self.attn_type not in [
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
] and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
@ -632,6 +644,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if self.attn_type != AttentionType.ENCODER:
num_prefill_tokens = attn_metadata.num_prefill_tokens
elif self.attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, all tokens are processed in one go
num_prefill_tokens = query.shape[0]
else:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
@ -642,8 +657,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# QKV for prefill.
query = query[:num_prefill_tokens]
# For encoder-only and encoder models,
# we process all tokens at once
# For decoder and encoder-decoder,
# we may need to limit key/value to prefill tokens
if key is not None and value is not None \
and self.attn_type != AttentionType.ENCODER_DECODER:
and self.attn_type not in [AttentionType.ENCODER_DECODER,
AttentionType.ENCODER_ONLY]:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
@ -678,7 +698,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.alibi_slopes,
query.dtype,
seq_lens,
make_attn_mask=False) # type: ignore
make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func(
query,
key,
@ -703,7 +723,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=True) # type: ignore
make_attn_mask=causal_mask) # type: ignore
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
@ -729,7 +749,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
causal=causal_mask,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
@ -742,25 +762,29 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
output = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
if decode_meta := attn_metadata.decode_metadata:
# prefix-enabled attention -
# not applicable for encoder-only models
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:
num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
# Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY):
# Decoding run.
# Whether to use rocm custom paged attention or not
num_seqs, num_heads, head_size = decode_query.shape
@ -885,4 +909,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)