[Bugfix] Handle process_weights_after_loading for QKVCrossParallelLinear (#15328)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-04-09 01:02:23 +08:00 committed by GitHub
parent 4ebc0b9640
commit 40b4284fe3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 6 deletions

View File

@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
prefix=f"{prefix}.kv_proj_encoder") prefix=f"{prefix}.kv_proj_encoder")
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.q_size = self.q_proj_decoder.output_size_per_partition
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
if bias: if bias:
@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
else: else:
self.bias = None self.bias = None
def process_weights_after_loading(self):
for layer in self.proj.values():
if self.quant_method is not None:
self.quant_method.process_weights_after_loading(layer)
@property @property
def q_proj_decoder(self) -> ColumnParallelLinear: def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"] layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters(): for name, param in self.named_parameters():
target_param = getattr(layer, name) target_param = getattr(layer, name, None)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="q_proj_decoder")
return layer return layer
@property @property
def kv_proj_encoder(self) -> QKVParallelLinear: def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"] layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters(): for name, param in self.named_parameters():
target_param = getattr(layer, name) target_param = getattr(layer, name, None)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") if target_param is not None:
self.sync_weight_attrs(param,
target_param,
mode="kv_proj_encoder")
return layer return layer
def sync_weight_attrs( def sync_weight_attrs(
@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
if loaded_shard_id == "q" else self.kv_proj_encoder) if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param) target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args) if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
else:
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"in_features={self.input_size}" s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}" s += f", q_size={self.q_size}"
s += f", kv_size={self.kv_size}" s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}" s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}" s += f", tp_size={get_tensor_model_parallel_world_size()}"

View File

@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader, weight_loader=weight_loader,
) )
scale[:] = torch.finfo(torch.float32).min scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale) layer.register_parameter("weight_scale", scale)
else: else:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader, weight_loader=weight_loader,
) )
scale[:] = torch.finfo(torch.float32).min scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3 # The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale) layer.register_parameter("weight_scale_inv", scale)
@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_loader=weight_loader) weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
else: else:
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)

View File

@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase) QuantizeMethodBase)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
@ -160,6 +164,11 @@ def _initialize_model(
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None: target_device: torch.device) -> None:
for _, module in model.named_modules(): for _, module in model.named_modules():
if isinstance(module, QKVCrossParallelLinear):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module.process_weights_after_loading()
continue
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase): if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading # When quant methods need to process weights after loading