diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1ae57407..21035a9e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase): prefix=f"{prefix}.kv_proj_encoder") # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. + self.q_size = self.q_proj_decoder.output_size_per_partition self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size if bias: @@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase): else: self.bias = None + def process_weights_after_loading(self): + for layer in self.proj.values(): + if self.quant_method is not None: + self.quant_method.process_weights_after_loading(layer) + @property def q_proj_decoder(self) -> ColumnParallelLinear: layer = self.proj["q_proj_decoder"] for name, param in self.named_parameters(): - target_param = getattr(layer, name) - self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") + target_param = getattr(layer, name, None) + if target_param is not None: + self.sync_weight_attrs(param, + target_param, + mode="q_proj_decoder") return layer @property def kv_proj_encoder(self) -> QKVParallelLinear: layer = self.proj["kv_proj_encoder"] for name, param in self.named_parameters(): - target_param = getattr(layer, name) - self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") + target_param = getattr(layer, name, None) + if target_param is not None: + self.sync_weight_attrs(param, + target_param, + mode="kv_proj_encoder") return layer def sync_weight_attrs( @@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase): if loaded_shard_id == "q" else self.kv_proj_encoder) target_param = self.select_proj_params(layer, param) shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () - layer.weight_loader(target_param, loaded_weight, *shard_id_args) + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: + layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) + else: + layer.weight_loader(target_param, loaded_weight, *shard_id_args) def extra_repr(self) -> str: s = f"in_features={self.input_size}" - s += f", q_size={self.q_proj_decoder.output_size_per_partition}" + s += f", q_size={self.q_size}" s += f", kv_size={self.kv_size}" s += f", bias={self.bias is not None}" s += f", tp_size={get_tensor_model_parallel_world_size()}" diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 512d6449..b7327f47 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase): weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"scale_type": "weight_scale"}) layer.register_parameter("weight_scale", scale) else: assert self.quant_config.activation_scheme == "dynamic" @@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase): weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"scale_type": "weight_scale"}) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) @@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase): weight_loader=weight_loader) scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 7e434388..03934ba0 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.layers.linear import (LinearBase, MergedColumnParallelLinear, + QKVCrossParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase) from vllm.model_executor.model_loader.tensorizer import ( @@ -160,6 +164,11 @@ def _initialize_model( def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, target_device: torch.device) -> None: for _, module in model.named_modules(): + if isinstance(module, QKVCrossParallelLinear): + # NOTE(Isotr0py): special case for cross QKV layer because + # q and kv proj aren't registered as submodules intentionally + module.process_weights_after_loading() + continue quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading