Expert Parallelism (EP) Support for DeepSeek V2 (#12583)
This commit is contained in:
parent
7940d8a6a7
commit
781096e385
@ -468,7 +468,8 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
|
227
tests/distributed/test_expert_parallel.py
Normal file
227
tests/distributed/test_expert_parallel.py
Normal file
@ -0,0 +1,227 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import TaskOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
||||
|
||||
logger = init_logger("test_expert_parallel")
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
class EPTestOptions(NamedTuple):
|
||||
trust_remote_code: bool
|
||||
tokenizer_mode: Optional[str]
|
||||
load_format: Optional[str] = None
|
||||
hf_overrides: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EPTestSettings:
|
||||
parallel_setups: List[ParallelSetup]
|
||||
distributed_backends: List[str]
|
||||
task: TaskOption
|
||||
test_options: EPTestOptions
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
task: TaskOption = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
load_format: Optional[str] = None,
|
||||
hf_overrides: Optional[str] = None,
|
||||
):
|
||||
return EPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
task=task,
|
||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fast(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
task: TaskOption = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
load_format: Optional[str] = None,
|
||||
hf_overrides: Optional[str] = None,
|
||||
):
|
||||
return EPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
task=task,
|
||||
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides),
|
||||
)
|
||||
|
||||
def iter_params(self, model_name: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for distributed_backend in self.distributed_backends:
|
||||
yield (model_name, parallel_setup, distributed_backend,
|
||||
self.task, opts)
|
||||
|
||||
|
||||
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
||||
# The values displayed here are only a rough indicator of the size of the model
|
||||
|
||||
# yapf: disable
|
||||
TEST_MODELS = {
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast(
|
||||
trust_remote_code=True),
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4),
|
||||
}
|
||||
|
||||
|
||||
def _compare_tp(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
task: TaskOption,
|
||||
test_options: EPTestOptions,
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate"],
|
||||
):
|
||||
(
|
||||
tp_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
(
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
load_format,
|
||||
hf_overrides,
|
||||
) = test_options
|
||||
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} GPUs")
|
||||
|
||||
common_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
"--load-format",
|
||||
"auto",
|
||||
]
|
||||
if chunked_prefill:
|
||||
common_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
if task != "auto":
|
||||
common_args.extend(["--task", task])
|
||||
if trust_remote_code:
|
||||
common_args.append("--trust-remote-code")
|
||||
if tokenizer_mode:
|
||||
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
if load_format:
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", hf_overrides])
|
||||
|
||||
ep_env = {
|
||||
"VLLM_TEST_ENABLE_EP": "1",
|
||||
}
|
||||
|
||||
ep_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
|
||||
# compare without expert parallelism
|
||||
tp_env = {
|
||||
"VLLM_TEST_ENABLE_EP": "0",
|
||||
}
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_name,
|
||||
ep_args,
|
||||
tp_args,
|
||||
ep_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=360)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"test_options"),
|
||||
[
|
||||
params for model_name, settings in TEST_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_ep(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
task: TaskOption,
|
||||
test_options: EPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate")
|
@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
|
||||
num_bits=num_bits,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(
|
||||
a,
|
||||
w_ref1.transpose(1, 2),
|
||||
w_ref2.transpose(1, 2),
|
||||
score,
|
||||
topk,
|
||||
)
|
||||
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
|
||||
score, topk, None)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
||||
|
||||
|
@ -26,6 +26,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
|
||||
@ -34,6 +35,7 @@ TOP_KS = [2, 6]
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
@ -41,6 +43,7 @@ def test_fused_moe(
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
@ -48,10 +51,38 @@ def test_fused_moe(
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
iterative_output = iterative_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
@ -63,13 +94,14 @@ def test_fused_moe(
|
||||
@pytest.mark.parametrize("k", [128, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype, group_size: int, has_zp: bool,
|
||||
weight_bits: int):
|
||||
ep_size: int, dtype: torch.dtype, group_size: int,
|
||||
has_zp: bool, weight_bits: int):
|
||||
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
if has_zp:
|
||||
w_qzeros[expert_id] = qzeros
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1_ref = w1_ref[e_ids]
|
||||
w2_ref = w2_ref[e_ids]
|
||||
w1_qweight = w1_qweight[e_ids]
|
||||
w2_qweight = w2_qweight[e_ids]
|
||||
w1_scales = w1_scales[e_ids]
|
||||
w2_scales = w2_scales[e_ids]
|
||||
w1_qzeros = w1_qzeros[e_ids]
|
||||
w2_qzeros = w2_qzeros[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
w2_qweight,
|
||||
@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
renormalize=False,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
|
||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
|
@ -1053,7 +1053,7 @@ def compute_max_diff(output, output_ref):
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
@ -1061,6 +1061,8 @@ def torch_moe(a, w1, w2, score, topk):
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
|
@ -297,12 +297,12 @@ def _test_completion_close(
|
||||
logprobs=5,
|
||||
temperature=0.0)
|
||||
|
||||
logporbs = completion.choices[0].logprobs.top_logprobs[0]
|
||||
logporbs = {k: round(v, 2) for k, v in logporbs.items()}
|
||||
logprobs = completion.choices[0].logprobs.top_logprobs[0]
|
||||
logprobs = {k: round(v, 2) for k, v in logprobs.items()}
|
||||
|
||||
results.append({
|
||||
"test": "completion_close",
|
||||
"logprobs": logporbs,
|
||||
"logprobs": logprobs,
|
||||
})
|
||||
|
||||
return results
|
||||
|
@ -677,6 +677,23 @@ class ModelConfig:
|
||||
"fallback to the eager mode.")
|
||||
self.enforce_eager = True
|
||||
|
||||
def _verify_with_expert_parallelism(self) -> None:
|
||||
num_expert_names = [
|
||||
"moe_num_experts", # Dbrx
|
||||
"num_experts", # Jamba
|
||||
"n_routed_experts", # DeepSeek
|
||||
"num_local_experts", # Mixtral
|
||||
]
|
||||
num_experts = 0
|
||||
for name in num_expert_names:
|
||||
num_experts = getattr(self.hf_text_config, name, 0)
|
||||
if num_experts > 0:
|
||||
break
|
||||
if num_experts < 1:
|
||||
raise ValueError(
|
||||
"Number of experts in the model must be greater than 0 "
|
||||
"when expert parallelism is enabled.")
|
||||
|
||||
def verify_async_output_proc(self, parallel_config, speculative_config,
|
||||
device_config) -> None:
|
||||
if not self.use_async_output_proc:
|
||||
@ -730,6 +747,9 @@ class ModelConfig:
|
||||
" must be divisible by tensor parallel size "
|
||||
f"({tensor_parallel_size}).")
|
||||
|
||||
if envs.VLLM_TEST_ENABLE_EP:
|
||||
self._verify_with_expert_parallelism()
|
||||
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
if pipeline_parallel_size > 1:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
|
@ -86,6 +86,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
||||
VLLM_TEST_ENABLE_EP: bool = False
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
@ -570,6 +571,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
),
|
||||
|
||||
# If set, vLLM will use the experimental expert parallel implementation on
|
||||
# the FusedMoE layer, using tensor parallelism size as expert parallelism
|
||||
# size.
|
||||
"VLLM_TEST_ENABLE_EP":
|
||||
lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))),
|
||||
|
||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||
# it allows ray to schedule multiple actors on a single GPU,
|
||||
# so that users can colocate other actors on the same GPUs as vLLM.
|
||||
|
@ -20,6 +20,18 @@ from vllm.utils import direct_register_custom_op
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
compute_type):
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel_gptq_awq(
|
||||
# Pointers to matrices
|
||||
@ -120,17 +132,26 @@ def fused_moe_kernel_gptq_awq(
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
# -----------------------------------------------------------
|
||||
# Write back zeros to the output when the expert is not
|
||||
# in the current expert parallel rank.
|
||||
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
||||
offs_token, token_mask, BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N, compute_type)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
|
||||
if use_int4_w4a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
||||
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \
|
||||
stride_bn
|
||||
b_shifter = (offs_k[:, None] % 2) * 4
|
||||
elif use_int8_w8a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
@ -170,7 +191,8 @@ def fused_moe_kernel_gptq_awq(
|
||||
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
||||
offs_bn[None, :] * stride_bsn + \
|
||||
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
||||
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \
|
||||
stride_bsk
|
||||
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
||||
b_scale = b_scale.to(tl.float32)
|
||||
|
||||
@ -319,13 +341,22 @@ def fused_moe_kernel(
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
# -----------------------------------------------------------
|
||||
# Write back zeros to the output when the expert is not
|
||||
# in the current expert parallel rank.
|
||||
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
||||
offs_token, token_mask, BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N, compute_type)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
if use_int8_w8a16:
|
||||
@ -349,7 +380,6 @@ def fused_moe_kernel(
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
@ -544,8 +574,11 @@ def moe_align_block_size_triton(
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor, block_size: int,
|
||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
@ -555,6 +588,10 @@ def moe_align_block_size(
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
@ -589,7 +626,9 @@ def moe_align_block_size(
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||
# mapping global expert ids to local expert ids in expert parallelism.
|
||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
@ -618,6 +657,9 @@ def moe_align_block_size(
|
||||
else:
|
||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
@ -1001,6 +1043,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@ -1009,8 +1053,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@ -1022,6 +1067,8 @@ def inplace_fused_experts_fake(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@ -1049,6 +1096,8 @@ def outplace_fused_experts(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@ -1058,8 +1107,9 @@ def outplace_fused_experts(
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@ -1071,6 +1121,8 @@ def outplace_fused_experts_fake(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@ -1098,26 +1150,27 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if inplace:
|
||||
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
torch.ops.vllm.inplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
@ -1129,6 +1182,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@ -1153,6 +1208,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
@ -1166,20 +1224,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
config = get_config_func(M)
|
||||
|
||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||
intermediate_cache1 = torch.empty((M, top_k_num, N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
||||
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
||||
intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
@ -1221,7 +1279,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
|
||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
||||
global_num_experts, expert_map))
|
||||
|
||||
invoke_fused_moe_kernel(curr_hidden_states,
|
||||
w1,
|
||||
@ -1235,7 +1294,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
@ -1286,6 +1345,8 @@ def fused_moe(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@ -1320,6 +1381,11 @@ def fused_moe(
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
@ -1334,8 +1400,6 @@ def fused_moe(
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
@ -1358,6 +1422,8 @@ def fused_moe(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
|
@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
@ -55,6 +56,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
@ -113,6 +116,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
@ -125,6 +130,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
@ -139,6 +146,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
@ -160,7 +169,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True)
|
||||
inplace=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
@ -172,6 +183,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -196,6 +209,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
@ -215,6 +230,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk=top_k,
|
||||
gating_output=router_logits,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
renormalize=renormalize)
|
||||
|
||||
forward_native = forward_cuda
|
||||
@ -255,6 +272,7 @@ class FusedMoE(torch.nn.Module):
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
@ -267,8 +285,13 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.tp_size = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
if envs.VLLM_TEST_ENABLE_EP:
|
||||
self.ep_size = self.tp_size
|
||||
self.tp_size = 1
|
||||
else:
|
||||
self.ep_size = 1
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.num_experts = num_experts # Global number of experts
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
@ -281,6 +304,26 @@ class FusedMoE(torch.nn.Module):
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.expert_map = None
|
||||
|
||||
if self.ep_size > 1:
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.expert_map = torch.full((self.num_experts, ),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
# Create a expert map for the local experts
|
||||
local_num_experts = num_experts // self.ep_size
|
||||
ep_rank = get_tensor_model_parallel_rank()
|
||||
if ep_rank < (self.ep_size - 1):
|
||||
# Each non-last rank gets local_num_experts experts.
|
||||
self.expert_map[ep_rank * local_num_experts:
|
||||
(ep_rank + 1) * local_num_experts] = \
|
||||
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||
else:
|
||||
# All remaining experts are assigned to the last rank.
|
||||
local_num_experts = num_experts - ep_rank * local_num_experts
|
||||
self.expert_map[-local_num_experts:] = \
|
||||
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
@ -293,8 +336,11 @@ class FusedMoE(torch.nn.Module):
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||
if self.expert_map is not None else num_experts
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": num_experts,
|
||||
"num_experts": local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
@ -423,10 +469,22 @@ class FusedMoE(torch.nn.Module):
|
||||
assert shard_id in ("w1", "w3")
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||
if self.expert_map is None:
|
||||
return expert_id
|
||||
return self.expert_map[expert_id].item()
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
# TP rank is set to 0 if EP is enabled
|
||||
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
@ -447,7 +505,6 @@ class FusedMoE(torch.nn.Module):
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
@ -590,13 +647,16 @@ class FusedMoE(torch.nn.Module):
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
|
@ -10,7 +10,9 @@ def fused_moe(
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
renormalize: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -18,6 +20,7 @@ def fused_moe(
|
||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
||||
w2: [num_experts, hidden_size, intermediate_size]
|
||||
gating_output: [*, num_experts]
|
||||
expert_map: [num_experts]
|
||||
"""
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
@ -27,13 +30,16 @@ def fused_moe(
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, num_experts)
|
||||
gating_output = gating_output.view(num_tokens, global_num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
selected_experts = expert_map[selected_experts]
|
||||
|
||||
final_hidden_states = None
|
||||
for expert_idx in range(num_experts):
|
||||
expert_w1 = w1[expert_idx]
|
||||
|
@ -464,10 +464,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
@ -214,6 +214,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
@ -239,6 +241,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
@ -540,10 +544,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
|
@ -108,6 +108,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
@ -133,6 +135,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int8_w8a16=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale)
|
||||
|
||||
|
@ -670,6 +670,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
@ -697,6 +699,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
|
@ -585,6 +585,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
|
@ -288,6 +288,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
@ -317,6 +319,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
inplace=True,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
|
@ -198,6 +198,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
@ -223,6 +225,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
|
@ -106,10 +106,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.n_routed_experts}.")
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
|
Loading…
x
Reference in New Issue
Block a user