[Bugfix] Fix bug of xformer prefill for encoder-decoder (#9026)
This commit is contained in:
parent
89feb4c84d
commit
00298e092c
@ -559,25 +559,32 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale, v_scale)
|
k_scale, v_scale)
|
||||||
|
|
||||||
if attn_type != AttentionType.ENCODER:
|
if attn_type == AttentionType.ENCODER:
|
||||||
# Decoder self-attention supports chunked prefill.
|
|
||||||
# Encoder/decoder cross-attention requires no chunked
|
|
||||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
||||||
else:
|
|
||||||
# Encoder attention - chunked prefill is not applicable;
|
# Encoder attention - chunked prefill is not applicable;
|
||||||
# derive token-count from query shape & and treat them
|
# derive token-count from query shape & and treat them
|
||||||
# as 100% prefill tokens
|
# as 100% prefill tokens
|
||||||
assert attn_metadata.num_encoder_tokens is not None
|
assert attn_metadata.num_encoder_tokens is not None
|
||||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
||||||
num_decode_tokens = 0
|
num_decode_tokens = 0
|
||||||
|
elif attn_type == AttentionType.DECODER:
|
||||||
if attn_type == AttentionType.DECODER:
|
# Decoder self-attention supports chunked prefill.
|
||||||
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# Only enforce this shape-constraint for decoder
|
# Only enforce this shape-constraint for decoder
|
||||||
# self-attention
|
# self-attention
|
||||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||||
|
else: # attn_type == AttentionType.ENCODER_DECODER
|
||||||
|
# Encoder/decoder cross-attention requires no chunked
|
||||||
|
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||||
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
if attn_metadata.num_encoder_tokens is not None:
|
||||||
|
num_encoder_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
else:
|
||||||
|
num_encoder_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
@ -585,8 +592,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
query = query[:num_prefill_tokens]
|
query = query[:num_prefill_tokens]
|
||||||
if key is not None and value is not None:
|
if key is not None and value is not None:
|
||||||
key = key[:num_prefill_tokens]
|
key = key[:num_encoder_tokens]
|
||||||
value = value[:num_prefill_tokens]
|
value = value[:num_encoder_tokens]
|
||||||
|
|
||||||
assert query.shape[0] == num_prefill_tokens
|
assert query.shape[0] == num_prefill_tokens
|
||||||
assert decode_query.shape[0] == num_decode_tokens
|
assert decode_query.shape[0] == num_decode_tokens
|
||||||
|
Loading…
x
Reference in New Issue
Block a user