[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash Attention to enable embedding models. (#14664)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
d9f83d6206
commit
916836bbfb
@ -561,6 +561,10 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_align_sum_kernels.cu"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
|
||||
endif()
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
@ -52,6 +52,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,8 @@ import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
@ -15,14 +17,21 @@ from transformers import AutoModelForSequenceClassification
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
["half"] if current_platform.is_rocm() else ["float"])
|
||||
def test_classification_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
@ -43,4 +52,8 @@ def test_classification_models(
|
||||
hf_output = torch.tensor(hf_output)
|
||||
vllm_output = torch.tensor(vllm_output)
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output, 1e-3)
|
||||
# the tolerance value of 1e-2 is selected based on the
|
||||
# half datatype tests in
|
||||
# tests/models/embedding/language/test_embedding.py
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
@ -6,6 +6,7 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||
import pytest
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import check_embeddings_close
|
||||
|
||||
@ -18,15 +19,15 @@ from ..utils import check_embeddings_close
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
||||
pytest.param("intfloat/multilingual-e5-small"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
|
||||
# [Decoder-only]
|
||||
pytest.param("BAAI/bge-multilingual-gemma2",
|
||||
marks=[pytest.mark.core_model]),
|
||||
pytest.param("intfloat/e5-mistral-7b-instruct",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
# [Encoder-decoder]
|
||||
# [Cross-Encoder]
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
||||
],
|
||||
)
|
||||
@ -37,11 +38,19 @@ def test_models(
|
||||
example_prompts,
|
||||
model,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
|
||||
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["override_pooler_config"] = \
|
||||
PoolerConfig(pooling_type="MEAN")
|
||||
|
||||
if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
|
||||
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
|
||||
|
||||
|
@ -15,8 +15,8 @@ import vllm.config
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
pytest.mark.skipif(not importlib.util.find_spec("xformers"),
|
||||
reason="GritLM requires XFormers")
|
||||
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
|
||||
reason="GritLM requires XFormers")
|
||||
|
||||
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
|
||||
MAX_MODEL_LEN = 4000
|
||||
|
@ -4,10 +4,27 @@ import pytest
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForVision2Seq
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from ....utils import large_gpu_test
|
||||
from ..utils import check_embeddings_close
|
||||
|
||||
# Llava Next embedding implementation is only supported by CUDA.
|
||||
# If run on ROCm, hf_model.model.resize_token_embeddings will
|
||||
# cause the following error:
|
||||
# RuntimeError: Calling torch.linalg.cholesky on a CUDA tensor
|
||||
# requires compiling PyTorch with MAGMA. Please use PyTorch
|
||||
# built with MAGMA support.
|
||||
# If run on CPU, hf_model.model.resize_token_embeddings will
|
||||
# cause the following error:
|
||||
# RuntimeError: Calling torch.linalg.cholesky on a CPU tensor
|
||||
# requires compiling PyTorch with LAPACK. Please use PyTorch
|
||||
# built with LAPACK support.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="Llava Next model uses op that is only supported in CUDA")
|
||||
|
||||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
|
||||
|
||||
HF_TEXT_PROMPTS = [
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer ROCm GPUs."""
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
@ -342,28 +343,27 @@ def _get_seq_len_block_table_args(
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths
|
||||
Encoder attn -> select encoder sequence lengths fields
|
||||
Encoder-only attn -> select prefill sequence lengths with
|
||||
bidirectional attention
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
encoder/decoder cross-attention, encoder-only
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensors for query and key
|
||||
* Appropriate max sequence-length scalar
|
||||
* Causal masking flag
|
||||
'''
|
||||
|
||||
partial_prefix_sum = 0
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.encoder_seq_lens
|
||||
],
|
||||
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
@ -372,16 +372,29 @@ def _get_seq_len_block_table_args(
|
||||
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_lens, causal_mask)
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# For encoder-only models, we use the prefill sequence lengths
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
# Encoder-only models typically use bidirectional attention
|
||||
causal_mask = False
|
||||
|
||||
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
|
||||
max_seq_len, attn_metadata.seq_lens, causal_mask)
|
||||
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.seq_lens
|
||||
],
|
||||
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
@ -393,21 +406,14 @@ def _get_seq_len_block_table_args(
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.seq_lens
|
||||
],
|
||||
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
|
||||
partial_prefix_sum = 0
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
key_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.encoder_seq_lens
|
||||
],
|
||||
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
@ -584,6 +590,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
* ENCODER_ONLY: bidirectional attention with no KV caching;
|
||||
use prefill sequence attributes
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
@ -608,7 +616,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
# Only update KV cache for decoder self-attention
|
||||
# and encoder-decoder cross-attention
|
||||
if self.attn_type not in [
|
||||
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
|
||||
] and kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -632,6 +644,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if self.attn_type != AttentionType.ENCODER:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
elif self.attn_type == AttentionType.ENCODER_ONLY:
|
||||
# For encoder-only models, all tokens are processed in one go
|
||||
num_prefill_tokens = query.shape[0]
|
||||
else:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
@ -642,8 +657,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
|
||||
# For encoder-only and encoder models,
|
||||
# we process all tokens at once
|
||||
# For decoder and encoder-decoder,
|
||||
# we may need to limit key/value to prefill tokens
|
||||
if key is not None and value is not None \
|
||||
and self.attn_type != AttentionType.ENCODER_DECODER:
|
||||
and self.attn_type not in [AttentionType.ENCODER_DECODER,
|
||||
AttentionType.ENCODER_ONLY]:
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
@ -678,7 +698,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
seq_lens,
|
||||
make_attn_mask=False) # type: ignore
|
||||
make_attn_mask=causal_mask) # type: ignore
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
@ -703,7 +723,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.seq_lens,
|
||||
make_attn_mask=True) # type: ignore
|
||||
make_attn_mask=causal_mask) # type: ignore
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
@ -729,7 +749,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=key_max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
causal=causal_mask,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
@ -742,25 +762,29 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
output = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# prefix-enabled attention -
|
||||
# not applicable for encoder-only models
|
||||
if self.attn_type != AttentionType.ENCODER_ONLY:
|
||||
output[:
|
||||
num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# Skip decode phase for encoder-only models
|
||||
if (decode_meta := attn_metadata.decode_metadata) and (
|
||||
self.attn_type != AttentionType.ENCODER_ONLY):
|
||||
# Decoding run.
|
||||
# Whether to use rocm custom paged attention or not
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
@ -885,4 +909,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||
|
Loading…
x
Reference in New Issue
Block a user