[Misc] Add qwen2.5-vl BNB support (#12944)
This commit is contained in:
parent
256a2d29dc
commit
4c8dd12ef3
@ -40,7 +40,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, world_size)
|
||||
num_heads, self.tp_size)
|
||||
|
||||
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
||||
output_size=3 * projection_size,
|
||||
@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = tensor_model_parallel_all_gather(qkv)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim]
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(dist_utils.split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
v = splitter(v)[self.tp_rank]
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -240,15 +264,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
|
||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
||||
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
@ -665,24 +682,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if name.endswith("qkv.weight"):
|
||||
visual_num_heads = self.num_heads
|
||||
visual_embed_dim = self.hidden_size
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
||||
head_size,
|
||||
visual_embed_dim)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
||||
elif name.endswith("qkv.bias"):
|
||||
visual_num_heads = self.num_heads
|
||||
visual_embed_dim = self.hidden_size
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
Loading…
x
Reference in New Issue
Block a user