diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a4a45c9c..41075068 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -468,7 +468,8 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] == "DeepseekV3ForCausalLM": + elif (config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM"): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py new file mode 100644 index 00000000..bc577064 --- /dev/null +++ b/tests/distributed/test_expert_parallel.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import List, Literal, NamedTuple, Optional + +import pytest + +from vllm.config import TaskOption +from vllm.logger import init_logger + +from ..utils import compare_two_settings, fork_new_process_for_each_test + +logger = init_logger("test_expert_parallel") + + +class ParallelSetup(NamedTuple): + tp_size: int + eager_mode: bool + chunked_prefill: bool + + +class EPTestOptions(NamedTuple): + trust_remote_code: bool + tokenizer_mode: Optional[str] + load_format: Optional[str] = None + hf_overrides: Optional[str] = None + + +@dataclass +class EPTestSettings: + parallel_setups: List[ParallelSetup] + distributed_backends: List[str] + task: TaskOption + test_options: EPTestOptions + + @staticmethod + def detailed( + *, + tp_base: int = 2, + task: TaskOption = "auto", + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + load_format: Optional[str] = None, + hf_overrides: Optional[str] = None, + ): + return EPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=2 * tp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp", "ray"], + task=task, + test_options=EPTestOptions(trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides), + ) + + @staticmethod + def fast( + *, + tp_base: int = 2, + task: TaskOption = "auto", + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + load_format: Optional[str] = None, + hf_overrides: Optional[str] = None, + ): + return EPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp"], + task=task, + test_options=EPTestOptions(trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides), + ) + + def iter_params(self, model_name: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for distributed_backend in self.distributed_backends: + yield (model_name, parallel_setup, distributed_backend, + self.task, opts) + + +# NOTE: You can adjust tp_base locally to fit the model in GPU +# The values displayed here are only a rough indicator of the size of the model + +# yapf: disable +TEST_MODELS = { + "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast( + trust_remote_code=True), + "mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4), +} + + +def _compare_tp( + model_name: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + task: TaskOption, + test_options: EPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate"], +): + ( + tp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + ( + trust_remote_code, + tokenizer_mode, + load_format, + hf_overrides, + ) = test_options + + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} GPUs") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + "--load-format", + "auto", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", hf_overrides]) + + ep_env = { + "VLLM_TEST_ENABLE_EP": "1", + } + + ep_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + # compare without expert parallelism + tp_env = { + "VLLM_TEST_ENABLE_EP": "0", + } + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + try: + compare_two_settings(model_name, + ep_args, + tp_args, + ep_env, + tp_env, + method=method, + max_wait_seconds=360) + except Exception: + raise + + +@pytest.mark.parametrize( + ("model_name", "parallel_setup", "distributed_backend", "task", + "test_options"), + [ + params for model_name, settings in TEST_MODELS.items() + for params in settings.iter_params(model_name) + ], +) +@fork_new_process_for_each_test +def test_ep( + model_name: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + task: TaskOption, + test_options: EPTestOptions, + num_gpus_available, +): + _compare_tp(model_name, + parallel_setup, + distributed_backend, + task, + test_options, + num_gpus_available, + method="generate") diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 67595010..939b0e71 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq( num_bits=num_bits, ) - torch_output = torch_moe( - a, - w_ref1.transpose(1, 2), - w_ref2.transpose(1, 2), - score, - topk, - ) + torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2), + score, topk, None) assert compute_max_diff(marlin_output, torch_output) < 4e-2 diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 0f13fbc9..2f5c6904 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -26,6 +26,7 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] TOP_KS = [2, 6] @@ -34,6 +35,7 @@ TOP_KS = [2, 6] @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_moe( m: int, @@ -41,6 +43,7 @@ def test_fused_moe( k: int, e: int, topk: int, + ep_size: int, dtype: torch.dtype, ): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 @@ -48,10 +51,38 @@ def test_fused_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) - torch_output = torch_moe(a, w1, w2, score, topk) + + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randint(0, + e, (local_e, ), + device="cuda", + dtype=torch.int32) + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + torch_output = torch_moe(a, w1, w2, score, topk, e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, @@ -63,13 +94,14 @@ def test_fused_moe( @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype, group_size: int, has_zp: bool, - weight_bits: int): + ep_size: int, dtype: torch.dtype, group_size: int, + has_zp: bool, weight_bits: int): print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, if has_zp: w_qzeros[expert_id] = qzeros + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randint(0, + e, (local_e, ), + device="cuda", + dtype=torch.int32) + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1_ref = w1_ref[e_ids] + w2_ref = w2_ref[e_ids] + w1_qweight = w1_qweight[e_ids] + w2_qweight = w2_qweight[e_ids] + w1_scales = w1_scales[e_ids] + w2_scales = w2_scales[e_ids] + w1_qzeros = w1_qzeros[e_ids] + w2_qzeros = w2_qzeros[e_ids] + else: + e_map = None + triton_output = fused_moe(a, w1_qweight, w2_qweight, @@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, renormalize=False, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, w1_scale=w1_scales, w2_scale=w2_scales, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5be111d7..1ee3a332 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1053,7 +1053,7 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_moe(a, w1, w2, score, topk): +def torch_moe(a, w1, w2, score, topk, expert_map): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) @@ -1061,6 +1061,8 @@ def torch_moe(a, w1, w2, score, topk): topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): diff --git a/tests/utils.py b/tests/utils.py index f39cbe7e..2ad91ca2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -297,12 +297,12 @@ def _test_completion_close( logprobs=5, temperature=0.0) - logporbs = completion.choices[0].logprobs.top_logprobs[0] - logporbs = {k: round(v, 2) for k, v in logporbs.items()} + logprobs = completion.choices[0].logprobs.top_logprobs[0] + logprobs = {k: round(v, 2) for k, v in logprobs.items()} results.append({ "test": "completion_close", - "logprobs": logporbs, + "logprobs": logprobs, }) return results diff --git a/vllm/config.py b/vllm/config.py index 6bcf34c3..ace49a86 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -677,6 +677,23 @@ class ModelConfig: "fallback to the eager mode.") self.enforce_eager = True + def _verify_with_expert_parallelism(self) -> None: + num_expert_names = [ + "moe_num_experts", # Dbrx + "num_experts", # Jamba + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = 0 + for name in num_expert_names: + num_experts = getattr(self.hf_text_config, name, 0) + if num_experts > 0: + break + if num_experts < 1: + raise ValueError( + "Number of experts in the model must be greater than 0 " + "when expert parallelism is enabled.") + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -730,6 +747,9 @@ class ModelConfig: " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") + if envs.VLLM_TEST_ENABLE_EP: + self._verify_with_expert_parallelism() + pipeline_parallel_size = parallel_config.pipeline_parallel_size if pipeline_parallel_size > 1: architectures = getattr(self.hf_config, "architectures", []) diff --git a/vllm/envs.py b/vllm/envs.py index dbf1d462..84426cb5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -86,6 +86,7 @@ if TYPE_CHECKING: VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True VLLM_MLA_DISABLE_REQUANTIZATION: bool = False VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True + VLLM_TEST_ENABLE_EP: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" @@ -570,6 +571,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) ), + # If set, vLLM will use the experimental expert parallel implementation on + # the FusedMoE layer, using tensor parallelism size as expert parallelism + # size. + "VLLM_TEST_ENABLE_EP": + lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))), + # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 543c8ced..4cab72a2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -20,6 +20,18 @@ from vllm.utils import direct_register_custom_op logger = init_logger(__name__) +@triton.jit +def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, + token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, + compute_type): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel_gptq_awq( # Pointers to matrices @@ -120,17 +132,26 @@ def fused_moe_kernel_gptq_awq( offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) - if use_int4_w4a16: b_ptrs = b_ptr + off_experts * stride_be + \ - (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ + stride_bn b_shifter = (offs_k[:, None] % 2) * 4 elif use_int8_w8a16: b_ptrs = b_ptr + off_experts * stride_be + \ @@ -170,7 +191,8 @@ def fused_moe_kernel_gptq_awq( b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ offs_bn[None, :] * stride_bsn + \ - ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ + stride_bsk b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) @@ -319,13 +341,22 @@ def fused_moe_kernel( offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) if use_int8_w8a16: @@ -349,7 +380,6 @@ def fused_moe_kernel( # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. @@ -544,8 +574,11 @@ def moe_align_block_size_triton( def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, - num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. @@ -555,6 +588,10 @@ def moe_align_block_size( top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. + - expert_map: A tensor of shape [num_experts] that maps the expert index + from the global space to the local index space of the current + expert parallel shard. If the expert is not in the current expert + parallel shard, the mapping is set to -1. Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -589,7 +626,9 @@ def moe_align_block_size( device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), + # Expert ids must be zeroed out to prevent index out of bounds error while + # mapping global expert ids to local expert ids in expert parallelism. + expert_ids = torch.zeros((max_num_m_blocks, ), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), @@ -618,6 +657,9 @@ def moe_align_block_size( else: ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + return sorted_ids, expert_ids, num_tokens_post_pad @@ -1001,6 +1043,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1009,8 +1053,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, - w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, w2_zp, a1_scale, a2_scale, block_shape) def inplace_fused_experts_fake( @@ -1022,6 +1067,8 @@ def inplace_fused_experts_fake( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1049,6 +1096,8 @@ def outplace_fused_experts( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1058,8 +1107,9 @@ def outplace_fused_experts( block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, use_fp8_w8a8, use_int8_w8a16, - use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, - a1_scale, a2_scale, block_shape) + use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale, block_shape) def outplace_fused_experts_fake( @@ -1071,6 +1121,8 @@ def outplace_fused_experts_fake( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1098,26 +1150,27 @@ def fused_experts(hidden_states: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None): + block_shape: Optional[List[int]] = None) -> torch.Tensor: + if inplace: - torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, - topk_weights, topk_ids, - use_fp8_w8a8, use_int8_w8a16, - use_int4_w4a16, w1_scale, - w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape) + torch.ops.vllm.inplace_fused_experts( + hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, + use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) return hidden_states else: return torch.ops.vllm.outplace_fused_experts( hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, - a1_scale, a2_scale, block_shape) + use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -1129,6 +1182,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1153,6 +1208,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_tokens, _ = hidden_states.shape E, N, _ = w1.shape + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.shape[1] # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -1166,20 +1224,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, try_get_optimal_moe_config, w1.shape, w2.shape, - topk_ids.shape[1], + top_k_num, config_dtype, block_shape=block_shape, ) config = get_config_func(M) - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + intermediate_cache1 = torch.empty((M, top_k_num, N), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) @@ -1221,7 +1279,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) invoke_fused_moe_kernel(curr_hidden_states, w1, @@ -1235,7 +1294,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], + top_k_num, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, @@ -1286,6 +1345,8 @@ def fused_moe( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1320,6 +1381,11 @@ def fused_moe( - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -1334,8 +1400,6 @@ def fused_moe( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - # Check constraints. - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" if use_grouped_topk: assert num_expert_group is not None and topk_group is not None @@ -1358,6 +1422,8 @@ def fused_moe( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, w1_zp=w1_zp, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f18c0313..49400b69 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Tuple import torch +import vllm.envs as envs from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -55,6 +56,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -113,6 +116,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -125,6 +130,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) @@ -139,6 +146,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -160,7 +169,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True) + inplace=True, + global_num_experts=global_num_experts, + expert_map=expert_map) def forward_cpu( self, @@ -172,6 +183,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, **kwargs, ): @@ -196,6 +209,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None @@ -215,6 +230,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk=top_k, gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, renormalize=renormalize) forward_native = forward_cuda @@ -255,6 +272,7 @@ class FusedMoE(torch.nn.Module): topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + ep_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", @@ -267,8 +285,13 @@ class FusedMoE(torch.nn.Module): self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) + if envs.VLLM_TEST_ENABLE_EP: + self.ep_size = self.tp_size + self.tp_size = 1 + else: + self.ep_size = 1 self.top_k = top_k - self.num_experts = num_experts + self.num_experts = num_experts # Global number of experts assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results @@ -281,6 +304,26 @@ class FusedMoE(torch.nn.Module): self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.expert_map = None + + if self.ep_size > 1: + # Create a tensor of size num_experts filled with -1 + self.expert_map = torch.full((self.num_experts, ), + -1, + dtype=torch.int32) + # Create a expert map for the local experts + local_num_experts = num_experts // self.ep_size + ep_rank = get_tensor_model_parallel_rank() + if ep_rank < (self.ep_size - 1): + # Each non-last rank gets local_num_experts experts. + self.expert_map[ep_rank * local_num_experts: + (ep_rank + 1) * local_num_experts] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + else: + # All remaining experts are assigned to the last rank. + local_num_experts = num_experts - ep_rank * local_num_experts + self.expert_map[-local_num_experts:] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -293,8 +336,11 @@ class FusedMoE(torch.nn.Module): self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None + local_num_experts = torch.sum(self.expert_map != -1) \ + if self.expert_map is not None else num_experts + moe_quant_params = { - "num_experts": num_experts, + "num_experts": local_num_experts, "hidden_size": hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, @@ -423,10 +469,22 @@ class FusedMoE(torch.nn.Module): assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return + + # TP rank is set to 0 if EP is enabled + tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank() + # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -447,7 +505,6 @@ class FusedMoE(torch.nn.Module): SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] - tp_rank = get_tensor_model_parallel_rank() # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -590,13 +647,16 @@ class FusedMoE(torch.nn.Module): top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.num_experts, + expert_map=self.expert_map, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias) - if self.reduce_results and self.tp_size > 1: + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py index d9a5de1b..da27633f 100644 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -10,7 +10,9 @@ def fused_moe( w2: torch.Tensor, gating_output: torch.Tensor, topk: int, - renormalize: bool, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, ) -> torch.Tensor: """ Args: @@ -18,6 +20,7 @@ def fused_moe( w1: [num_experts, intermediate_size * 2, hidden_size] w2: [num_experts, hidden_size, intermediate_size] gating_output: [*, num_experts] + expert_map: [num_experts] """ orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] @@ -27,13 +30,16 @@ def fused_moe( dtype = hidden_states.dtype hidden_states = hidden_states.view(num_tokens, hidden_size) - gating_output = gating_output.view(num_tokens, num_experts) + gating_output = gating_output.view(num_tokens, global_num_experts) topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(dtype) + if expert_map is not None: + selected_experts = expert_map[selected_experts] + final_hidden_states = None for expert_idx in range(num_experts): expert_w1 = w1[expert_idx] diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 111b3f74..0e8c4c7b 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -464,10 +464,17 @@ class AWQMoEMethod(FusedMoEMethodBase): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if expert_map is not None: + raise NotImplementedError( + "Expert Parallelism is not supported for " + "fused Marlin MoE method.") + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index db8e8a4b..389359a6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -214,6 +214,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -239,6 +241,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, @@ -540,10 +544,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if expert_map is not None: + raise NotImplementedError( + "Expert Parallelism is not supported for " + "fused Marlin MoE method.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 663fb8bf..0767926e 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -108,6 +108,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -133,6 +135,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, use_int8_w8a16=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_scale, w2_scale=layer.w2_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1ca39b0f..9f4cd2aa 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -670,6 +670,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -697,6 +699,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=(layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale_inv diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9f960d9f..241fc7d7 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -585,6 +585,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index da06ca3f..a3adac1b 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -288,6 +288,8 @@ class MoeWNA16Method(FusedMoEMethodBase): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -317,6 +319,8 @@ class MoeWNA16Method(FusedMoEMethodBase): inplace=True, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, w1_zp=layer.w13_qzeros if has_zp else None, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 98743b15..36b08589 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -198,6 +198,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -223,6 +225,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a4d52c61..9bf3ec2f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -106,10 +106,6 @@ class DeepseekV2MoE(nn.Module): self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. "