diff --git a/CMakeLists.txt b/CMakeLists.txt index ea6d5237..5baa39b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -561,6 +561,10 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") +if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") +endif() + set_gencode_flags_for_srcs( SRCS "${VLLM_MOE_EXT_SRC}" CUDA_ARCHS "${CUDA_ARCHS}") diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 957ac765..718418e6 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -52,6 +52,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); // conditionally compiled so impl registration is in source file + #endif } diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py index b0420ff5..c6155da5 100644 --- a/tests/models/embedding/language/test_cls_models.py +++ b/tests/models/embedding/language/test_cls_models.py @@ -7,6 +7,8 @@ import pytest import torch from transformers import AutoModelForSequenceClassification +from vllm.platforms import current_platform + @pytest.mark.parametrize( "model", @@ -15,14 +17,21 @@ from transformers import AutoModelForSequenceClassification marks=[pytest.mark.core_model, pytest.mark.cpu_model]), ], ) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", + ["half"] if current_platform.is_rocm() else ["float"]) def test_classification_models( hf_runner, vllm_runner, example_prompts, model: str, dtype: str, + monkeypatch, ) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) @@ -43,4 +52,8 @@ def test_classification_models( hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, 1e-3) + # the tolerance value of 1e-2 is selected based on the + # half datatype tests in + # tests/models/embedding/language/test_embedding.py + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index a8ac70d5..6c28ee91 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -6,6 +6,7 @@ Run `pytest tests/models/embedding/language/test_embedding.py`. import pytest from vllm.config import PoolerConfig +from vllm.platforms import current_platform from ..utils import check_embeddings_close @@ -18,15 +19,15 @@ from ..utils import check_embeddings_close marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), + pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), pytest.param("intfloat/e5-mistral-7b-instruct", marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), - pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), - # [Encoder-decoder] + # [Cross-Encoder] pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) @@ -37,11 +38,19 @@ def test_models( example_prompts, model, dtype: str, + monkeypatch, ) -> None: + + if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": vllm_extra_kwargs["override_pooler_config"] = \ PoolerConfig(pooling_type="MEAN") + if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": vllm_extra_kwargs["hf_overrides"] = {"is_causal": True} diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 470dc041..cae3e1a5 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -15,8 +15,8 @@ import vllm.config from ....utils import RemoteOpenAIServer # GritLM embedding implementation is only supported by XFormers backend. -pytest.mark.skipif(not importlib.util.find_spec("xformers"), - reason="GritLM requires XFormers") +pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"), + reason="GritLM requires XFormers") MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index 4c2fbd52..8b9a856d 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -4,10 +4,27 @@ import pytest import torch.nn.functional as F from transformers import AutoModelForVision2Seq +from vllm.platforms import current_platform + from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from ....utils import large_gpu_test from ..utils import check_embeddings_close +# Llava Next embedding implementation is only supported by CUDA. +# If run on ROCm, hf_model.model.resize_token_embeddings will +# cause the following error: +# RuntimeError: Calling torch.linalg.cholesky on a CUDA tensor +# requires compiling PyTorch with MAGMA. Please use PyTorch +# built with MAGMA support. +# If run on CPU, hf_model.model.resize_token_embeddings will +# cause the following error: +# RuntimeError: Calling torch.linalg.cholesky on a CPU tensor +# requires compiling PyTorch with LAPACK. Please use PyTorch +# built with LAPACK support. +pytestmark = pytest.mark.skipif( + not current_platform.is_cuda(), + reason="Llava Next model uses op that is only supported in CUDA") + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 HF_TEXT_PROMPTS = [ diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02a2a48f..c4720209 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer ROCm GPUs.""" +import itertools from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -342,28 +343,27 @@ def _get_seq_len_block_table_args( Decoder attn -> select entirely decoder self-attention-related fields Encoder/decoder cross-attn -> select encoder sequence lengths Encoder attn -> select encoder sequence lengths fields + Encoder-only attn -> select prefill sequence lengths with + bidirectional attention Arguments: * attn_metadata: Attention metadata structure associated with attention op * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention + encoder/decoder cross-attention, encoder-only Returns: * Appropriate sequence-lengths tensors for query and key * Appropriate max sequence-length scalar + * Causal masking flag ''' - partial_prefix_sum = 0 if attn_type == AttentionType.ENCODER: assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.encoder_seq_lens_tensor is not None query_seq_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), device=attn_metadata.encoder_seq_lens_tensor.device, dtype=attn_metadata.encoder_seq_lens_tensor.dtype) causal_mask = False @@ -372,16 +372,29 @@ def _get_seq_len_block_table_args( return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, query_seq_start_loc, attn_metadata.max_encoder_seq_len, attn_metadata.encoder_seq_lens, causal_mask) + + elif attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, we use the prefill sequence lengths + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + # Encoder-only models typically use bidirectional attention + causal_mask = False + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens_tensor is not None query_seq_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.seq_lens)), device=attn_metadata.seq_lens_tensor.device, dtype=attn_metadata.seq_lens_tensor.dtype) max_seq_len = attn_metadata.max_prefill_seq_len @@ -393,21 +406,14 @@ def _get_seq_len_block_table_args( assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens_tensor is not None query_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.seq_lens)), device=attn_metadata.encoder_seq_lens_tensor.device, dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - partial_prefix_sum = 0 assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.seq_lens_tensor is not None key_seq_start_loc = torch.tensor( - [0] + [ - partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens - ], + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), device=attn_metadata.seq_lens_tensor.device, dtype=attn_metadata.seq_lens_tensor.dtype) causal_mask = False @@ -584,6 +590,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): will match encoder sequence lengths, pass encoder sequence attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ max_encoder_seq_len) + * ENCODER_ONLY: bidirectional attention with no KV caching; + use prefill sequence attributes Args: query: shape = [num_tokens, num_heads * head_size] @@ -608,7 +616,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: assert value is None - if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: + # Only update KV cache for decoder self-attention + # and encoder-decoder cross-attention + if self.attn_type not in [ + AttentionType.ENCODER, AttentionType.ENCODER_ONLY + ] and kv_cache.numel() > 0: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -632,6 +644,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): if self.attn_type != AttentionType.ENCODER: num_prefill_tokens = attn_metadata.num_prefill_tokens + elif self.attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, all tokens are processed in one go + num_prefill_tokens = query.shape[0] else: assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens @@ -642,8 +657,13 @@ class ROCmFlashAttentionImpl(AttentionImpl): # QKV for prefill. query = query[:num_prefill_tokens] + # For encoder-only and encoder models, + # we process all tokens at once + # For decoder and encoder-decoder, + # we may need to limit key/value to prefill tokens if key is not None and value is not None \ - and self.attn_type != AttentionType.ENCODER_DECODER: + and self.attn_type not in [AttentionType.ENCODER_DECODER, + AttentionType.ENCODER_ONLY]: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -678,7 +698,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.alibi_slopes, query.dtype, seq_lens, - make_attn_mask=False) # type: ignore + make_attn_mask=causal_mask) # type: ignore out, _ = self.attn_func( query, key, @@ -703,7 +723,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.alibi_slopes, query.dtype, attn_metadata.seq_lens, - make_attn_mask=True) # type: ignore + make_attn_mask=causal_mask) # type: ignore query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) @@ -729,7 +749,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): max_seqlen_q=prefill_meta.max_prefill_seq_len, max_seqlen_k=key_max_seq_len, softmax_scale=self.scale, - causal=True, + causal=causal_mask, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, @@ -742,25 +762,29 @@ class ROCmFlashAttentionImpl(AttentionImpl): else: output = out else: - # prefix-enabled attention - output[:num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) - - if decode_meta := attn_metadata.decode_metadata: + # prefix-enabled attention - + # not applicable for encoder-only models + if self.attn_type != AttentionType.ENCODER_ONLY: + output[: + num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) + # Skip decode phase for encoder-only models + if (decode_meta := attn_metadata.decode_metadata) and ( + self.attn_type != AttentionType.ENCODER_ONLY): # Decoding run. # Whether to use rocm custom paged attention or not num_seqs, num_heads, head_size = decode_query.shape @@ -885,4 +909,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) \ No newline at end of file + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)