[Bugfix] Fix bug of xformer prefill for encoder-decoder (#9026)

This commit is contained in:
Xiang Xu 2024-10-12 00:00:43 -07:00 committed by GitHub
parent 89feb4c84d
commit 00298e092c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -559,25 +559,32 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.kv_cache_dtype,
k_scale, v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_encoder_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
elif attn_type == AttentionType.DECODER:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.num_encoder_tokens is not None:
num_encoder_tokens = attn_metadata.num_encoder_tokens
else:
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
@ -585,8 +592,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# QKV for prefill.
query = query[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
key = key[:num_encoder_tokens]
value = value[:num_encoder_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens