[Model] Remove hardcoded image tokens ids from Pixtral (#11582)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
d34be24bb1
commit
b7dcc003dc
@ -45,13 +45,6 @@ try:
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
# These token ids cannot be retrieved from model config
|
||||
# so we hardcode them here.
|
||||
PIXTRAL_12B_IMAGE_BREAK_ID = 12
|
||||
PIXTRAL_12B_IMAGE_END_ID = 13
|
||||
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
|
||||
PIXTRAL_LARGE_IMAGE_END_ID = 15
|
||||
|
||||
|
||||
def get_max_pixtral_image_tokens(ctx: InputContext):
|
||||
tokenizer = cached_get_tokenizer(
|
||||
@ -201,6 +194,13 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if key in dataclass_fields
|
||||
}
|
||||
|
||||
if not ("image_break_token_id" in vision_args
|
||||
and "image_end_token_id" in vision_args):
|
||||
raise ValueError(
|
||||
"'image_break_token_id' and 'image_end_token_id' not found "
|
||||
"in the vision_encoder arguments. Please download the latest "
|
||||
"version of 'params.json' from the model repository.")
|
||||
|
||||
self.vision_args = VisionEncoderArgs(**vision_args)
|
||||
|
||||
# init MistralForCausalLM
|
||||
@ -240,9 +240,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# NOTE: Image embeddings are split into separate tensors for each image
|
||||
# by the indices of `[IMG_END]` token.
|
||||
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
|
||||
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
|
||||
split_indices = torch.where(image_end_condition)[0] + 1
|
||||
image_end_mask = image_tokens == self.vision_args.image_end_token_id
|
||||
split_indices = torch.where(image_end_mask)[0] + 1
|
||||
if len(split_indices) <= 1:
|
||||
# Do not split, return as tensor of shape [1, fs, hs]
|
||||
return image_embeds.unsqueeze(0)
|
||||
@ -265,10 +264,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings, [
|
||||
self.vision_args.image_token_id,
|
||||
PIXTRAL_12B_IMAGE_END_ID,
|
||||
PIXTRAL_12B_IMAGE_BREAK_ID,
|
||||
PIXTRAL_LARGE_IMAGE_BREAK_ID,
|
||||
PIXTRAL_LARGE_IMAGE_END_ID,
|
||||
self.vision_args.image_break_token_id,
|
||||
self.vision_args.image_end_token_id,
|
||||
])
|
||||
return inputs_embeds
|
||||
|
||||
@ -409,6 +406,8 @@ class VisionEncoderArgs:
|
||||
num_attention_heads: int
|
||||
rope_theta: float # for rope-2D
|
||||
image_token_id: int
|
||||
image_break_token_id: int
|
||||
image_end_token_id: int
|
||||
adapter_bias: bool = True
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user