[Bugfix] Handle process_weights_after_loading
for QKVCrossParallelLinear
(#15328)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
4ebc0b9640
commit
40b4284fe3
@ -1353,6 +1353,7 @@ class QKVCrossParallelLinear(LinearBase):
|
|||||||
prefix=f"{prefix}.kv_proj_encoder")
|
prefix=f"{prefix}.kv_proj_encoder")
|
||||||
|
|
||||||
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
|
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
|
||||||
|
self.q_size = self.q_proj_decoder.output_size_per_partition
|
||||||
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
|
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
@ -1364,20 +1365,31 @@ class QKVCrossParallelLinear(LinearBase):
|
|||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self):
|
||||||
|
for layer in self.proj.values():
|
||||||
|
if self.quant_method is not None:
|
||||||
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def q_proj_decoder(self) -> ColumnParallelLinear:
|
def q_proj_decoder(self) -> ColumnParallelLinear:
|
||||||
layer = self.proj["q_proj_decoder"]
|
layer = self.proj["q_proj_decoder"]
|
||||||
for name, param in self.named_parameters():
|
for name, param in self.named_parameters():
|
||||||
target_param = getattr(layer, name)
|
target_param = getattr(layer, name, None)
|
||||||
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
|
if target_param is not None:
|
||||||
|
self.sync_weight_attrs(param,
|
||||||
|
target_param,
|
||||||
|
mode="q_proj_decoder")
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kv_proj_encoder(self) -> QKVParallelLinear:
|
def kv_proj_encoder(self) -> QKVParallelLinear:
|
||||||
layer = self.proj["kv_proj_encoder"]
|
layer = self.proj["kv_proj_encoder"]
|
||||||
for name, param in self.named_parameters():
|
for name, param in self.named_parameters():
|
||||||
target_param = getattr(layer, name)
|
target_param = getattr(layer, name, None)
|
||||||
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
|
if target_param is not None:
|
||||||
|
self.sync_weight_attrs(param,
|
||||||
|
target_param,
|
||||||
|
mode="kv_proj_encoder")
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
def sync_weight_attrs(
|
def sync_weight_attrs(
|
||||||
@ -1466,11 +1478,14 @@ class QKVCrossParallelLinear(LinearBase):
|
|||||||
if loaded_shard_id == "q" else self.kv_proj_encoder)
|
if loaded_shard_id == "q" else self.kv_proj_encoder)
|
||||||
target_param = self.select_proj_params(layer, param)
|
target_param = self.select_proj_params(layer, param)
|
||||||
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
|
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
|
||||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
|
||||||
|
layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
|
||||||
|
else:
|
||||||
|
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"in_features={self.input_size}"
|
s = f"in_features={self.input_size}"
|
||||||
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
|
s += f", q_size={self.q_size}"
|
||||||
s += f", kv_size={self.kv_size}"
|
s += f", kv_size={self.kv_size}"
|
||||||
s += f", bias={self.bias is not None}"
|
s += f", bias={self.bias is not None}"
|
||||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||||
|
@ -254,6 +254,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_loader=weight_loader,
|
weight_loader=weight_loader,
|
||||||
)
|
)
|
||||||
scale[:] = torch.finfo(torch.float32).min
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
|
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||||
layer.register_parameter("weight_scale", scale)
|
layer.register_parameter("weight_scale", scale)
|
||||||
else:
|
else:
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
@ -268,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_loader=weight_loader,
|
weight_loader=weight_loader,
|
||||||
)
|
)
|
||||||
scale[:] = torch.finfo(torch.float32).min
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
|
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||||
# The weight_scale_inv name is intentional for deepseekv3
|
# The weight_scale_inv name is intentional for deepseekv3
|
||||||
layer.register_parameter("weight_scale_inv", scale)
|
layer.register_parameter("weight_scale_inv", scale)
|
||||||
|
|
||||||
@ -278,6 +280,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
scale[:] = torch.finfo(torch.float32).min
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
|
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
else:
|
else:
|
||||||
layer.register_parameter("input_scale", None)
|
layer.register_parameter("input_scale", None)
|
||||||
|
@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
|
QKVCrossParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizeMethodBase)
|
QuantizeMethodBase)
|
||||||
from vllm.model_executor.model_loader.tensorizer import (
|
from vllm.model_executor.model_loader.tensorizer import (
|
||||||
@ -160,6 +164,11 @@ def _initialize_model(
|
|||||||
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||||
target_device: torch.device) -> None:
|
target_device: torch.device) -> None:
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
|
if isinstance(module, QKVCrossParallelLinear):
|
||||||
|
# NOTE(Isotr0py): special case for cross QKV layer because
|
||||||
|
# q and kv proj aren't registered as submodules intentionally
|
||||||
|
module.process_weights_after_loading()
|
||||||
|
continue
|
||||||
quant_method = getattr(module, "quant_method", None)
|
quant_method = getattr(module, "quant_method", None)
|
||||||
if isinstance(quant_method, QuantizeMethodBase):
|
if isinstance(quant_method, QuantizeMethodBase):
|
||||||
# When quant methods need to process weights after loading
|
# When quant methods need to process weights after loading
|
||||||
|
Loading…
x
Reference in New Issue
Block a user