[Bugfix] Remove hardcoded head_size=256
for Deepseek v2 and v3 (#12067)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
9aa1519f08
commit
dd7c9ad870
@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
|||||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||||
|
|
||||||
# FlashAttention forward only supports head dimension at most 128
|
# This should be sync with get_supported_head_sizes() in
|
||||||
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
# vllm.attention.ops.paged_attn.PagedAttention
|
||||||
HEAD_SIZES = [64, 80, 120, 256]
|
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
|
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
USE_ALIBI = [False, True]
|
USE_ALIBI = [False, True]
|
||||||
|
@ -733,9 +733,12 @@ class ModelConfig:
|
|||||||
if hasattr(self.hf_text_config,
|
if hasattr(self.hf_text_config,
|
||||||
"model_type") and (self.hf_text_config.model_type
|
"model_type") and (self.hf_text_config.model_type
|
||||||
in ('deepseek_v2', 'deepseek_v3')):
|
in ('deepseek_v2', 'deepseek_v3')):
|
||||||
# FlashAttention supports only head_size 32, 64, 128, 256,
|
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||||
# we need to pad head_size 192 to 256
|
0)
|
||||||
return 256
|
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
|
||||||
|
0)
|
||||||
|
if qk_rope_head_dim and qk_nope_head_dim:
|
||||||
|
return qk_rope_head_dim + qk_nope_head_dim
|
||||||
|
|
||||||
if self.is_attention_free:
|
if self.is_attention_free:
|
||||||
return 0
|
return 0
|
||||||
|
@ -262,14 +262,8 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
self.scaling = self.scaling * mscale * mscale
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
# self.attn = Attention(self.num_heads,
|
|
||||||
# self.qk_head_dim,
|
|
||||||
# self.scaling,
|
|
||||||
# num_kv_heads=self.num_heads)
|
|
||||||
|
|
||||||
# TODO, support head_size 192
|
|
||||||
self.attn = Attention(self.num_local_heads,
|
self.attn = Attention(self.num_local_heads,
|
||||||
256,
|
self.qk_head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_local_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
@ -319,18 +313,14 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
k = torch.empty_like(q)
|
k = torch.empty_like(q)
|
||||||
k[..., :self.qk_nope_head_dim] = k_nope
|
k[..., :self.qk_nope_head_dim] = k_nope
|
||||||
k[..., self.qk_nope_head_dim:] = k_pe
|
k[..., self.qk_nope_head_dim:] = k_pe
|
||||||
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
|
# padding value to qk_head_dim for alignment
|
||||||
value=0).view(-1,
|
v = torch.nn.functional.pad(
|
||||||
self.num_local_heads * 256)
|
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||||
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
|
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||||
value=0).view(-1,
|
|
||||||
self.num_local_heads * 256)
|
|
||||||
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
|
|
||||||
value=0).view(-1,
|
|
||||||
self.num_local_heads * 256)
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
attn_output = attn_output.view(
|
attn_output = attn_output.view(
|
||||||
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
|
-1, self.num_local_heads,
|
||||||
|
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||||
-1, self.num_local_heads * self.v_head_dim)
|
-1, self.num_local_heads * self.v_head_dim)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
@ -269,14 +269,8 @@ class DeepseekV3Attention(nn.Module):
|
|||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
self.scaling = self.scaling * mscale * mscale
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
# self.attn = Attention(self.num_heads,
|
|
||||||
# self.qk_head_dim,
|
|
||||||
# self.scaling,
|
|
||||||
# num_kv_heads=self.num_heads)
|
|
||||||
|
|
||||||
# TODO, support head_size 192
|
|
||||||
self.attn = Attention(self.num_local_heads,
|
self.attn = Attention(self.num_local_heads,
|
||||||
256,
|
self.qk_head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_local_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
@ -326,18 +320,14 @@ class DeepseekV3Attention(nn.Module):
|
|||||||
k = torch.empty_like(q)
|
k = torch.empty_like(q)
|
||||||
k[..., :self.qk_nope_head_dim] = k_nope
|
k[..., :self.qk_nope_head_dim] = k_nope
|
||||||
k[..., self.qk_nope_head_dim:] = k_pe
|
k[..., self.qk_nope_head_dim:] = k_pe
|
||||||
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
|
# padding value to qk_head_dim for alignment
|
||||||
value=0).view(-1,
|
v = torch.nn.functional.pad(
|
||||||
self.num_local_heads * 256)
|
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||||
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
|
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||||
value=0).view(-1,
|
|
||||||
self.num_local_heads * 256)
|
|
||||||
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
|
|
||||||
value=0).view(-1,
|
|
||||||
self.num_local_heads * 256)
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
attn_output = attn_output.view(
|
attn_output = attn_output.view(
|
||||||
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
|
-1, self.num_local_heads,
|
||||||
|
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||||
-1, self.num_local_heads * self.v_head_dim)
|
-1, self.num_local_heads * self.v_head_dim)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user