Bump transformers version for Llama 3.1 hotfix and patch Chameleon (#6690)

This commit is contained in:
Roger Wang 2024-07-23 13:47:48 -07:00 committed by GitHub
parent 507ef787d8
commit 1bedf210e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 32 additions and 177 deletions

View File

@ -6,7 +6,7 @@ numpy < 2.0.0
requests requests
tqdm tqdm
py-cpuinfo py-cpuinfo
transformers >= 4.42.4 # Required for Gemma 2 and for additional chat template parameters. transformers >= 4.43.1 # Required for Chameleon and Llama 3.1 hotfox.
tokenizers >= 0.19.1 # Required for Llama 3. tokenizers >= 0.19.1 # Required for Llama 3.
fastapi fastapi
aiohttp aiohttp

View File

@ -64,9 +64,8 @@ def test_get_sliding_window():
def test_rope_customization(): def test_rope_customization():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0 TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
llama_model_config = ModelConfig( llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct",
@ -96,27 +95,29 @@ def test_rope_customization():
None) == TEST_ROPE_THETA None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384 assert llama_model_config.max_model_len == 16384
longchat_model_config = ModelConfig( # TODO: add these back when the rope configs are fixed
"lmsys/longchat-13b-16k", # LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
"lmsys/longchat-13b-16k", # longchat_model_config = ModelConfig(
tokenizer_mode="auto", # "lmsys/longchat-13b-16k",
trust_remote_code=False, # "lmsys/longchat-13b-16k",
dtype="float16", # tokenizer_mode="auto",
seed=0, # trust_remote_code=False,
) # dtype="float16",
assert getattr(longchat_model_config.hf_config, "rope_scaling", # seed=0,
None) == LONGCHAT_ROPE_SCALING # )
assert longchat_model_config.max_model_len == 16384 # assert getattr(longchat_model_config.hf_config, "rope_scaling",
# None) == LONGCHAT_ROPE_SCALING
# assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig( # longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k", # "lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k", # "lmsys/longchat-13b-16k",
tokenizer_mode="auto", # tokenizer_mode="auto",
trust_remote_code=False, # trust_remote_code=False,
dtype="float16", # dtype="float16",
seed=0, # seed=0,
rope_scaling=TEST_ROPE_SCALING, # rope_scaling=TEST_ROPE_SCALING,
) # )
assert getattr(longchat_model_config.hf_config, "rope_scaling", # assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING # None) == TEST_ROPE_SCALING
assert longchat_model_config.max_model_len == 4096 # assert longchat_model_config.max_model_len == 4096

View File

@ -16,8 +16,6 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
#TODO(ywang96): remove this when huggingface fixes the model repo
"ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"),
"ChameleonForConditionalGeneration": "ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),

View File

@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
@ -30,8 +31,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer, from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens) repeat_and_pad_image_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.transformers_utils.configs import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsVision from .interfaces import SupportsVision

View File

@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChameleonConfig, ChatGLMConfig, from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
DbrxConfig, JAISConfig, JAISConfig, MedusaConfig,
MedusaConfig, MLPSpeculatorConfig, MLPSpeculatorConfig, MPTConfig,
MPTConfig, RWConfig) RWConfig)
if VLLM_USE_MODELSCOPE: if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig from modelscope import AutoConfig
@ -18,7 +18,6 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chameleon": ChameleonConfig,
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,

View File

@ -1,5 +1,3 @@
from vllm.transformers_utils.configs.chameleon import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and # RWConfig is for the original tiiuae/falcon-40b(-instruct) and
@ -12,8 +10,6 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
__all__ = [ __all__ = [
"ChameleonConfig",
"ChameleonVQVAEConfig",
"ChatGLMConfig", "ChatGLMConfig",
"DbrxConfig", "DbrxConfig",
"MPTConfig", "MPTConfig",

View File

@ -1,138 +0,0 @@
from typing import List, Optional
from transformers import PretrainedConfig
#TODO (ywang96): Remove this file and import it from
# transformers once the new release with Chameleon support
# is available.
class ChameleonConfig(PretrainedConfig):
model_type = "chameleon"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=65536,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
model_parallel_size=1,
swin_norm=False,
vq_config=None,
vocabulary_map=None,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.mlp_bias = mlp_bias
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.model_parallel_size = model_parallel_size
self.swin_norm = swin_norm
if vq_config is None:
vq_config = {}
self.vq_config = ChameleonVQVAEConfig(**vq_config)
self.vocabulary_map = vocabulary_map
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling,
dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, "
f"`type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in [
"linear", "dynamic"
]:
raise ValueError(
"`rope_scaling`'s type field must be one of ['linear', "
f"'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1, "
f"got {rope_scaling_factor}")
class ChameleonVQVAEConfig(PretrainedConfig):
model_type = "chameleon_vqgan"
def __init__(
self,
embed_dim: int = 256,
num_embeddings: int = 8192,
double_latent: bool = False,
latent_channels: int = 256,
resolution: int = 512,
in_channels: int = 3,
base_channels: int = 128,
channel_multiplier: List[int] = [1, 1, 2, 2, 4], #noqa
num_res_blocks: int = 2,
attn_resolutions: Optional[List[int]] = None,
dropout: float = 0.0,
attn_type: str = "vanilla",
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_embeddings = num_embeddings
self.double_latent = double_latent
self.latent_channels = latent_channels
self.resolution = resolution
self.in_channels = in_channels
self.base_channels = base_channels
self.channel_multiplier = channel_multiplier
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.dropout = dropout
self.attn_type = attn_type
self.initializer_range = initializer_range