[Bugfix] adding chunking mechanism to fused_moe to handle large inputs (#6029)
This commit is contained in:
parent
dec6fc6f3b
commit
12a59959ed
@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
|
@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
|
||||
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
|
||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
@ -248,6 +249,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Only used for XLA devices such as TPUs.
|
||||
"VLLM_XLA_CACHE_PATH":
|
||||
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE":
|
||||
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
|
||||
if M > 65536:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
raise ValueError("MoE kernel does not support more than 65536 tokens, "
|
||||
f"but got {M}")
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
if override_config:
|
||||
config = override_config
|
||||
@ -455,51 +455,74 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config['BLOCK_SIZE_M'], E)
|
||||
compute_type = (tl.bfloat16
|
||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
invoke_fused_moe_kernel(intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
|
||||
if inplace:
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=hidden_states)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1)
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
if tokens_in_chunk < CHUNK_SIZE:
|
||||
# will only happen in the last chunk
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
|
||||
|
||||
invoke_fused_moe_kernel(curr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
invoke_fused_moe_kernel(intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
|
||||
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def fused_moe(
|
||||
|
Loading…
x
Reference in New Issue
Block a user