Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Patrick von Platen 2024-09-11 23:41:55 +02:00 committed by GitHub
parent 775f00f81e
commit d394787e52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 807 additions and 9 deletions

View File

@ -247,6 +247,11 @@ Multimodal Language Models
- Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
-
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- Image\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409`
-
* - :code:`QWenLMHeadModel`
- Qwen-VL
- Image\ :sup:`E`

View File

@ -0,0 +1,164 @@
# ruff: noqa
import argparse
from vllm import LLM
from vllm.sampling_params import SamplingParams
# This script is an offline demo for running Pixtral.
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Pixtral-12B-2409",
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced
def run_simple_demo():
model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192)
llm = LLM(model=model_name, tokenizer_mode="mistral")
prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"
messages = [
{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
],
},
]
outputs = llm.chat(messages, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
def run_advanced_demo():
model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5
max_tokens_per_img = 4096
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
llm = LLM(
model=model_name,
tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg},
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
)
prompt = "Describe the following image."
url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
url_2 = "https://picsum.photos/seed/picsum/200/300"
url_3 = "https://picsum.photos/id/32/512/512"
messages = [
{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": url_1
}
},
{
"type": "image_url",
"image_url": {
"url": url_2
}
},
],
},
{
"role": "assistant",
"content": "The images show nature.",
},
{
"role": "user",
"content": "More details please and answer only in French!.",
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": url_3
}
},
],
},
]
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
def main():
parser = argparse.ArgumentParser(
description="Run a demo in simple or advanced mode.")
parser.add_argument(
"mode",
choices=["simple", "advanced"],
help="Specify the demo mode: 'simple' or 'advanced'",
)
args = parser.parse_args()
if args.mode == "simple":
print("Running simple demo...")
run_simple_demo()
elif args.mode == "advanced":
print("Running advanced demo...")
run_advanced_demo()
if __name__ == "__main__":
main()

View File

@ -25,7 +25,7 @@ pyzmq
msgspec
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
mistral_common >= 1.4.0
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
einops # Required for Qwen2-VL.

View File

@ -0,0 +1,64 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py`.
"""
import pytest
from vllm.sampling_params import SamplingParams
pytestmark = pytest.mark.vlm
MODELS = ["mistralai/Pixtral-12B-2409"]
@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
image_urls = [
"https://picsum.photos/id/237/200/300",
"https://picsum.photos/seed/picsum/200/300"
]
expected = [
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
]
prompt = "Describe the image in one short sentence."
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:
for i, image_url in enumerate(image_urls):
messages = [
{
"role":
"user",
"content": [{
"type": "text",
"text": prompt
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}]
},
]
outputs = vllm_model.model.chat(messages,
sampling_params=sampling_params)
assert outputs[0].outputs[0].text == expected[i]

View File

@ -148,7 +148,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":

View File

@ -92,6 +92,8 @@ _MULTIMODAL_MODELS = {
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
}

View File

@ -0,0 +1,551 @@
import math
from array import array
from dataclasses import dataclass, fields
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image
from transformers import PretrainedConfig
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .interfaces import SupportsMultiModal
from .utils import init_vllm_registered_model
def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
max_image_size = mm_encoder.mm_config.max_image_size
image_patch_size = mm_encoder.mm_config.image_patch_size
return ((max_image_size // image_patch_size)**2)
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
mm_config = ctx.model_config.multimodal_config
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
# approximate image size
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)
tokens = mm_encoder(img_chunk).tokens
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens)
seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]}
return seq_data, mm_data
def input_mapper_for_pixtral(ctx: InputContext,
data: object) -> MultiModalInputs:
"""Maps the input data to its MultiModalInputs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalInputs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
data_list = data if isinstance(data, list) else [data]
images = []
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda",
dtype=torch.float16)
images.append(image)
return MultiModalInputs({"images": images})
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: Optional[List[torch.Tensor]],
image_id: int) -> torch.Tensor:
text_locations = input_ids != image_id
image_locations = input_ids == image_id
seq_len = input_ids.shape[0]
N_txt = text_locations.sum().item()
_, D_txt = inputs_embeds.shape
N_img, D_img = image_features.shape
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
"to image features dim {D_img}")
assert (seq_len == N_txt +
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
f"{(N_txt, N_img, image_locations.sum().item())}")
inputs_embeds[image_locations, :] = image_features
return inputs_embeds
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
vision_args = {
key: value
for key, value in self.config.vision_config.to_dict().items()
if key in dataclass_fields
}
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for pixtral.
TODO
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_args.image_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
return hidden_states
def _parse_and_validate_image_input(
self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None
) -> Optional[List[torch.Tensor]]:
if images is None:
return None
if isinstance(images, torch.Tensor):
# always take last images
images = [images[-1][i] for i in range(images.size(1))]
elif isinstance(images, list):
# always take last images
images = [images[-1][i] for i in range(len(images[0]))]
return images
def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor:
return self.vision_language_adapter(self.vision_encoder(image_input))
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder")
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")
def is_vision_weights(weight: Tuple[str, torch.Tensor]):
return is_vision_encoder_weights(
weight) or is_vision_lang_adapter_weights(weight)
llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
weights, 3)
# llm
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
self.language_model.load_weights(llm_weights)
# vision encoder
vision_encoder_weights = filter(is_vision_encoder_weights,
vision_encoder_weights)
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
for name, loaded_weight in vision_encoder_weights:
# cut 'vision_encoder.'
name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[name]
default_weight_loader(param, loaded_weight)
# adapter
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
vision_lang_adapter_weights)
vision_lang_adpter_dict = dict(
self.vision_language_adapter.named_parameters())
for name, loaded_weight in vision_lang_adapter_weights:
# cut 'vision_language_adapter.'
name = '.'.join(name.split(".")[1:])
param = vision_lang_adpter_dict[name]
default_weight_loader(param, loaded_weight)
# Vision encoder
@dataclass
class VisionEncoderArgs:
hidden_size: int
num_channels: int
image_size: int
patch_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
x: torch.Tensor) -> torch.Tensor:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x.ndim
assert ndim > 1
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
freqs_cis.shape,
(x.shape[1], x.shape[-1]),
)
shape = [
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
]
return freqs_cis.view(*shape)
def precompute_freqs_cis_2d(
dim: int,
height: int,
width: int,
theta: float,
) -> torch.Tensor:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
to be indexed by (height, width) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
h = torch.arange(height, device=freqs.device)
w = torch.arange(width, device=freqs.device)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat(
[
freqs_h[:, None, :].repeat(1, width, 1),
freqs_w[None, :, :].repeat(height, 1, 1),
],
dim=-1,
)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
def apply_rotary_emb_vit(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
assert freqs_cis.dtype == torch.complex64
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class FeedForward(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
assert args.intermediate_size is not None
self.w1 = nn.Linear(args.hidden_size,
args.intermediate_size,
bias=False)
self.w2 = nn.Linear(args.intermediate_size,
args.hidden_size,
bias=False)
self.w3 = nn.Linear(args.hidden_size,
args.intermediate_size,
bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Attention(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
assert not args.hidden_size % args.num_attention_heads
self.n_heads = args.num_attention_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
batch, patches, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
out = memory_efficient_attention(q, k, v, attn_bias=mask)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
class TransformerBlock(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x),
mask=mask,
freqs_cis=freqs_cis)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class Transformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(TransformerBlock(args))
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
return x
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
positions = torch.cat([
torch.stack(
torch.meshgrid(
torch.arange(p.shape[-2]),
torch.arange(p.shape[-1]),
indexing="ij",
),
dim=-1,
).reshape(-1, 2) for p in patch_embeds_list
])
return positions
class VisionTransformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
self.patch_conv = nn.Conv2d(
in_channels=args.num_channels,
out_channels=args.hidden_size,
kernel_size=args.patch_size,
stride=args.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
self.transformer = Transformer(args)
head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None
@property
def max_patches_per_side(self) -> int:
return self.args.image_size // self.args.patch_size
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.device:
return next(self.parameters()).dtype
@property
def freqs_cis(self) -> torch.Tensor:
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis_2d(
dim=self.args.hidden_size // self.args.num_attention_heads,
height=self.max_patches_per_side,
width=self.max_patches_per_side,
theta=self.args.rope_theta,
)
if self._freqs_cis.device != self.device:
self._freqs_cis = self._freqs_cis.to(device=self.device)
return self._freqs_cis
def forward(
self,
images: List[torch.Tensor],
) -> torch.Tensor:
"""
Args:
images: list of N_img images of variable sizes,
each of shape (C, H, W)
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
]
# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
positions = position_meshgrid(patch_embeds_list).to(self.device)
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
# pass through Transformer with a block diagonal mask delimiting images
mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# remove batch dimension of the single sequence
return out.squeeze(0)
class VisionLanguageAdapter(nn.Module):
def __init__(self, args: VisionEncoderArgs, dim: int):
super().__init__()
assert isinstance(args, VisionEncoderArgs)
self.w_in = nn.Linear(
args.hidden_size,
dim,
bias=True,
)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))

View File

@ -70,7 +70,7 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
if Path(model).exists():
return (Path(model) / config_name).is_file()
return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
return file_exists(model, config_name, revision=revision, token=token)
def get_config(
@ -205,14 +205,25 @@ def load_params_config(model, revision) -> PretrainedConfig:
config_dict["hidden_act"] = config_dict.get("activation", "silu")
config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False)
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
if config_dict["model_type"] == "transformer":
if "moe" in config_dict:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
return recurse_elems(config_dict)
if config_dict.get("vision_encoder") is not None:
multimodal_config = config_dict.pop("vision_encoder")
config_dict = {
"text_config": config_dict,
"vision_config": multimodal_config
}
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
config_dict["model_type"] = "pixtral"
config = recurse_elems(config_dict)
return config
def get_hf_image_processor_config(