[Model] Rename MiniCPMVQwen2 to MiniCPMV2.6 (#7273)
This commit is contained in:
parent
6dffa4b0a6
commit
757ac70a64
@ -222,7 +222,7 @@ Vision Language Models
|
|||||||
-
|
-
|
||||||
* - :code:`MiniCPMV`
|
* - :code:`MiniCPMV`
|
||||||
- MiniCPM-V
|
- MiniCPM-V
|
||||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
|
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||||
-
|
-
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -22,8 +22,8 @@ def run_llava(question):
|
|||||||
prompt = f"USER: <image>\n{question}\nASSISTANT:"
|
prompt = f"USER: <image>\n{question}\nASSISTANT:"
|
||||||
|
|
||||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
||||||
|
stop_token_ids = None
|
||||||
return llm, prompt
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# LLaVA-1.6/LLaVA-NeXT
|
# LLaVA-1.6/LLaVA-NeXT
|
||||||
@ -31,8 +31,8 @@ def run_llava_next(question):
|
|||||||
|
|
||||||
prompt = f"[INST] <image>\n{question} [/INST]"
|
prompt = f"[INST] <image>\n{question} [/INST]"
|
||||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
|
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
|
||||||
|
stop_token_ids = None
|
||||||
return llm, prompt
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Fuyu
|
# Fuyu
|
||||||
@ -40,8 +40,8 @@ def run_fuyu(question):
|
|||||||
|
|
||||||
prompt = f"{question}\n"
|
prompt = f"{question}\n"
|
||||||
llm = LLM(model="adept/fuyu-8b")
|
llm = LLM(model="adept/fuyu-8b")
|
||||||
|
stop_token_ids = None
|
||||||
return llm, prompt
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Phi-3-Vision
|
# Phi-3-Vision
|
||||||
@ -59,7 +59,8 @@ def run_phi3v(question):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
)
|
)
|
||||||
return llm, prompt
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# PaliGemma
|
# PaliGemma
|
||||||
@ -68,8 +69,8 @@ def run_paligemma(question):
|
|||||||
# PaliGemma has special prompt format for VQA
|
# PaliGemma has special prompt format for VQA
|
||||||
prompt = "caption en"
|
prompt = "caption en"
|
||||||
llm = LLM(model="google/paligemma-3b-mix-224")
|
llm = LLM(model="google/paligemma-3b-mix-224")
|
||||||
|
stop_token_ids = None
|
||||||
return llm, prompt
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Chameleon
|
# Chameleon
|
||||||
@ -77,7 +78,8 @@ def run_chameleon(question):
|
|||||||
|
|
||||||
prompt = f"{question}<image>"
|
prompt = f"{question}<image>"
|
||||||
llm = LLM(model="facebook/chameleon-7b")
|
llm = LLM(model="facebook/chameleon-7b")
|
||||||
return llm, prompt
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# MiniCPM-V
|
# MiniCPM-V
|
||||||
@ -89,13 +91,26 @@ def run_minicpmv(question):
|
|||||||
# model_name = "HwwwH/MiniCPM-V-2"
|
# model_name = "HwwwH/MiniCPM-V-2"
|
||||||
|
|
||||||
# 2.5
|
# 2.5
|
||||||
model_name = "openbmb/MiniCPM-Llama3-V-2_5"
|
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"
|
||||||
|
|
||||||
|
#2.6
|
||||||
|
model_name = "openbmb/MiniCPM-V-2_6"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
|
||||||
|
# 2.0
|
||||||
|
# stop_token_ids = [tokenizer.eos_id]
|
||||||
|
|
||||||
|
# 2.5
|
||||||
|
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
|
||||||
|
|
||||||
|
# 2.6
|
||||||
|
stop_tokens = ['<|im_end|>', '<|endoftext|>']
|
||||||
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
@ -104,7 +119,7 @@ def run_minicpmv(question):
|
|||||||
prompt = tokenizer.apply_chat_template(messages,
|
prompt = tokenizer.apply_chat_template(messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True)
|
add_generation_prompt=True)
|
||||||
return llm, prompt
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# InternVL
|
# InternVL
|
||||||
@ -118,7 +133,8 @@ def run_internvl(question):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
)
|
)
|
||||||
return llm, prompt
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# BLIP-2
|
# BLIP-2
|
||||||
@ -128,7 +144,8 @@ def run_blip2(question):
|
|||||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||||
prompt = f"Question: {question} Answer:"
|
prompt = f"Question: {question} Answer:"
|
||||||
llm = LLM(model="Salesforce/blip2-opt-2.7b")
|
llm = LLM(model="Salesforce/blip2-opt-2.7b")
|
||||||
return llm, prompt
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
@ -149,11 +166,13 @@ def main(args):
|
|||||||
if model not in model_example_map:
|
if model not in model_example_map:
|
||||||
raise ValueError(f"Model type {model} is not supported.")
|
raise ValueError(f"Model type {model} is not supported.")
|
||||||
|
|
||||||
llm, prompt = model_example_map[model](question)
|
llm, prompt, stop_token_ids = model_example_map[model](question)
|
||||||
|
|
||||||
# We set temperature to 0.2 so that outputs can be different
|
# We set temperature to 0.2 so that outputs can be different
|
||||||
# even when all prompts are identical when running batch inference.
|
# even when all prompts are identical when running batch inference.
|
||||||
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
|
sampling_params = SamplingParams(temperature=0.2,
|
||||||
|
max_tokens=64,
|
||||||
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
assert args.num_prompts > 0
|
assert args.num_prompts > 0
|
||||||
if args.num_prompts == 1:
|
if args.num_prompts == 1:
|
||||||
|
@ -216,7 +216,6 @@ class BaseResampler(nn.Module):
|
|||||||
|
|
||||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||||
trunc_normal_(self.query, std=0.02)
|
trunc_normal_(self.query, std=0.02)
|
||||||
|
|
||||||
if kv_dim is not None and kv_dim != embed_dim:
|
if kv_dim is not None and kv_dim != embed_dim:
|
||||||
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
||||||
else:
|
else:
|
||||||
@ -225,7 +224,6 @@ class BaseResampler(nn.Module):
|
|||||||
nn.Identity()(*args, **kwargs),
|
nn.Identity()(*args, **kwargs),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||||
self.ln_q = norm_layer(embed_dim)
|
self.ln_q = norm_layer(embed_dim)
|
||||||
self.ln_kv = norm_layer(embed_dim)
|
self.ln_kv = norm_layer(embed_dim)
|
||||||
@ -261,7 +259,6 @@ class Resampler2(BaseResampler):
|
|||||||
norm_layer)
|
norm_layer)
|
||||||
|
|
||||||
self.adaptive = adaptive
|
self.adaptive = adaptive
|
||||||
|
|
||||||
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
||||||
grid_size,
|
grid_size,
|
||||||
version=(2, 0))
|
version=(2, 0))
|
||||||
@ -717,7 +714,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMV2(MiniCPMVBaseModel):
|
class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -890,10 +887,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
|
|||||||
return "resampler" in name
|
return "resampler" in name
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Currently, information about this model is unavailable. We are
|
class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||||
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
|
|
||||||
# to be modified in the future.
|
|
||||||
class MiniCPMVQwen2(MiniCPMVBaseModel):
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -903,6 +897,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__(config, multimodal_config, cache_config, quant_config)
|
super().__init__(config, multimodal_config, cache_config, quant_config)
|
||||||
|
assert self.version == (2, 6)
|
||||||
|
|
||||||
def init_llm(
|
def init_llm(
|
||||||
self,
|
self,
|
||||||
@ -930,6 +925,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
|
|||||||
|
|
||||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||||
with set_default_torch_dtype(torch.float16):
|
with set_default_torch_dtype(torch.float16):
|
||||||
|
# The resampler in 2.6 remains consistent with the one in 2.5.
|
||||||
resampler = Resampler2_5(
|
resampler = Resampler2_5(
|
||||||
num_queries=self.config.query_num,
|
num_queries=self.config.query_num,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
@ -989,6 +985,13 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
|
|||||||
return "resampler" in name or "vpm" in name
|
return "resampler" in name or "vpm" in name
|
||||||
|
|
||||||
|
|
||||||
|
_SUPPORT_VERSION = {
|
||||||
|
(2, 0): MiniCPMV2_0,
|
||||||
|
(2, 5): MiniCPMV2_5,
|
||||||
|
(2, 6): MiniCPMV2_6
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
|
||||||
@ -1016,11 +1019,9 @@ class MiniCPMV(MiniCPMVBaseModel):
|
|||||||
version = str(config.version).split(".")
|
version = str(config.version).split(".")
|
||||||
version = tuple([int(x) for x in version])
|
version = tuple([int(x) for x in version])
|
||||||
# Dispatch class based on version
|
# Dispatch class based on version
|
||||||
if version == (2, 0):
|
instance_class = _SUPPORT_VERSION.get(version, None)
|
||||||
instance_class = MiniCPMV2
|
if instance_class is None:
|
||||||
elif version == (2, 5):
|
raise ValueError(
|
||||||
instance_class = MiniCPMV2_5
|
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
|
||||||
else:
|
|
||||||
instance_class = MiniCPMVQwen2
|
|
||||||
return instance_class(config, multimodal_config, cache_config,
|
return instance_class(config, multimodal_config, cache_config,
|
||||||
quant_config)
|
quant_config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user