[Bugfix] Remove hardcoded head_size=256 for Deepseek v2 and v3 (#12067)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-01-16 18:11:54 +08:00 committed by GitHub
parent 9aa1519f08
commit dd7c9ad870
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 40 deletions

View File

@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128 # This should be sync with get_supported_head_sizes() in
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 # vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [64, 80, 120, 256] HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]

View File

@ -733,9 +733,12 @@ class ModelConfig:
if hasattr(self.hf_text_config, if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type "model_type") and (self.hf_text_config.model_type
in ('deepseek_v2', 'deepseek_v3')): in ('deepseek_v2', 'deepseek_v3')):
# FlashAttention supports only head_size 32, 64, 128, 256, qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
# we need to pad head_size 192 to 256 0)
return 256 qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
if self.is_attention_free: if self.is_attention_free:
return 0 return 0

View File

@ -262,14 +262,8 @@ class DeepseekV2Attention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
# TODO, support head_size 192
self.attn = Attention(self.num_local_heads, self.attn = Attention(self.num_local_heads,
256, self.qk_head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
cache_config=cache_config, cache_config=cache_config,
@ -319,18 +313,14 @@ class DeepseekV2Attention(nn.Module):
k = torch.empty_like(q) k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe k[..., self.qk_nope_head_dim:] = k_pe
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], # padding value to qk_head_dim for alignment
value=0).view(-1, v = torch.nn.functional.pad(
self.num_local_heads * 256) v, [0, self.qk_head_dim - self.v_head_dim],
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(-1, self.num_local_heads * self.qk_head_dim)
value=0).view(-1,
self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view( attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( -1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim) -1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output

View File

@ -269,14 +269,8 @@ class DeepseekV3Attention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
# self.attn = Attention(self.num_heads,
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)
# TODO, support head_size 192
self.attn = Attention(self.num_local_heads, self.attn = Attention(self.num_local_heads,
256, self.qk_head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
cache_config=cache_config, cache_config=cache_config,
@ -326,18 +320,14 @@ class DeepseekV3Attention(nn.Module):
k = torch.empty_like(q) k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe k[..., self.qk_nope_head_dim:] = k_pe
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], # padding value to qk_head_dim for alignment
value=0).view(-1, v = torch.nn.functional.pad(
self.num_local_heads * 256) v, [0, self.qk_head_dim - self.v_head_dim],
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(-1, self.num_local_heads * self.qk_head_dim)
value=0).view(-1,
self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view( attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( -1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim) -1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output