[Bugfix] Remove hardcoded head_size=256
for Deepseek v2 and v3 (#12067)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
9aa1519f08
commit
dd7c9ad870
@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
|
||||
# FlashAttention forward only supports head dimension at most 128
|
||||
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
# This should be sync with get_supported_head_sizes() in
|
||||
# vllm.attention.ops.paged_attn.PagedAttention
|
||||
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
|
@ -733,9 +733,12 @@ class ModelConfig:
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2', 'deepseek_v3')):
|
||||
# FlashAttention supports only head_size 32, 64, 128, 256,
|
||||
# we need to pad head_size 192 to 256
|
||||
return 256
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||
0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
|
||||
0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
@ -262,14 +262,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# self.attn = Attention(self.num_heads,
|
||||
# self.qk_head_dim,
|
||||
# self.scaling,
|
||||
# num_kv_heads=self.num_heads)
|
||||
|
||||
# TODO, support head_size 192
|
||||
self.attn = Attention(self.num_local_heads,
|
||||
256,
|
||||
self.qk_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
cache_config=cache_config,
|
||||
@ -319,18 +313,14 @@ class DeepseekV2Attention(nn.Module):
|
||||
k = torch.empty_like(q)
|
||||
k[..., :self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim:] = k_pe
|
||||
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
# padding value to qk_head_dim for alignment
|
||||
v = torch.nn.functional.pad(
|
||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads,
|
||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads * self.v_head_dim)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
@ -269,14 +269,8 @@ class DeepseekV3Attention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# self.attn = Attention(self.num_heads,
|
||||
# self.qk_head_dim,
|
||||
# self.scaling,
|
||||
# num_kv_heads=self.num_heads)
|
||||
|
||||
# TODO, support head_size 192
|
||||
self.attn = Attention(self.num_local_heads,
|
||||
256,
|
||||
self.qk_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
cache_config=cache_config,
|
||||
@ -326,18 +320,14 @@ class DeepseekV3Attention(nn.Module):
|
||||
k = torch.empty_like(q)
|
||||
k[..., :self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim:] = k_pe
|
||||
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
|
||||
value=0).view(-1,
|
||||
self.num_local_heads * 256)
|
||||
# padding value to qk_head_dim for alignment
|
||||
v = torch.nn.functional.pad(
|
||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads,
|
||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||
-1, self.num_local_heads * self.v_head_dim)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
Loading…
x
Reference in New Issue
Block a user