[Model] Initialize Florence-2 language backbone support (#9555)
This commit is contained in:
parent
2394962d70
commit
3ff57ebfca
44
examples/florence2_inference.py
Normal file
44
examples/florence2_inference.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
'''
|
||||||
|
Demonstrate prompting of text-to-text
|
||||||
|
encoder/decoder models, specifically Florence-2
|
||||||
|
'''
|
||||||
|
# TODO(Isotr0py):
|
||||||
|
# Move to offline_inference_vision_language.py after porting vision backbone
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
dtype = "float"
|
||||||
|
|
||||||
|
# Create a Florence-2 encoder/decoder model instance
|
||||||
|
llm = LLM(
|
||||||
|
model="microsoft/Florence-2-base",
|
||||||
|
tokenizer="facebook/bart-base",
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
|
||||||
|
"<CAPTION_TO_PHRASE_GROUNDING>", "<OD>", "<DENSE_REGION_CAPTION>",
|
||||||
|
"<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
top_p=1.0,
|
||||||
|
min_tokens=0,
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate output tokens from the prompts. The output is a list of
|
||||||
|
# RequestOutput objects that contain the prompt, generated
|
||||||
|
# text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
encoder_prompt = output.encoder_prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Encoder prompt: {encoder_prompt!r}, "
|
||||||
|
f"Decoder prompt: {prompt!r}, "
|
||||||
|
f"Generated text: {generated_text!r}")
|
@ -253,7 +253,9 @@ class HfRunner:
|
|||||||
dtype: str = "half",
|
dtype: str = "half",
|
||||||
*,
|
*,
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
is_embedding_model: bool = False,
|
||||||
is_sentence_transformer: bool = False,
|
is_sentence_transformer: bool = False,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||||
postprocess_inputs: Callable[[BatchEncoding],
|
postprocess_inputs: Callable[[BatchEncoding],
|
||||||
BatchEncoding] = identity,
|
BatchEncoding] = identity,
|
||||||
@ -281,11 +283,12 @@ class HfRunner:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
))
|
))
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
if not skip_tokenizer_init:
|
||||||
model_name,
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
torch_dtype=torch_dtype,
|
model_name,
|
||||||
trust_remote_code=True,
|
torch_dtype=torch_dtype,
|
||||||
)
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
# don't put this import at the top level
|
# don't put this import at the top level
|
||||||
# it will call torch.cuda.device_count()
|
# it will call torch.cuda.device_count()
|
||||||
@ -295,6 +298,8 @@ class HfRunner:
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
if skip_tokenizer_init:
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
|
|
||||||
self.postprocess_inputs = postprocess_inputs
|
self.postprocess_inputs = postprocess_inputs
|
||||||
|
|
||||||
@ -535,6 +540,7 @@ class HfRunner:
|
|||||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
|
images: Optional[PromptImageInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[TokensTextLogprobs]:
|
) -> List[TokensTextLogprobs]:
|
||||||
'''
|
'''
|
||||||
@ -545,11 +551,17 @@ class HfRunner:
|
|||||||
all_output_ids: List[List[int]] = []
|
all_output_ids: List[List[int]] = []
|
||||||
all_output_strs: List[str] = []
|
all_output_strs: List[str] = []
|
||||||
|
|
||||||
for (encoder_prompt,
|
for i, (encoder_prompt, decoder_prompt) in enumerate(
|
||||||
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
|
to_enc_dec_tuple_list(encoder_decoder_prompts)):
|
||||||
|
processor_kwargs: Dict[str, Any] = {
|
||||||
|
"text": encoder_prompt,
|
||||||
|
"return_tensors": "pt",
|
||||||
|
}
|
||||||
|
if images is not None and images[i] is not None:
|
||||||
|
processor_kwargs["images"] = images[i]
|
||||||
|
|
||||||
encoder_input_ids = self.wrap_device(
|
encoder_input_ids = self.wrap_device(
|
||||||
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
|
self.processor(**processor_kwargs).input_ids,
|
||||||
device=self.model.device.type,
|
device=self.model.device.type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
102
tests/models/encoder_decoder/vision_language/test_florence2.py
Normal file
102
tests/models/encoder_decoder/vision_language/test_florence2.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.inputs.data import ExplicitEncoderDecoderPrompt
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
|
from ....conftest import HfRunner, VllmRunner
|
||||||
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
|
Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
|
||||||
|
decoder_prompt=None,
|
||||||
|
mm_processor_kwargs=None)
|
||||||
|
|
||||||
|
MODELS = ["microsoft/Florence-2-base"]
|
||||||
|
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||||
|
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||||
|
TOKENIZER = "facebook/bart-base"
|
||||||
|
PROMPTS = [
|
||||||
|
Florence2Prompt(encoder_prompt="<CAPTION>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<OCR>"),
|
||||||
|
Florence2Prompt(encoder_prompt="<OD>"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
|
Optional[SampleLogprobs]], ):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
hf_output_str = "</s><s>" + output_str + "</s>"
|
||||||
|
|
||||||
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
prompts: List[ExplicitEncoderDecoderPrompt],
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model,
|
||||||
|
tokenizer_name=TOKENIZER,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||||
|
prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
# Florence-2 processors require image inputs
|
||||||
|
dummy_image = Image.new(mode="RGB", size=(2, 2))
|
||||||
|
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
|
||||||
|
hf_model.model.get_output_embeddings = lambda: \
|
||||||
|
hf_model.model.language_model.lm_head
|
||||||
|
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
|
prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs,
|
||||||
|
images=[dummy_image] * len(prompts),
|
||||||
|
))
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
|
||||||
|
num_logprobs) -> None:
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
PROMPTS,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
261
vllm/model_executor/models/florence2.py
Normal file
261
vllm/model_executor/models/florence2.py
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
import math
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
|
||||||
|
BartParallelLMHead,
|
||||||
|
BartScaledWordEmbedding)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .utils import AutoWeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2LanguageModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
|
||||||
|
self.encoder = BartEncoder(config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.decoder = BartDecoder(config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
self.encoder.embed_tokens.weight = self.shared.weight
|
||||||
|
self.decoder.embed_tokens.weight = self.shared.weight
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
|
encoder_input_ids: torch.Tensor,
|
||||||
|
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||||
|
Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
positions
|
||||||
|
Positions of *decoder* input sequence tokens.
|
||||||
|
encoder_input_ids
|
||||||
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||||
|
encoder_positions:
|
||||||
|
Positions of *encoder* input sequence tokens.
|
||||||
|
kv_caches:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Model output torch.Tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
encoder_hidden_states = None
|
||||||
|
|
||||||
|
if encoder_input_ids.numel() > 0:
|
||||||
|
# Run encoder attention if a non-zero number of encoder tokens
|
||||||
|
# are provided as input
|
||||||
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||||
|
positions=encoder_positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
# decoder outputs consists of
|
||||||
|
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||||
|
decoder_outputs = self.decoder(
|
||||||
|
decoder_input_ids=input_ids,
|
||||||
|
decoder_positions=positions,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
return decoder_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = Florence2LanguageModel(config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
embed_scale = math.sqrt(
|
||||||
|
config.d_model) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = BartParallelLMHead(self.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
embed_scale=embed_scale)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(self.vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
encoder_input_ids: torch.Tensor,
|
||||||
|
encoder_positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
torch.Tensor of *decoder* input token ids.
|
||||||
|
positions
|
||||||
|
torch.Tensor of *decoder* position indices.
|
||||||
|
encoder_input_ids
|
||||||
|
torch.Tensor of *encoder* input token ids.
|
||||||
|
encoder_positions
|
||||||
|
torch.Tensor of *encoder* position indices
|
||||||
|
kv_caches:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Output torch.Tensor
|
||||||
|
"""
|
||||||
|
return self.model(input_ids, positions, encoder_input_ids,
|
||||||
|
encoder_positions, kv_caches, attn_metadata)
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(self, logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name.replace(weight_name, param_name)]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if "final_logits_bias" in name:
|
||||||
|
continue
|
||||||
|
if self.config.tie_word_embeddings and "embed_tokens" in name:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2ForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# TODO(Isotr0py): Add vision backbone
|
||||||
|
self.language_model = Florence2LanguageForConditionalGeneration(
|
||||||
|
config=config.text_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampler(self):
|
||||||
|
return self.language_model.sampler
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
*,
|
||||||
|
encoder_input_ids: torch.Tensor,
|
||||||
|
encoder_positions: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
torch.Tensor of *decoder* input token ids.
|
||||||
|
positions
|
||||||
|
torch.Tensor of *decoder* position indices.
|
||||||
|
encoder_input_ids
|
||||||
|
torch.Tensor of *encoder* input token ids.
|
||||||
|
encoder_positions
|
||||||
|
torch.Tensor of *encoder* position indices
|
||||||
|
kv_caches:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Output torch.Tensor
|
||||||
|
"""
|
||||||
|
return self.language_model(input_ids, positions, encoder_input_ids,
|
||||||
|
encoder_positions, kv_caches, attn_metadata)
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
skip_prefixes = [
|
||||||
|
'image_projection', "vision_tower", "image_proj_norm",
|
||||||
|
"image_pos_embed", "visual_temporal_embed"
|
||||||
|
]
|
||||||
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||||
|
loader.load_weights(weights)
|
@ -85,6 +85,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||||
|
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user