[Misc] Add qwen2.5-vl BNB support (#12944)
This commit is contained in:
parent
256a2d29dc
commit
4c8dd12ef3
@ -40,7 +40,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
projection_size, num_heads)
|
projection_size, num_heads)
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, world_size)
|
num_heads, self.tp_size)
|
||||||
|
|
||||||
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
||||||
output_size=3 * projection_size,
|
output_size=3 * projection_size,
|
||||||
@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
|
# [s, b, 3 * head * head_dim]
|
||||||
|
seq_len, bs, _ = qkv.shape
|
||||||
|
if self.tp_size > 1:
|
||||||
|
qkv = tensor_model_parallel_all_gather(qkv)
|
||||||
|
|
||||||
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||||
|
q, k, v = qkv.chunk(3, dim=2)
|
||||||
|
|
||||||
|
# 3 * [s, b, head * head_dim]
|
||||||
|
if self.tp_size > 1:
|
||||||
|
splitter = partial(dist_utils.split_tensor_along_last_dim,
|
||||||
|
num_partitions=self.tp_size)
|
||||||
|
q = splitter(q)[self.tp_rank]
|
||||||
|
k = splitter(k)[self.tp_rank]
|
||||||
|
v = splitter(v)[self.tp_rank]
|
||||||
|
|
||||||
|
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||||
|
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head)
|
||||||
|
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -240,15 +264,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
|
|
||||||
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||||
new_x_shape = x.size()[:-1] + (
|
q, k, v = self.split_qkv(x)
|
||||||
self.num_attention_heads_per_partition,
|
|
||||||
3 * self.hidden_size_per_attention_head,
|
|
||||||
)
|
|
||||||
x = x.view(*new_x_shape)
|
|
||||||
|
|
||||||
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
|
||||||
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
|
|
||||||
batch_size = q.shape[1]
|
batch_size = q.shape[1]
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||||
@ -665,24 +682,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if name.endswith("qkv.weight"):
|
|
||||||
visual_num_heads = self.num_heads
|
|
||||||
visual_embed_dim = self.hidden_size
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
|
||||||
head_size,
|
|
||||||
visual_embed_dim)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
|
||||||
elif name.endswith("qkv.bias"):
|
|
||||||
visual_num_heads = self.num_heads
|
|
||||||
visual_embed_dim = self.hidden_size
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
|
||||||
head_size)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user