[Misc] Add qwen2.5-vl BNB support (#12944)

This commit is contained in:
Isotr0py 2025-02-08 20:24:47 +08:00 committed by GitHub
parent 256a2d29dc
commit 4c8dd12ef3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,7 +40,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads) projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide( self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size) num_heads, self.tp_size)
self.qkv = ColumnParallelLinear(input_size=embed_dim, self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size, output_size=3 * projection_size,
@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module):
f"Qwen2.5-VL does not support {self.attn_backend} backend now." f"Qwen2.5-VL does not support {self.attn_backend} backend now."
) )
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = tensor_model_parallel_all_gather(qkv)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
# 3 * [s, b, head * head_dim]
if self.tp_size > 1:
splitter = partial(dist_utils.split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -240,15 +264,8 @@ class Qwen2_5_VisionAttention(nn.Module):
# [s, b, c] --> [s, b, head * 3 * head_dim] # [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x) x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
new_x_shape = x.size()[:-1] + ( q, k, v = self.split_qkv(x)
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1] batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
@ -665,24 +682,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
if name.endswith("qkv.weight"):
visual_num_heads = self.num_heads
visual_embed_dim = self.hidden_size
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif name.endswith("qkv.bias"):
visual_num_heads = self.num_heads
visual_embed_dim = self.hidden_size
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)