[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)
This commit is contained in:
parent
d522034c85
commit
a78dd3303e
@ -7,12 +7,18 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import DecoderPromptType
|
||||
from ..models.utils import check_logprobs_close
|
||||
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [
|
||||
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(
|
||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
||||
@ -29,7 +35,8 @@ def vllm_to_hf_output(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||
@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
|
||||
num_logprobs: int,
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
enforce_eager: bool,
|
||||
attn_backend: _Backend,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-End (E2E) test for the encoder-decoder framework.
|
||||
@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
|
||||
implementations to ensure that both implementations produce consistent
|
||||
and correct results.
|
||||
'''
|
||||
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
test_case_prompts = example_encoder_decoder_prompts[
|
||||
decoder_prompt_type]
|
||||
|
||||
# Configuration settings for HF baseline
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
# Configuration settings for HF baseline
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_case_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
enforce_eager=enforce_eager) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
test_case_prompts, max_tokens, num_logprobs)
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_outputs = (
|
||||
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_case_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
enforce_eager=enforce_eager) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
test_case_prompts, max_tokens, num_logprobs)
|
||||
|
||||
hf_skip_tokens = (1
|
||||
if decoder_prompt_type == DecoderPromptType.NONE else 0)
|
||||
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
|
||||
else 0)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||
)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||
)
|
||||
|
@ -16,13 +16,13 @@ from tests.kernels.utils import *
|
||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend,
|
||||
from vllm.attention.selector import (_Backend, get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
|
||||
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
HEAD_SIZES = [64, 256]
|
||||
|
||||
NUM_HEADS = [1, 16]
|
||||
@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
test_pt.block_size,
|
||||
device=CUDA_DEVICE)
|
||||
device=CUDA_DEVICE,
|
||||
backend=test_pt.backend_name)
|
||||
return TestResources(scale, attn_backend, attn, kv_cache)
|
||||
|
||||
|
||||
@ -592,6 +593,7 @@ def _run_encoder_attention_test(
|
||||
attn: Attention,
|
||||
encoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder attention.
|
||||
@ -610,6 +612,8 @@ def _run_encoder_attention_test(
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
* test_pt: The TestPoint object containing test details like number of
|
||||
model heads, head size, name of the backend being used etc.
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed {query,key,value} and
|
||||
@ -619,20 +623,31 @@ def _run_encoder_attention_test(
|
||||
attn_type = AttentionType.ENCODER
|
||||
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device),
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
with set_forward_context(attn_metadata):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||
# invoking the forward method.
|
||||
# TODO - Update the way we construct the query so that it
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device),
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_decoder_self_attention_test(
|
||||
test_rsrcs: TestResources,
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run decoder self-attention test.
|
||||
@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
|
||||
query/key/value fields
|
||||
* attn_metadata: attention metadata for decoder-self attention
|
||||
(contains KV cache memory-mapping)
|
||||
* test_pt: The TestPoint object containing test details like number of
|
||||
model heads, head size, name of the backend being used etc.
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
@ -660,12 +677,22 @@ def _run_decoder_self_attention_test(
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
return attn.forward(packed_qkv.query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
with set_forward_context(attn_metadata):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||
# invoking the forward method.
|
||||
# TODO - Update the way we construct the query so that it
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
packed_qkv.key,
|
||||
packed_qkv.value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
def _run_encoder_decoder_cross_attention_test(
|
||||
@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
decoder_test_params: PhaseTestParameters,
|
||||
cross_test_params: Optional[PhaseTestParameters],
|
||||
attn_metadata: AttentionMetadata,
|
||||
test_pt: TestPoint,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Run encoder/decoder cross-attention test.
|
||||
@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
(number_of_tokens x num_heads x head_size)
|
||||
key/value fields
|
||||
* attn_metadata: attention metadata for encoder/decoder-self attention
|
||||
* test_pt: The TestPoint object containing test details like number of
|
||||
model heads, head size, name of the backend being used etc.
|
||||
|
||||
Returns:
|
||||
* Attention.forward() applied to packed_{query,key,value}, kv_cache
|
||||
@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
|
||||
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
|
||||
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
|
||||
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
with set_forward_context(attn_metadata):
|
||||
# In the test setup the shape of the query is
|
||||
# [batch_size, seq_len, num_heads, head_size]. However
|
||||
# the attention backend expect the shape to be
|
||||
# [num_tokens, hidden_size]. Hence reshape the query before
|
||||
# invoking the forward method.
|
||||
# TODO - Update the way we construct the query so that it
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_reset_environment(attn_backend):
|
||||
# Set the default torch datatype to bfloat16 to enable
|
||||
# testing of the Flash Attention backend. Also clear the
|
||||
# cached value of the backend.
|
||||
default_dtype = torch.get_default_dtype()
|
||||
if attn_backend.name == 'FLASH_ATTN':
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
get_attn_backend.cache_clear()
|
||||
yield
|
||||
# Reset the torch datatype to what it was before the test
|
||||
# so as not to impact the remaining tests.
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
@ -773,10 +828,8 @@ def test_encoder_only(
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
@ -807,10 +860,14 @@ def test_encoder_only(
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
|
||||
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# PREFILL: decoder self-attention test
|
||||
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
prephase_dec_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
test_rsrcs,
|
||||
prephase_dec_test_params,
|
||||
prephase_cross_test_params,
|
||||
prephase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
prephase_cross_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# DECODE: build decode-phase attention metadata
|
||||
|
||||
@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
|
||||
# DECODE: decoder self-attention test
|
||||
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
decphase_dec_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
test_rsrcs,
|
||||
decphase_dec_test_params,
|
||||
None,
|
||||
decphase_attn_metadata,
|
||||
test_pt=test_pt)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
decphase_cross_pckd_act_out,
|
||||
attn_backend.name)
|
||||
|
@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
|
||||
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad)
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||
|
||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||
# bugs related to this test in PyTorch 2.4.
|
||||
@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
||||
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
|
||||
from vllm.attention.backends.xformers import XFormersBackend
|
||||
|
||||
return XFormersBackend()
|
||||
elif backend_name == STR_FLASH_ATTN_VAL:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||
return FlashAttentionBackend()
|
||||
|
||||
raise AssertionError(
|
||||
f"Unrecognized backend_name {backend_name} for unit test")
|
||||
|
||||
|
||||
def _make_metadata_tensors(
|
||||
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
|
||||
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
|
||||
torch.Tensor, Optional[int]]:
|
||||
seq_lens: Optional[List[int]],
|
||||
context_lens: Optional[List[int]],
|
||||
encoder_seq_lens: Optional[List[int]],
|
||||
device: Union[torch.device, str],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
||||
torch.Tensor, torch.Tensor, Optional[int]]:
|
||||
'''
|
||||
Build scalar & tensor values required to build attention metadata structure.
|
||||
|
||||
@ -553,6 +558,8 @@ def _make_metadata_tensors(
|
||||
* max_context_len: max(context_lens)
|
||||
* max_seq_len: max(seq_lens)
|
||||
* seq_start_loc: start idx of each sequence
|
||||
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
|
||||
* encoder_seq_start_loc: start idx of each encoder sequence
|
||||
* max_encoder_seq_len: encoder seq_lens list, as tensor
|
||||
'''
|
||||
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
|
||||
@ -566,8 +573,26 @@ def _make_metadata_tensors(
|
||||
|
||||
seq_start_loc = None
|
||||
|
||||
if seq_lens_tensor is not None:
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=seq_lens_tensor.device)
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=encoder_seq_lens_tensor.device)
|
||||
torch.cumsum(encoder_seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=encoder_seq_start_loc.dtype,
|
||||
out=encoder_seq_start_loc[1:])
|
||||
|
||||
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
|
||||
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
|
||||
seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
|
||||
max_encoder_seq_len)
|
||||
|
||||
|
||||
def make_kv_cache(num_blocks: int,
|
||||
@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
device: Union[torch.device, str],
|
||||
backend: str,
|
||||
default_val: float = 0.0) -> torch.Tensor:
|
||||
'''
|
||||
Create a fake KV cache.
|
||||
@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
|
||||
Returns:
|
||||
|
||||
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
||||
* for backend 'XFORMERS'
|
||||
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
||||
* for backend 'FLASH_ATTN'
|
||||
'''
|
||||
|
||||
kv_cache = torch.rand(
|
||||
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
||||
if backend == 'XFORMERS':
|
||||
kv_cache = torch.rand(
|
||||
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
||||
elif backend == 'FLASH_ATTN':
|
||||
kv_cache = torch.rand(
|
||||
(2, num_blocks, block_size, num_heads, head_size)).to(device)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
||||
f"'FLASH_ATTN'.")
|
||||
if default_val is not None:
|
||||
kv_cache[:, :, :] = default_val
|
||||
return kv_cache
|
||||
@ -858,8 +894,9 @@ def make_test_metadata(
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
seq_start_loc,
|
||||
encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(seq_lens,
|
||||
context_lens,
|
||||
@ -874,6 +911,7 @@ def make_test_metadata(
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
||||
max_decode_seq_len=0,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
@ -882,6 +920,7 @@ def make_test_metadata(
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.slot_mapping),
|
||||
@ -904,8 +943,9 @@ def make_test_metadata(
|
||||
context_lens_tensor,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
seq_start_loc,
|
||||
encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc,
|
||||
max_encoder_seq_len,
|
||||
) = _make_metadata_tensors(seq_lens,
|
||||
context_lens,
|
||||
@ -920,14 +960,17 @@ def make_test_metadata(
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=max(seq_lens),
|
||||
max_decode_query_len=1,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=kv_mmap.block_tables,
|
||||
use_cuda_graph=False,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=encoder_seq_start_loc,
|
||||
max_encoder_seq_len=max_encoder_seq_len,
|
||||
cross_slot_mapping=(None if cross_kv_mmap is None else
|
||||
cross_kv_mmap.slot_mapping),
|
||||
@ -936,7 +979,8 @@ def make_test_metadata(
|
||||
|
||||
|
||||
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
output_under_test: torch.Tensor) -> None:
|
||||
output_under_test: torch.Tensor,
|
||||
backend: str) -> None:
|
||||
'''
|
||||
Assert that observed output matches the ideal output
|
||||
contained in the test parameters data structure.
|
||||
@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
* output_under_test: actually observed output value
|
||||
'''
|
||||
ideal_output = test_params.packed_qkvo.ideal_output
|
||||
torch.testing.assert_close(ideal_output,
|
||||
output_under_test.view_as(ideal_output))
|
||||
if backend == 'XFORMERS':
|
||||
torch.testing.assert_close(ideal_output,
|
||||
output_under_test.view_as(ideal_output))
|
||||
|
||||
elif backend == 'FLASH_ATTN':
|
||||
# For FlashAttention override the accuracy thresholds to non default
|
||||
# values since we notice a higher difference between the ideal and
|
||||
# actual output.
|
||||
torch.testing.assert_close(ideal_output,
|
||||
output_under_test.view_as(ideal_output),
|
||||
atol=0.01,
|
||||
rtol=0.016)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
||||
f"'FLASH_ATTN'.")
|
||||
|
||||
|
||||
# Copied/modified from torch._refs.__init__.py
|
||||
|
@ -85,7 +85,7 @@ def run_test(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
|
||||
|
@ -10,10 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
|
||||
@ -73,7 +74,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
@ -85,6 +85,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@ -111,26 +112,12 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int]
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
@ -146,11 +133,62 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
@ -159,32 +197,52 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
self._cached_prefill_metadata = FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@ -194,17 +252,25 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
self._cached_decode_metadata = FlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
@ -214,9 +280,15 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
@ -586,16 +658,20 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention(
|
||||
query,
|
||||
key,
|
||||
@ -608,6 +684,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
k_scale,
|
||||
v_scale,
|
||||
self.scale,
|
||||
attn_type.value,
|
||||
self.sliding_window,
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
@ -616,6 +693,89 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
return output
|
||||
|
||||
|
||||
def _get_query_key_seq_metadata(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
"""
|
||||
Returns sequence metadata for key and query based on the specified
|
||||
attention type and whether input is a prompt.
|
||||
|
||||
This function computes the starting locations and maximum sequence lengths
|
||||
for key and query sequences for different attention types.
|
||||
|
||||
Args:
|
||||
attn_metadata: The attention metadata object
|
||||
is_prompt (bool): A flag indicating if the input is a prompt
|
||||
attn_type (AttentionType): The type of attention being used.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing four integers:
|
||||
- Starting location for the query sequence.
|
||||
- Maximum sequence length for the query sequence.
|
||||
- Starting location for the key sequence.
|
||||
- Maximum sequence length for the key sequence.
|
||||
|
||||
Raises:
|
||||
AttributeError: If an invalid attention type is provided.
|
||||
"""
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||
attn_metadata.seq_start_loc, max_seq_len)
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# This is cross attention between the where the key
|
||||
# is the precomputed encoder attention and query
|
||||
# is the input sequence.
|
||||
# Choose query max length based on whether it is prompt
|
||||
# or not.
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# For encoder attention both the query and the key are same i.e the
|
||||
# encoder sequence.
|
||||
return (attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert is_prompt, "Should not have decode for encoder only model."
|
||||
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
|
||||
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _get_causal_option(attn_type: AttentionType) -> bool:
|
||||
"""
|
||||
Determine whether the given attention type is suitable for causal
|
||||
attention mechanisms.
|
||||
|
||||
Args:
|
||||
attn_type (AttentionType): The type of attention being evaluated
|
||||
|
||||
Returns:
|
||||
bool: Returns `True` if the attention type is suitable for causal
|
||||
attention (i.e., not encoder, encoder-only, or encoder-decoder),
|
||||
otherwise returns `False`.
|
||||
"""
|
||||
return not (attn_type == AttentionType.ENCODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY
|
||||
or attn_type == AttentionType.ENCODER_DECODER)
|
||||
|
||||
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -628,60 +788,76 @@ def unified_flash_attention(
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Convert integer attn_type to enum
|
||||
try:
|
||||
attn_type = AttentionType(attn_type_int_val)
|
||||
except ValueError as err:
|
||||
raise AttributeError(
|
||||
f"Invalid attention type {str(attn_type_int_val)}") from err
|
||||
|
||||
current_metadata = get_forward_context()
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
if (key is not None) and (value is not None):
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the KV
|
||||
# cache is already populated with the cross-attention tensor.
|
||||
# Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
|
||||
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
|
||||
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
query = query[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
@ -689,22 +865,30 @@ def unified_flash_attention(
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
cu_seqlens_q=q_seq_start_loc,
|
||||
cu_seqlens_k=k_seq_start_loc,
|
||||
max_seqlen_q=q_seq_len,
|
||||
max_seqlen_k=k_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
causal=_get_causal_option(attn_type),
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
prefill_output = flash_attn_varlen_func( # noqa
|
||||
@ -729,6 +913,8 @@ def unified_flash_attention(
|
||||
# because different queries might have different lengths.
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1")
|
||||
decode_output = flash_attn_varlen_func(
|
||||
q=decode_query,
|
||||
k=key_cache,
|
||||
@ -746,12 +932,17 @@ def unified_flash_attention(
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
(
|
||||
seq_lens_arg,
|
||||
_,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
decode_output = flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
block_table=block_tables_arg,
|
||||
cache_seqlens=seq_lens_arg,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
@ -761,10 +952,10 @@ def unified_flash_attention(
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
return decode_output.view(num_decode_tokens, hidden_size)
|
||||
return decode_output.view(num_decode_query_tokens, hidden_size)
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_tokens, hidden_size)
|
||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
||||
|
||||
# Chunked prefill does not work with speculative decoding.
|
||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||
@ -786,6 +977,7 @@ def unified_flash_attention_fake(
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
softmax_scale: float,
|
||||
attn_type_int_val: int,
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
|
@ -1,13 +1,14 @@
|
||||
"""Attention backend utils"""
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
AttentionState)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
@ -336,11 +337,13 @@ class CommonAttentionState(AttentionState):
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
@ -356,11 +359,13 @@ class CommonAttentionState(AttentionState):
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additonal_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
@ -375,11 +380,13 @@ class CommonAttentionState(AttentionState):
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in\
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
||||
f"'FLASH_ATTN', but "\
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
@ -411,6 +418,7 @@ class CommonAttentionState(AttentionState):
|
||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||
attn_metadata.num_encoder_tokens = 0
|
||||
|
||||
def _add_additonal_input_buffers_for_enc_dec_model(
|
||||
self, attn_metadata, input_buffers: Dict[str, Any]):
|
||||
@ -453,3 +461,122 @@ class CommonAttentionState(AttentionState):
|
||||
input_buffers["cross_block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.cross_block_tables,
|
||||
non_blocking=True)
|
||||
|
||||
|
||||
def is_all_encoder_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((attn_metadata.encoder_seq_lens is not None)
|
||||
and (attn_metadata.encoder_seq_lens_tensor is not None)
|
||||
and (attn_metadata.max_encoder_seq_len is not None))
|
||||
|
||||
|
||||
def is_all_cross_attn_metadata_set(attn_metadata):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (attn_metadata.is_all_encoder_attn_metadata_set
|
||||
and (attn_metadata.cross_slot_mapping is not None)
|
||||
and (attn_metadata.cross_block_tables is not None))
|
||||
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
attn_metadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def get_num_prefill_decode_query_kv_tokens(
|
||||
attn_metadata,
|
||||
attn_type: AttentionType,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate the number of prefill and decode tokens for query, key/value
|
||||
based on the attention metadata and the specified attention type.
|
||||
|
||||
Args:
|
||||
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
|
||||
attn_type (AttentionType): The type of attention being used.
|
||||
Returns:
|
||||
Tuple[int, int, int]: A tuple containing three integers:
|
||||
- The number of prefill query tokens.
|
||||
- The number of prefill key/value tokens.
|
||||
- The number of decode query tokens.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of encoder tokens in `attn_metadata`
|
||||
is `None` when required for the calculations.
|
||||
"""
|
||||
num_prefill_query_tokens = 0
|
||||
num_decode_query_tokens = 0
|
||||
num_prefill_kv_tokens = 0
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention is only invoked during prefill phase.
|
||||
# The same input servers a both query and key.
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = 0
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
# The key is the encoder/cross-attention.
|
||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
else: # attn_type == AttentionType.DECODER or
|
||||
# attn_type == AttentionType.ENCODER_ONLY
|
||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
||||
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens)
|
||||
|
@ -11,8 +11,10 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.backends.utils import (
|
||||
CommonAttentionState, CommonMetadataBuilder,
|
||||
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
@ -162,9 +169,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
@ -173,9 +178,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
@ -329,64 +332,6 @@ def _set_attn_bias(
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: XFormersMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert is_prompt, "Should not have decode for encoder only model."
|
||||
|
||||
# No block tables associated with encoder attention
|
||||
return (attn_metadata.seq_lens_tensor,
|
||||
attn_metadata.max_prefill_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
|
||||
|
||||
_metadata_cls = XFormersMetadata
|
||||
@ -574,45 +519,21 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
else: # attn_type == AttentionType.ENCODER_DECODER
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
if attn_metadata.num_encoder_tokens is not None:
|
||||
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
||||
else:
|
||||
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
query = query[:num_prefill_query_tokens]
|
||||
if key is not None and value is not None:
|
||||
key = key[:num_encoder_tokens]
|
||||
value = value[:num_encoder_tokens]
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
@ -622,8 +543,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
# prefix.
|
||||
out = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, prefill_meta, attn_type=attn_type)
|
||||
assert out.shape == output[:num_prefill_tokens].shape
|
||||
output[:num_prefill_tokens] = out
|
||||
assert out.shape == output[:num_prefill_query_tokens].shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
else:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have prefix attention.")
|
||||
@ -652,8 +573,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
@ -663,9 +584,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
|
@ -98,7 +98,6 @@ def get_attn_backend(
|
||||
is_blocksparse: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
@ -108,6 +107,7 @@ def get_attn_backend(
|
||||
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||
is_attention_free)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using Flash Attention backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
|
@ -624,8 +624,6 @@ class BartEncoder(nn.Module):
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
# retrieve input_ids and inputs_embeds
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(
|
||||
|
@ -80,8 +80,8 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
|
||||
"currently supported with encoder/"
|
||||
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
|
||||
"backends currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
||||
|
@ -19,6 +19,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.utils import get_architecture_class_name
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
|
||||
MultiModalRegistry)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -36,6 +37,11 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# The Mllama model has PagedAttention specific logic because of which it
|
||||
# can only be run with the XFORMERS backend
|
||||
# TODO Make Mllama model work with Flash Attention backend.
|
||||
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||
@ -101,9 +107,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
models) but these arguments are present here for compatibility with
|
||||
the base-class constructor.
|
||||
'''
|
||||
|
||||
self._maybe_force_supported_attention_backend()
|
||||
|
||||
self._maybe_force_supported_attention_backend(model_config)
|
||||
super().__init__(
|
||||
model_config,
|
||||
parallel_config,
|
||||
@ -119,7 +123,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
# Crash for unsupported encoder/scenarios
|
||||
assert_enc_dec_mr_supported_scenario(self)
|
||||
|
||||
def _maybe_force_supported_attention_backend(self):
|
||||
def _is_xformers_only_encoder_decoder_model(self,
|
||||
model: ModelConfig) -> bool:
|
||||
return get_architecture_class_name(
|
||||
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
|
||||
|
||||
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
|
||||
'''
|
||||
Force vLLM to use the XFormers attention backend,
|
||||
which is currently the only supported option.
|
||||
@ -135,22 +144,26 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
is_forced_by_global = maybe_global_forced_backend is not None
|
||||
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
||||
|
||||
if not (is_forced_by_global or is_forced_by_env_var):
|
||||
if not (is_forced_by_global or is_forced_by_env_var) \
|
||||
and self._is_xformers_only_encoder_decoder_model(model):
|
||||
# The user has not already specified an attention backend
|
||||
# override
|
||||
logger.info("EncoderDecoderModelRunner requires "
|
||||
"XFormers backend; overriding backend "
|
||||
"auto-selection and forcing XFormers.")
|
||||
logger.info(
|
||||
"Encoder-Decoder Model Architecture %s requires XFormers "
|
||||
"backend; overriding backend auto-selection and "
|
||||
"forcing XFormers.", get_architecture_class_name(model))
|
||||
global_force_attn_backend(_Backend.XFORMERS)
|
||||
elif is_forced_by_global:
|
||||
# Backend override enforced by global variable takes
|
||||
# precedence over vLLM backend environment variable.
|
||||
if maybe_global_forced_backend != _Backend.XFORMERS:
|
||||
if maybe_global_forced_backend not in\
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||
raise_backend_err()
|
||||
elif is_forced_by_env_var:
|
||||
# Backend override enforced by vLLM backend
|
||||
# environment variable
|
||||
if maybe_env_var_forced_backend != _Backend.XFORMERS:
|
||||
if maybe_env_var_forced_backend not in\
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||
raise_backend_err()
|
||||
|
||||
def _list_to_int32_tensor(
|
||||
@ -532,6 +545,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
attn_metadata.encoder_seq_lens,
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.cross_slot_mapping,
|
||||
attn_metadata.cross_block_tables,
|
||||
) = (
|
||||
@ -539,6 +553,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
encoder_seq_lens,
|
||||
encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len,
|
||||
encoder_seq_start_loc,
|
||||
cross_slot_mapping_tensor,
|
||||
cross_block_tables,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user