[V1] EP/TP MoE + DP Attention (#13931)
This commit is contained in:
parent
0a995d5434
commit
72c62eae5f
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
|
# usage:
|
||||||
|
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
|
||||||
|
# python examples/offline_inference/data_parallel.py
|
||||||
# we need to have a launcher to create multiple data parallel
|
# we need to have a launcher to create multiple data parallel
|
||||||
# ranks. And each rank will create a vLLM instance to process its own prompts.
|
# ranks. And each rank will create a vLLM instance to process its own prompts.
|
||||||
import os
|
import os
|
||||||
@ -7,6 +9,9 @@ import os
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.utils import get_open_port
|
from vllm.utils import get_open_port
|
||||||
|
|
||||||
|
GPUs_per_dp_rank = 2
|
||||||
|
DP_size = 2
|
||||||
|
|
||||||
|
|
||||||
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
||||||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||||
@ -48,8 +53,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
|||||||
max_tokens=16 * (dp_rank + 1))
|
max_tokens=16 * (dp_rank + 1))
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m",
|
llm = LLM(model="ibm-research/PowerMoE-3b",
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=GPUs_per_dp_rank,
|
||||||
enforce_eager=True)
|
enforce_eager=True)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
@ -62,14 +67,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
dp_size = 2
|
|
||||||
GPUs_per_dp_rank = 2
|
|
||||||
dp_master_ip = "127.0.0.1"
|
dp_master_ip = "127.0.0.1"
|
||||||
dp_master_port = get_open_port()
|
dp_master_port = get_open_port()
|
||||||
procs = []
|
procs = []
|
||||||
for i in range(dp_size):
|
for i in range(DP_size):
|
||||||
proc = Process(target=main,
|
proc = Process(target=main,
|
||||||
args=(dp_size, i, dp_master_ip, dp_master_port,
|
args=(DP_size, i, dp_master_ip, dp_master_port,
|
||||||
GPUs_per_dp_rank))
|
GPUs_per_dp_rank))
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
|
@ -217,6 +217,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
params_dtype=dtype,
|
params_dtype=dtype,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
|
dp_size=1,
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# Load the weights
|
# Load the weights
|
||||||
|
@ -324,7 +324,7 @@ def unified_attention(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.attn_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||||
|
|
||||||
@ -356,7 +356,7 @@ def unified_attention_with_output(
|
|||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.attn_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
self.impl.forward(self,
|
self.impl.forward(self,
|
||||||
query,
|
query,
|
||||||
|
@ -396,8 +396,9 @@ class VllmBackend:
|
|||||||
|
|
||||||
cache_dir = self.compilation_config.cache_dir
|
cache_dir = self.compilation_config.cache_dir
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
local_cache_dir = os.path.join(
|
rank = vllm_config.parallel_config.rank
|
||||||
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
|
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||||
self.compilation_config.local_cache_dir = local_cache_dir
|
self.compilation_config.local_cache_dir = local_cache_dir
|
||||||
|
|
||||||
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
|
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||||
|
@ -25,16 +25,22 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
|||||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DPMetadata:
|
||||||
|
num_tokens_across_dp: list[int]
|
||||||
|
cu_tokens_across_dp_cpu: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardContext:
|
class ForwardContext:
|
||||||
# copy from vllm_config.compilation_config.static_forward_context
|
# copy from vllm_config.compilation_config.static_forward_context
|
||||||
attn_layers: dict[str, Any]
|
no_compile_layers: dict[str, Any]
|
||||||
# TODO: extend to support per-layer dynamic forward context
|
# TODO: extend to support per-layer dynamic forward context
|
||||||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
||||||
# TODO: remove after making all virtual_engines share the same kv cache
|
# TODO: remove after making all virtual_engines share the same kv cache
|
||||||
virtual_engine: int # set dynamically for each forward pass
|
virtual_engine: int # set dynamically for each forward pass
|
||||||
num_tokens_across_dp: Optional[
|
# set dynamically for each forward pass
|
||||||
list[int]] = None # set dynamically for each forward pass
|
dp_metadata: Optional[DPMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
_forward_context: Optional[ForwardContext] = None
|
_forward_context: Optional[ForwardContext] = None
|
||||||
@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||||
if need_to_track_batchsize:
|
if need_to_track_batchsize:
|
||||||
forward_start_time = time.perf_counter()
|
forward_start_time = time.perf_counter()
|
||||||
num_tokens_across_dp = None
|
dp_metadata: Optional[DPMetadata] = None
|
||||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||||
num_tokens_across_dp = num_tokens_tensor.tolist()
|
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
||||||
|
dp_metadata = DPMetadata(num_tokens_across_dp, cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
global _forward_context
|
global _forward_context
|
||||||
prev_context = _forward_context
|
prev_context = _forward_context
|
||||||
_forward_context = ForwardContext(
|
_forward_context = ForwardContext(
|
||||||
attn_layers=vllm_config.compilation_config.static_forward_context,
|
no_compile_layers=vllm_config.compilation_config.
|
||||||
|
static_forward_context,
|
||||||
virtual_engine=virtual_engine,
|
virtual_engine=virtual_engine,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
num_tokens_across_dp=num_tokens_across_dp)
|
dp_metadata=dp_metadata)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
@ -8,9 +8,11 @@ import torch
|
|||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_moe import fused_experts
|
from .fused_moe import fused_experts
|
||||||
@ -246,6 +249,51 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
forward_native = forward_cuda
|
forward_native = forward_cuda
|
||||||
|
|
||||||
|
|
||||||
|
def determine_expert_map(
|
||||||
|
ep_size: int, ep_rank: int,
|
||||||
|
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Calculates how many experts should be assigned to each rank for EP and
|
||||||
|
creates a mapping from global to local expert index. Experts are
|
||||||
|
distributed evenly across ranks. Any remaining are assigned to the
|
||||||
|
last rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ep_size (int): The size of the expert parallel group
|
||||||
|
global_num_experts (int): The total number of experts in the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
||||||
|
- local_num_experts (int): The number of experts assigned
|
||||||
|
to the current rank.
|
||||||
|
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
||||||
|
(global_num_experts,) mapping from global to local index.
|
||||||
|
Contains -1 for experts not assigned to the current rank.
|
||||||
|
Returns None if ep_size is 1.
|
||||||
|
"""
|
||||||
|
assert ep_size > 0
|
||||||
|
if ep_size == 1:
|
||||||
|
return (global_num_experts, None)
|
||||||
|
|
||||||
|
local_num_experts = global_num_experts // ep_size
|
||||||
|
|
||||||
|
# Create a tensor of size num_experts filled with -1
|
||||||
|
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
|
||||||
|
# Create a expert map for the local experts
|
||||||
|
if ep_rank < (ep_size - 1):
|
||||||
|
# Each non-last rank gets local_num_experts experts.
|
||||||
|
expert_map[ep_rank * local_num_experts:
|
||||||
|
(ep_rank + 1) * local_num_experts] = \
|
||||||
|
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
# All remaining experts are assigned to the last rank.
|
||||||
|
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
|
||||||
|
|
||||||
|
expert_map[-local_num_experts:] = \
|
||||||
|
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||||
|
return (local_num_experts, expert_map)
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
"""FusedMoE layer for MoE models.
|
"""FusedMoE layer for MoE models.
|
||||||
|
|
||||||
@ -282,6 +330,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
ep_size: Optional[int] = None,
|
ep_size: Optional[int] = None,
|
||||||
|
dp_size: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
@ -293,16 +342,48 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
# For smuggling this layer into the fused moe custom op
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
self.layer_name = prefix
|
||||||
|
self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP
|
||||||
|
|
||||||
|
# Note: here we guard against accessing the TP and DP groups when
|
||||||
|
# uninitialized (this happens when testing)
|
||||||
self.tp_size = (tp_size if tp_size is not None else
|
self.tp_size = (tp_size if tp_size is not None else
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
|
tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
|
||||||
|
self.dp_size = (dp_size
|
||||||
|
if dp_size is not None else get_dp_group().world_size)
|
||||||
|
self.dp_rank = (0
|
||||||
|
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
||||||
|
self.global_num_experts = num_experts
|
||||||
|
|
||||||
if envs.VLLM_TEST_ENABLE_EP:
|
if envs.VLLM_TEST_ENABLE_EP:
|
||||||
self.ep_size = self.tp_size
|
# Set TP size to 1 to adjust for EP and adjust EP size and rank
|
||||||
|
# for DP attention.
|
||||||
|
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
|
||||||
|
self.tp_rank = 0
|
||||||
|
self.ep_size = self.tp_size * self.dp_size
|
||||||
self.tp_size = 1
|
self.tp_size = 1
|
||||||
|
|
||||||
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||||
|
ep_size=self.ep_size,
|
||||||
|
ep_rank=self.ep_rank,
|
||||||
|
global_num_experts=self.global_num_experts)
|
||||||
else:
|
else:
|
||||||
|
# Adjust TP size for DP attention
|
||||||
|
self.tp_rank = tp_rank + self.tp_size * self.dp_rank
|
||||||
|
self.ep_rank = 0
|
||||||
|
self.tp_size = self.tp_size * self.dp_size
|
||||||
self.ep_size = 1
|
self.ep_size = 1
|
||||||
|
self.local_num_experts = self.global_num_experts
|
||||||
|
self.expert_map = None
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.global_num_experts = num_experts
|
self.global_num_experts = num_experts
|
||||||
self.local_num_experts = self.global_num_experts // self.ep_size
|
|
||||||
assert intermediate_size % self.tp_size == 0
|
assert intermediate_size % self.tp_size == 0
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@ -316,26 +397,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.scoring_func = scoring_func
|
self.scoring_func = scoring_func
|
||||||
self.e_score_correction_bias = e_score_correction_bias
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.expert_map = None
|
|
||||||
|
|
||||||
if self.ep_size > 1:
|
|
||||||
# Create a tensor of size num_experts filled with -1
|
|
||||||
self.expert_map = torch.full((self.global_num_experts, ),
|
|
||||||
-1,
|
|
||||||
dtype=torch.int32)
|
|
||||||
# Create a expert map for the local experts
|
|
||||||
ep_rank = get_tensor_model_parallel_rank()
|
|
||||||
if ep_rank < (self.ep_size - 1):
|
|
||||||
# Each non-last rank gets local_num_experts experts.
|
|
||||||
self.expert_map[ep_rank * self.local_num_experts:
|
|
||||||
(ep_rank + 1) * self.local_num_experts] = \
|
|
||||||
torch.arange(0, self.local_num_experts, dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
# All remaining experts are assigned to the last rank.
|
|
||||||
self.local_num_experts = (self.global_num_experts -
|
|
||||||
ep_rank * self.local_num_experts)
|
|
||||||
self.expert_map[-self.local_num_experts:] = \
|
|
||||||
torch.arange(0, self.local_num_experts, dtype=torch.int32)
|
|
||||||
|
|
||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||||
raise ValueError("Only softmax scoring function is supported for "
|
raise ValueError("Only softmax scoring function is supported for "
|
||||||
@ -493,9 +554,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if expert_id == -1:
|
if expert_id == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TP rank is set to 0 if EP is enabled
|
|
||||||
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||||
# against known CompressionFormat enum values that have this quality
|
# against known CompressionFormat enum values that have this quality
|
||||||
@ -539,8 +597,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
final_shape = list(loaded_weight.shape)
|
final_shape = list(loaded_weight.shape)
|
||||||
if shard_id in ["w1", "w3"]:
|
if shard_id in ["w1", "w3"]:
|
||||||
final_shape[1] *= 2
|
final_shape[1] *= 2
|
||||||
final_shape[shard_dim] = final_shape[
|
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
|
||||||
shard_dim] // get_tensor_model_parallel_world_size()
|
|
||||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||||
|
|
||||||
expert_data = param.data if full_load else param.data[expert_id]
|
expert_data = param.data if full_load else param.data[expert_id]
|
||||||
@ -567,7 +624,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_data=expert_data,
|
expert_data=expert_data,
|
||||||
tp_rank=tp_rank)
|
tp_rank=self.tp_rank)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Case weight scales and zero_points
|
# Case weight scales and zero_points
|
||||||
@ -584,7 +641,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_dim=shard_dim,
|
shard_dim=shard_dim,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_data=expert_data,
|
expert_data=expert_data,
|
||||||
tp_rank=tp_rank)
|
tp_rank=self.tp_rank)
|
||||||
elif quant_method in [
|
elif quant_method in [
|
||||||
FusedMoeWeightScaleSupported.GROUP.value,
|
FusedMoeWeightScaleSupported.GROUP.value,
|
||||||
FusedMoeWeightScaleSupported.BLOCK.value,
|
FusedMoeWeightScaleSupported.BLOCK.value,
|
||||||
@ -594,7 +651,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_dim=shard_dim,
|
shard_dim=shard_dim,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_data=expert_data,
|
expert_data=expert_data,
|
||||||
tp_rank=tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
load_full_w2=getattr(param, "load_full_w2", False))
|
load_full_w2=getattr(param, "load_full_w2", False))
|
||||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||||
@ -621,7 +678,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_dim=shard_dim,
|
shard_dim=shard_dim,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_data=expert_data,
|
expert_data=expert_data,
|
||||||
tp_rank=tp_rank)
|
tp_rank=self.tp_rank)
|
||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -665,10 +722,45 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
def naive_multicast(self, x: torch.Tensor,
|
||||||
|
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||||
|
assert (len(x.shape) == 2)
|
||||||
|
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
|
|
||||||
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||||
|
self.dp_rank - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||||
|
buffer[start:end, :].copy_(x)
|
||||||
|
for idx in range(get_dp_group().world_size):
|
||||||
|
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[idx]
|
||||||
|
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
|
if self.use_direct_call:
|
||||||
|
return self.forward_impl(hidden_states, router_logits)
|
||||||
|
else:
|
||||||
|
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||||
|
self.layer_name)
|
||||||
|
|
||||||
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
if self.dp_size > 1:
|
||||||
|
cu_tokens_across_dp_cpu = get_forward_context(
|
||||||
|
).dp_metadata.cu_tokens_across_dp_cpu
|
||||||
|
|
||||||
|
hidden_states = self.naive_multicast(hidden_states,
|
||||||
|
cu_tokens_across_dp_cpu)
|
||||||
|
router_logits = self.naive_multicast(router_logits,
|
||||||
|
cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
@ -687,6 +779,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.dp_size > 1:
|
||||||
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||||
|
self.dp_rank - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||||
|
|
||||||
|
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
|
||||||
|
final_hidden_states = all_hidden_states[start:end, :]
|
||||||
|
|
||||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
# Default set to False. (May have to add shared expert outputs.)
|
# Default set to False. (May have to add shared expert outputs.)
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
@ -757,3 +857,26 @@ class FusedMoE(torch.nn.Module):
|
|||||||
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
|
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||||
|
layer_name: str) -> torch.Tensor:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
return self.forward_impl(hidden_states, router_logits)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||||
|
layer_name: str) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="moe_forward",
|
||||||
|
op_func=moe_forward,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=moe_forward_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
@ -46,7 +46,7 @@ class AriaImagePixelInputs(TypedDict):
|
|||||||
pixel_values: torch.Tensor
|
pixel_values: torch.Tensor
|
||||||
pixel_mask: Optional[torch.Tensor]
|
pixel_mask: Optional[torch.Tensor]
|
||||||
"""
|
"""
|
||||||
Shape:
|
Shape:
|
||||||
pixel_values: `(batch_size * num_images, num_channels, height, width)`
|
pixel_values: `(batch_size * num_images, num_channels, height, width)`
|
||||||
pixel_mask: `(batch_size * num_images, height, width)`
|
pixel_mask: `(batch_size * num_images, height, width)`
|
||||||
"""
|
"""
|
||||||
@ -135,11 +135,11 @@ class AriaProjector(nn.Module):
|
|||||||
query numbers,
|
query numbers,
|
||||||
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
|
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
|
||||||
based on image resolution.
|
based on image resolution.
|
||||||
embed_dim (int): Embedding dimension.
|
embed_dim (int): Embedding dimension.
|
||||||
num_heads (int): Number of attention heads.
|
num_heads (int): Number of attention heads.
|
||||||
kv_dim (int): Dimension of key and value.
|
kv_dim (int): Dimension of key and value.
|
||||||
ff_dim (int): Hidden dimension of the feed-forward network.
|
ff_dim (int): Hidden dimension of the feed-forward network.
|
||||||
output_dim (int): Output dimension.
|
output_dim (int): Output dimension.
|
||||||
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
|
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
@ -239,6 +239,7 @@ class AriaTextMoELayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: AriaTextConfig,
|
config: AriaTextConfig,
|
||||||
quant_config: Optional[QuantizationConfig],
|
quant_config: Optional[QuantizationConfig],
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -254,6 +255,7 @@ class AriaTextMoELayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
)
|
)
|
||||||
self.shared_experts = LlamaMLP(
|
self.shared_experts = LlamaMLP(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@ -301,7 +303,9 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(config, cache_config, quant_config, prefix)
|
super().__init__(config, cache_config, quant_config, prefix)
|
||||||
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
|
self.mlp = AriaTextMoELayer(config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
|
|
||||||
class AriaTextModel(LlamaModel, SupportsQuant):
|
class AriaTextModel(LlamaModel, SupportsQuant):
|
||||||
|
@ -65,6 +65,7 @@ class DbrxExperts(FusedMoE):
|
|||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_experts=config.ffn_config.moe_num_experts,
|
num_experts=config.ffn_config.moe_num_experts,
|
||||||
@ -76,6 +77,7 @@ class DbrxExperts(FusedMoE):
|
|||||||
renormalize=True,
|
renormalize=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=get_tensor_model_parallel_world_size(),
|
tp_size=get_tensor_model_parallel_world_size(),
|
||||||
|
prefix=prefix,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -139,6 +141,7 @@ class DbrxMoE(nn.Module):
|
|||||||
config: DbrxConfig,
|
config: DbrxConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@ -150,7 +153,8 @@ class DbrxMoE(nn.Module):
|
|||||||
|
|
||||||
self.experts = DbrxExperts(config=config,
|
self.experts = DbrxExperts(config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
params_dtype=self.params_dtype)
|
params_dtype=self.params_dtype,
|
||||||
|
prefix=f"{prefix}.experts")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
@ -291,7 +295,7 @@ class DbrxBlock(nn.Module):
|
|||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=f"{prefix}.norm_attn_norm")
|
prefix=f"{prefix}.norm_attn_norm")
|
||||||
self.ffn = DbrxMoE(config, quant_config)
|
self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -47,7 +47,8 @@ class JambaMoE(nn.Module):
|
|||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_total_experts = num_experts or config.num_experts
|
self.num_total_experts = num_experts or config.num_experts
|
||||||
self.top_k = top_k or config.num_experts_per_tok
|
self.top_k = top_k or config.num_experts_per_tok
|
||||||
@ -70,7 +71,8 @@ class JambaMoE(nn.Module):
|
|||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.experts")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
@ -92,13 +94,15 @@ class JambaMLP(JambaMoE):
|
|||||||
config: JambaConfig,
|
config: JambaConfig,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
super().__init__(config,
|
super().__init__(config,
|
||||||
num_experts=1,
|
num_experts=1,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
class JambaMambaDecoderLayer(nn.Module):
|
class JambaMambaDecoderLayer(nn.Module):
|
||||||
@ -109,6 +113,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
is_lora_enabled: Optional[bool] = False,
|
is_lora_enabled: Optional[bool] = False,
|
||||||
|
prefix: str = "",
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -129,7 +134,9 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
num_experts = config.layers_num_experts[layer_idx]
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||||
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
self.feed_forward = ffn_layer_class(config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.feed_forward")
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||||
@ -211,7 +218,9 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
num_experts = config.layers_num_experts[layer_idx]
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||||
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
self.feed_forward = ffn_layer_class(config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.feed_forward")
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
@ -71,6 +71,7 @@ class MixtralMoE(nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
|
dp_size: Optional[int] = None,
|
||||||
prefix: str = ""):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -93,6 +94,7 @@ class MixtralMoE(nn.Module):
|
|||||||
renormalize=True,
|
renormalize=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
|
dp_size=dp_size,
|
||||||
prefix=f"{prefix}.experts")
|
prefix=f"{prefix}.experts")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -80,7 +80,8 @@ class OlmoeMoE(nn.Module):
|
|||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size)
|
tp_size=tp_size,
|
||||||
|
prefix=f"{prefix}.experts")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
@ -212,6 +213,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
@ -249,6 +249,7 @@ class PhiMoE(nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -272,7 +273,8 @@ class PhiMoE(nn.Module):
|
|||||||
renormalize=False,
|
renormalize=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
custom_routing_function=phimoe_routing_function)
|
custom_routing_function=phimoe_routing_function,
|
||||||
|
prefix=f"{prefix}.experts")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
@ -396,6 +398,7 @@ class PhiMoEDecoderLayer(nn.Module):
|
|||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.block_sparse_moe",
|
||||||
)
|
)
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
|
@ -100,6 +100,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -115,7 +116,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.experts")
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
config.num_experts,
|
config.num_experts,
|
||||||
@ -277,7 +279,8 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
config.num_experts > 0 and
|
config.num_experts > 0 and
|
||||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp")
|
||||||
else:
|
else:
|
||||||
self.mlp = Qwen2MoeMLP(
|
self.mlp = Qwen2MoeMLP(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
|
@ -111,6 +111,7 @@ class CudaPlatformBase(Platform):
|
|||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
if scheduler_config.is_multi_step:
|
if scheduler_config.is_multi_step:
|
||||||
@ -150,6 +151,14 @@ class CudaPlatformBase(Platform):
|
|||||||
"FlashMLA: Forcing kv cache block size to 64 since this"
|
"FlashMLA: Forcing kv cache block size to 64 since this"
|
||||||
" is currently the only block size supported by the kernel.")
|
" is currently the only block size supported by the kernel.")
|
||||||
|
|
||||||
|
if (parallel_config.data_parallel_size > 1
|
||||||
|
and compilation_config.use_cudagraph):
|
||||||
|
logger.info(
|
||||||
|
"Data Parallel: Forcing enforce eager to be True since DP is "
|
||||||
|
"currently not supported with CUDA Graphs.")
|
||||||
|
vllm_config.model_config.enforce_eager = True
|
||||||
|
compilation_config.use_cudagraph = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current_memory_usage(cls,
|
def get_current_memory_usage(cls,
|
||||||
device: Optional[torch.types.Device] = None
|
device: Optional[torch.types.Device] = None
|
||||||
|
@ -2196,8 +2196,8 @@ def bind_kv_cache(
|
|||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
layer_need_kv_cache = [
|
layer_need_kv_cache = [
|
||||||
layer_name for layer_name in ctx
|
layer_name for layer_name in ctx
|
||||||
if ctx[layer_name].attn_type in (AttentionType.DECODER,
|
if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type
|
||||||
AttentionType.ENCODER_DECODER)
|
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER))
|
||||||
]
|
]
|
||||||
layer_index_sorted = sorted(
|
layer_index_sorted = sorted(
|
||||||
set(
|
set(
|
||||||
|
@ -149,7 +149,6 @@ class EngineCore:
|
|||||||
if not self.scheduler.has_unfinished_requests():
|
if not self.scheduler.has_unfinished_requests():
|
||||||
return EngineCoreOutputs(
|
return EngineCoreOutputs(
|
||||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||||
|
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
output = self.model_executor.execute_model(scheduler_output)
|
output = self.model_executor.execute_model(scheduler_output)
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
|
@ -17,6 +17,7 @@ from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
|||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
@ -1357,7 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
Args:
|
Args:
|
||||||
kv_cache_config: Configuration for the KV cache, including the KV
|
kv_cache_config: Configuration for the KV cache, including the KV
|
||||||
cache size of each layer
|
cache size of each layer
|
||||||
"""
|
"""
|
||||||
if len(kv_cache_config.groups) > 1:
|
if len(kv_cache_config.groups) > 1:
|
||||||
@ -1389,10 +1390,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||||
"""
|
"""
|
||||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||||
Attention module in the static forward context.
|
Attention module in the static forward context.
|
||||||
Returns:
|
Returns:
|
||||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||||
format. Layers that do not need KV cache are not included.
|
format. Layers that do not need KV cache are not included.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1400,6 +1401,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
kv_cache_spec: KVCacheSpec = {}
|
kv_cache_spec: KVCacheSpec = {}
|
||||||
for layer_name, attn_module in forward_ctx.items():
|
for layer_name, attn_module in forward_ctx.items():
|
||||||
|
if isinstance(attn_module, FusedMoE):
|
||||||
|
continue
|
||||||
|
|
||||||
# TODO: Support other attention modules, e.g., sliding window,
|
# TODO: Support other attention modules, e.g., sliding window,
|
||||||
# cross-attention, MLA.
|
# cross-attention, MLA.
|
||||||
assert isinstance(attn_module, Attention)
|
assert isinstance(attn_module, Attention)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user