[Model] Remove hardcoded image tokens ids from Pixtral (#11582)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2024-12-28 02:54:23 -08:00 committed by GitHub
parent d34be24bb1
commit b7dcc003dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,13 +45,6 @@ try:
except ImportError:
USE_XFORMERS_OPS = False
# These token ids cannot be retrieved from model config
# so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15
def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
@ -201,6 +194,13 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if key in dataclass_fields
}
if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
@ -240,9 +240,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
split_indices = torch.where(image_end_condition)[0] + 1
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
@ -265,10 +264,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id,
PIXTRAL_12B_IMAGE_END_ID,
PIXTRAL_12B_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
self.vision_args.image_break_token_id,
self.vision_args.image_end_token_id,
])
return inputs_embeds
@ -409,6 +406,8 @@ class VisionEncoderArgs:
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True