[V1] EP/TP MoE + DP Attention (#13931)

This commit is contained in:
Tyler Michael Smith 2025-03-05 00:27:26 -05:00 committed by GitHub
parent 0a995d5434
commit 72c62eae5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 250 additions and 75 deletions

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py # usage:
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
# python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel # we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts. # ranks. And each rank will create a vLLM instance to process its own prompts.
import os import os
@ -7,6 +9,9 @@ import os
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import get_open_port from vllm.utils import get_open_port
GPUs_per_dp_rank = 2
DP_size = 2
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(dp_rank) os.environ["VLLM_DP_RANK"] = str(dp_rank)
@ -48,8 +53,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
max_tokens=16 * (dp_rank + 1)) max_tokens=16 * (dp_rank + 1))
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="ibm-research/PowerMoE-3b",
tensor_parallel_size=2, tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True) enforce_eager=True)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
@ -62,14 +67,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
if __name__ == "__main__": if __name__ == "__main__":
from multiprocessing import Process from multiprocessing import Process
dp_size = 2
GPUs_per_dp_rank = 2
dp_master_ip = "127.0.0.1" dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port() dp_master_port = get_open_port()
procs = [] procs = []
for i in range(dp_size): for i in range(DP_size):
proc = Process(target=main, proc = Process(target=main,
args=(dp_size, i, dp_master_ip, dp_master_port, args=(DP_size, i, dp_master_ip, dp_master_port,
GPUs_per_dp_rank)) GPUs_per_dp_rank))
proc.start() proc.start()
procs.append(proc) procs.append(proc)

View File

@ -217,6 +217,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
params_dtype=dtype, params_dtype=dtype,
tp_size=1, tp_size=1,
dp_size=1,
).cuda() ).cuda()
# Load the weights # Load the weights

View File

@ -324,7 +324,7 @@ def unified_attention(
) -> torch.Tensor: ) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
@ -356,7 +356,7 @@ def unified_attention_with_output(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self, self.impl.forward(self,
query, query,

View File

@ -396,8 +396,9 @@ class VllmBackend:
cache_dir = self.compilation_config.cache_dir cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
local_cache_dir = os.path.join( rank = vllm_config.parallel_config.rank
cache_dir, f"rank_{vllm_config.parallel_config.rank}") dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

View File

@ -25,16 +25,22 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list) batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass
class DPMetadata:
num_tokens_across_dp: list[int]
cu_tokens_across_dp_cpu: torch.Tensor
@dataclass @dataclass
class ForwardContext: class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context # copy from vllm_config.compilation_config.static_forward_context
attn_layers: dict[str, Any] no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context # TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache # TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
num_tokens_across_dp: Optional[ # set dynamically for each forward pass
list[int]] = None # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None
_forward_context: Optional[ForwardContext] = None _forward_context: Optional[ForwardContext] = None
@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any,
need_to_track_batchsize = track_batchsize and attn_metadata is not None need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize: if need_to_track_batchsize:
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
num_tokens_across_dp = None dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1: if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any,
dtype=torch.int32) dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
num_tokens_across_dp = num_tokens_tensor.tolist() cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(num_tokens_across_dp, cu_tokens_across_dp_cpu)
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
_forward_context = ForwardContext( _forward_context = ForwardContext(
attn_layers=vllm_config.compilation_config.static_forward_context, no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
num_tokens_across_dp=num_tokens_across_dp) dp_metadata=dp_metadata)
try: try:
yield yield
finally: finally:

View File

@ -8,9 +8,11 @@ import torch
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_moe import fused_experts from .fused_moe import fused_experts
@ -246,6 +249,51 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
forward_native = forward_cuda forward_native = forward_cuda
def determine_expert_map(
ep_size: int, ep_rank: int,
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Args:
ep_size (int): The size of the expert parallel group
global_num_experts (int): The total number of experts in the model.
Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)
local_num_experts = global_num_experts // ep_size
# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts
if ep_rank < (ep_size - 1):
# Each non-last rank gets local_num_experts experts.
expert_map[ep_rank * local_num_experts:
(ep_rank + 1) * local_num_experts] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
expert_map[-local_num_experts:] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
return (local_num_experts, expert_map)
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
@ -282,6 +330,7 @@ class FusedMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
ep_size: Optional[int] = None, ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
@ -293,16 +342,48 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# For smuggling this layer into the fused moe custom op
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP
# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
self.tp_size = (tp_size if tp_size is not None else self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
self.dp_size = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.dp_rank = (0
if self.dp_size == 1 else get_dp_group().rank_in_group)
self.global_num_experts = num_experts
if envs.VLLM_TEST_ENABLE_EP: if envs.VLLM_TEST_ENABLE_EP:
self.ep_size = self.tp_size # Set TP size to 1 to adjust for EP and adjust EP size and rank
# for DP attention.
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
self.tp_rank = 0
self.ep_size = self.tp_size * self.dp_size
self.tp_size = 1 self.tp_size = 1
self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
else: else:
# Adjust TP size for DP attention
self.tp_rank = tp_rank + self.tp_size * self.dp_rank
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1 self.ep_size = 1
self.local_num_experts = self.global_num_experts
self.expert_map = None
self.top_k = top_k self.top_k = top_k
self.global_num_experts = num_experts self.global_num_experts = num_experts
self.local_num_experts = self.global_num_experts // self.ep_size
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
@ -316,26 +397,6 @@ class FusedMoE(torch.nn.Module):
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.activation = activation self.activation = activation
self.expert_map = None
if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.global_num_experts, ),
-1,
dtype=torch.int32)
# Create a expert map for the local experts
ep_rank = get_tensor_model_parallel_rank()
if ep_rank < (self.ep_size - 1):
# Each non-last rank gets local_num_experts experts.
self.expert_map[ep_rank * self.local_num_experts:
(ep_rank + 1) * self.local_num_experts] = \
torch.arange(0, self.local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
self.local_num_experts = (self.global_num_experts -
ep_rank * self.local_num_experts)
self.expert_map[-self.local_num_experts:] = \
torch.arange(0, self.local_num_experts, dtype=torch.int32)
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
@ -493,9 +554,6 @@ class FusedMoE(torch.nn.Module):
if expert_id == -1: if expert_id == -1:
return return
# TP rank is set to 0 if EP is enabled
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
# compressed-tensors checkpoints with packed weights are stored flipped # compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format # TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality # against known CompressionFormat enum values that have this quality
@ -539,8 +597,7 @@ class FusedMoE(torch.nn.Module):
final_shape = list(loaded_weight.shape) final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]: if shard_id in ["w1", "w3"]:
final_shape[1] *= 2 final_shape[1] *= 2
final_shape[shard_dim] = final_shape[ final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
shard_dim] // get_tensor_model_parallel_world_size()
param.materialize(final_shape, dtype=loaded_weight.dtype) param.materialize(final_shape, dtype=loaded_weight.dtype)
expert_data = param.data if full_load else param.data[expert_id] expert_data = param.data if full_load else param.data[expert_id]
@ -567,7 +624,7 @@ class FusedMoE(torch.nn.Module):
shard_id=shard_id, shard_id=shard_id,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank) tp_rank=self.tp_rank)
return return
# Case weight scales and zero_points # Case weight scales and zero_points
@ -584,7 +641,7 @@ class FusedMoE(torch.nn.Module):
shard_dim=shard_dim, shard_dim=shard_dim,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank) tp_rank=self.tp_rank)
elif quant_method in [ elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value, FusedMoeWeightScaleSupported.BLOCK.value,
@ -594,7 +651,7 @@ class FusedMoE(torch.nn.Module):
shard_dim=shard_dim, shard_dim=shard_dim,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank, tp_rank=self.tp_rank,
load_full_w2=getattr(param, "load_full_w2", False)) load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id, self._load_per_tensor_weight_scale(shard_id=shard_id,
@ -621,7 +678,7 @@ class FusedMoE(torch.nn.Module):
shard_dim=shard_dim, shard_dim=shard_dim,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_data=expert_data, expert_data=expert_data,
tp_rank=tp_rank) tp_rank=self.tp_rank)
return return
@staticmethod @staticmethod
@ -665,10 +722,45 @@ class FusedMoE(torch.nn.Module):
return topk_weights, topk_ids return topk_weights, topk_ids
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(get_dp_group().world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
if self.use_direct_call:
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
if self.dp_size > 1:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
@ -687,6 +779,14 @@ class FusedMoE(torch.nn.Module):
activation=self.activation, activation=self.activation,
) )
if self.dp_size > 1:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
final_hidden_states = all_hidden_states[start:end, :]
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.) # Default set to False. (May have to add shared expert outputs.)
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
@ -757,3 +857,26 @@ class FusedMoE(torch.nn.Module):
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="moe_forward",
op_func=moe_forward,
mutates_args=[],
fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -46,7 +46,7 @@ class AriaImagePixelInputs(TypedDict):
pixel_values: torch.Tensor pixel_values: torch.Tensor
pixel_mask: Optional[torch.Tensor] pixel_mask: Optional[torch.Tensor]
""" """
Shape: Shape:
pixel_values: `(batch_size * num_images, num_channels, height, width)` pixel_values: `(batch_size * num_images, num_channels, height, width)`
pixel_mask: `(batch_size * num_images, height, width)` pixel_mask: `(batch_size * num_images, height, width)`
""" """
@ -135,11 +135,11 @@ class AriaProjector(nn.Module):
query numbers, query numbers,
e.g., {1225: 128, 4900: 256}. This allows for different query sizes e.g., {1225: 128, 4900: 256}. This allows for different query sizes
based on image resolution. based on image resolution.
embed_dim (int): Embedding dimension. embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads. num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value. kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network. ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension. output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
Outputs: Outputs:
@ -239,6 +239,7 @@ class AriaTextMoELayer(nn.Module):
self, self,
config: AriaTextConfig, config: AriaTextConfig,
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -254,6 +255,7 @@ class AriaTextMoELayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
reduce_results=True, reduce_results=True,
prefix=f"{prefix}.experts",
) )
self.shared_experts = LlamaMLP( self.shared_experts = LlamaMLP(
config.hidden_size, config.hidden_size,
@ -301,7 +303,9 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__(config, cache_config, quant_config, prefix) super().__init__(config, cache_config, quant_config, prefix)
self.mlp = AriaTextMoELayer(config, quant_config=quant_config) self.mlp = AriaTextMoELayer(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
class AriaTextModel(LlamaModel, SupportsQuant): class AriaTextModel(LlamaModel, SupportsQuant):

View File

@ -65,6 +65,7 @@ class DbrxExperts(FusedMoE):
config: DbrxConfig, config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
): ):
super().__init__( super().__init__(
num_experts=config.ffn_config.moe_num_experts, num_experts=config.ffn_config.moe_num_experts,
@ -76,6 +77,7 @@ class DbrxExperts(FusedMoE):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(), tp_size=get_tensor_model_parallel_world_size(),
prefix=prefix,
) )
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -139,6 +141,7 @@ class DbrxMoE(nn.Module):
config: DbrxConfig, config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
@ -150,7 +153,8 @@ class DbrxMoE(nn.Module):
self.experts = DbrxExperts(config=config, self.experts = DbrxExperts(config=config,
quant_config=quant_config, quant_config=quant_config,
params_dtype=self.params_dtype) params_dtype=self.params_dtype,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
@ -291,7 +295,7 @@ class DbrxBlock(nn.Module):
cache_config, cache_config,
quant_config, quant_config,
prefix=f"{prefix}.norm_attn_norm") prefix=f"{prefix}.norm_attn_norm")
self.ffn = DbrxMoE(config, quant_config) self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn")
def forward( def forward(
self, self,

View File

@ -47,7 +47,8 @@ class JambaMoE(nn.Module):
top_k: Optional[int] = None, top_k: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.num_total_experts = num_experts or config.num_experts self.num_total_experts = num_experts or config.num_experts
self.top_k = top_k or config.num_experts_per_tok self.top_k = top_k or config.num_experts_per_tok
@ -70,7 +71,8 @@ class JambaMoE(nn.Module):
reduce_results=True, reduce_results=True,
renormalize=False, renormalize=False,
use_grouped_topk=False, use_grouped_topk=False,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
@ -92,13 +94,15 @@ class JambaMLP(JambaMoE):
config: JambaConfig, config: JambaConfig,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(config, super().__init__(config,
num_experts=1, num_experts=1,
top_k=1, top_k=1,
params_dtype=params_dtype, params_dtype=params_dtype,
tp_size=tp_size, tp_size=tp_size,
quant_config=quant_config) quant_config=quant_config,
prefix=prefix)
class JambaMambaDecoderLayer(nn.Module): class JambaMambaDecoderLayer(nn.Module):
@ -109,6 +113,7 @@ class JambaMambaDecoderLayer(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False, is_lora_enabled: Optional[bool] = False,
prefix: str = "",
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -129,7 +134,9 @@ class JambaMambaDecoderLayer(nn.Module):
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config=quant_config) self.feed_forward = ffn_layer_class(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, self.pre_ff_layernorm = RMSNorm(config.hidden_size,
@ -211,7 +218,9 @@ class JambaAttentionDecoderLayer(nn.Module):
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config=quant_config) self.feed_forward = ffn_layer_class(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, self.pre_ff_layernorm = RMSNorm(config.hidden_size,

View File

@ -71,6 +71,7 @@ class MixtralMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -93,6 +94,7 @@ class MixtralMoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
dp_size=dp_size,
prefix=f"{prefix}.experts") prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@ -80,7 +80,8 @@ class OlmoeMoE(nn.Module):
reduce_results=True, reduce_results=True,
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size) tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
@ -212,6 +213,7 @@ class OlmoeDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)

View File

@ -249,6 +249,7 @@ class PhiMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -272,7 +273,8 @@ class PhiMoE(nn.Module):
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
custom_routing_function=phimoe_routing_function) custom_routing_function=phimoe_routing_function,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
@ -396,6 +398,7 @@ class PhiMoEDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe",
) )
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,

View File

@ -100,6 +100,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -115,7 +116,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False, reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.experts")
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
@ -277,7 +279,8 @@ class Qwen2MoeDecoderLayer(nn.Module):
config.num_experts > 0 and config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0): (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config, self.mlp = Qwen2MoeSparseMoeBlock(config=config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen2MoeMLP( self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,

View File

@ -111,6 +111,7 @@ class CudaPlatformBase(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step: if scheduler_config.is_multi_step:
@ -150,6 +151,14 @@ class CudaPlatformBase(Platform):
"FlashMLA: Forcing kv cache block size to 64 since this" "FlashMLA: Forcing kv cache block size to 64 since this"
" is currently the only block size supported by the kernel.") " is currently the only block size supported by the kernel.")
if (parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP is "
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
@classmethod @classmethod
def get_current_memory_usage(cls, def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None device: Optional[torch.types.Device] = None

View File

@ -2196,8 +2196,8 @@ def bind_kv_cache(
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [ layer_need_kv_cache = [
layer_name for layer_name in ctx layer_name for layer_name in ctx
if ctx[layer_name].attn_type in (AttentionType.DECODER, if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type
AttentionType.ENCODER_DECODER) in (AttentionType.DECODER, AttentionType.ENCODER_DECODER))
] ]
layer_index_sorted = sorted( layer_index_sorted = sorted(
set( set(

View File

@ -149,7 +149,6 @@ class EngineCore:
if not self.scheduler.has_unfinished_requests(): if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs( return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats()) outputs=[], scheduler_stats=self.scheduler.make_stats())
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output) output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(

View File

@ -17,6 +17,7 @@ from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
@ -1357,7 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
Args: Args:
kv_cache_config: Configuration for the KV cache, including the KV kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer cache size of each layer
""" """
if len(kv_cache_config.groups) > 1: if len(kv_cache_config.groups) > 1:
@ -1389,10 +1390,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> KVCacheSpec:
""" """
Generates the KVCacheSpec by parsing the kv cache format from each Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context. Attention module in the static forward context.
Returns: Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
@ -1400,6 +1401,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {} kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
# TODO: Support other attention modules, e.g., sliding window, # TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA. # cross-attention, MLA.
assert isinstance(attn_module, Attention) assert isinstance(attn_module, Attention)