[Bugfix] Handle process_weights_after_loading
for QKVCrossParallelLinear
(#15328)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
4ebc0b9640
commit
40b4284fe3
@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
|
||||
prefix=f"{prefix}.kv_proj_encoder")
|
||||
|
||||
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
|
||||
self.q_size = self.q_proj_decoder.output_size_per_partition
|
||||
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
|
||||
|
||||
if bias:
|
||||
@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for layer in self.proj.values():
|
||||
if self.quant_method is not None:
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
@property
|
||||
def q_proj_decoder(self) -> ColumnParallelLinear:
|
||||
layer = self.proj["q_proj_decoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name)
|
||||
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
|
||||
target_param = getattr(layer, name, None)
|
||||
if target_param is not None:
|
||||
self.sync_weight_attrs(param,
|
||||
target_param,
|
||||
mode="q_proj_decoder")
|
||||
return layer
|
||||
|
||||
@property
|
||||
def kv_proj_encoder(self) -> QKVParallelLinear:
|
||||
layer = self.proj["kv_proj_encoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name)
|
||||
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
|
||||
target_param = getattr(layer, name, None)
|
||||
if target_param is not None:
|
||||
self.sync_weight_attrs(param,
|
||||
target_param,
|
||||
mode="kv_proj_encoder")
|
||||
return layer
|
||||
|
||||
def sync_weight_attrs(
|
||||
@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
|
||||
if loaded_shard_id == "q" else self.kv_proj_encoder)
|
||||
target_param = self.select_proj_params(layer, param)
|
||||
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
|
||||
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
|
||||
else:
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
|
||||
s += f", q_size={self.q_size}"
|
||||
s += f", kv_size={self.kv_size}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||
|
@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVCrossParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase)
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
@ -160,6 +164,11 @@ def _initialize_model(
|
||||
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||
target_device: torch.device) -> None:
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, QKVCrossParallelLinear):
|
||||
# NOTE(Isotr0py): special case for cross QKV layer because
|
||||
# q and kv proj aren't registered as submodules intentionally
|
||||
module.process_weights_after_loading()
|
||||
continue
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
|
Loading…
x
Reference in New Issue
Block a user