[Model] Allow loading from original Mistral format (#8168)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Patrick von Platen 2024-09-07 01:02:05 +02:00 committed by GitHub
parent 23f322297f
commit 29f49cd6e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 291 additions and 81 deletions

View File

@ -41,3 +41,43 @@ def test_models(
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
@pytest.mark.parametrize("model", MODELS[1:])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="auto",
load_format="safetensors",
config_format="hf",
) as hf_format_model:
hf_format_outputs = hf_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="mistral",
load_format="mistral",
config_format="mistral",
) as mistral_format_model:
mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_format_outputs,
outputs_1_lst=mistral_format_outputs,
name_0="hf",
name_1="mistral",
)

View File

@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (get_config, from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
@ -121,35 +121,37 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices, override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments. can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
""" """
def __init__( def __init__(self,
self, model: str,
model: str, tokenizer: str,
tokenizer: str, tokenizer_mode: str,
tokenizer_mode: str, trust_remote_code: bool,
trust_remote_code: bool, dtype: Union[str, torch.dtype],
dtype: Union[str, torch.dtype], seed: int,
seed: int, revision: Optional[str] = None,
revision: Optional[str] = None, code_revision: Optional[str] = None,
code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None,
rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None,
rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None,
max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None, quantization: Optional[str] = None,
quantization: Optional[str] = None, quantization_param_path: Optional[str] = None,
quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None,
enforce_eager: Optional[bool] = None, max_context_len_to_capture: Optional[int] = None,
max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20,
max_logprobs: int = 20, disable_sliding_window: bool = False,
disable_sliding_window: bool = False, skip_tokenizer_init: bool = False,
skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None,
served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True,
use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None,
override_neuron_config: Optional[Dict[str, Any]] = None) -> None: config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
@ -176,7 +178,8 @@ class ModelConfig:
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta) code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision) self.model, revision)
@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
SHARDED_STATE = "sharded_state" SHARDED_STATE = "sharded_state"
GGUF = "gguf" GGUF = "gguf"
BITSANDBYTES = "bitsandbytes" BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
@dataclass @dataclass

View File

@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
EngineConfig, LoadConfig, LoadFormat, LoRAConfig, DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
ModelConfig, ObservabilityConfig, ParallelConfig, LoRAConfig, ModelConfig, ObservabilityConfig,
PromptAdapterConfig, SchedulerConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig) SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
@ -65,6 +65,7 @@ class EngineArgs:
trust_remote_code: bool = False trust_remote_code: bool = False
download_dir: Optional[str] = None download_dir: Optional[str] = None
load_format: str = 'auto' load_format: str = 'auto'
config_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None quantization_param_path: Optional[str] = None
@ -234,6 +235,13 @@ class EngineArgs:
'section for more information.\n' 'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes ' '* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n') 'quantization.\n')
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
choices=[f.value for f in ConfigFormat],
help='The format of the model config to load.\n\n'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format ')
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
@ -813,7 +821,10 @@ class EngineArgs:
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config) override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len self.max_model_len, # neuron needs block_size = max_model_len

View File

@ -17,6 +17,7 @@ import torch
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig, LoRAConfig, ModelConfig, MultiModalConfig,
@ -241,12 +242,17 @@ class DefaultModelLoader(BaseModelLoader):
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format load_format = self.load_config.load_format
use_safetensors = False use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights. # Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO: if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"] allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS: elif load_format == LoadFormat.SAFETENSORS:
use_safetensors = True use_safetensors = True
allow_patterns = ["*.safetensors"] allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT: elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"] allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE: elif load_format == LoadFormat.NPCACHE:
@ -284,10 +290,10 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index. # any files not found in the index.
if not is_local: if not is_local:
download_safetensors_index_file_from_hf( download_safetensors_index_file_from_hf(
model_name_or_path, self.load_config.download_dir, model_name_or_path, index_file,
revision) self.load_config.download_dir, revision)
hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder) hf_weights_files, hf_folder, index_file)
else: else:
hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files) hf_weights_files)

View File

@ -16,7 +16,6 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
@ -251,6 +250,7 @@ def download_weights_from_hf(
def download_safetensors_index_file_from_hf( def download_safetensors_index_file_from_hf(
model_name_or_path: str, model_name_or_path: str,
index_file: str,
cache_dir: Optional[str], cache_dir: Optional[str],
revision: Optional[str] = None, revision: Optional[str] = None,
) -> None: ) -> None:
@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file. # Download the safetensors index file.
hf_hub_download( hf_hub_download(
repo_id=model_name_or_path, repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME, filename=index_file,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=revision, revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
) )
# If file not found on remote or locally, we should not fail since # If file not found on remote or locally, we should not fail since
# only some models will have SAFE_WEIGHTS_INDEX_NAME. # only some models will have index_file.
except huggingface_hub.utils.EntryNotFoundError: except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) logger.info("No %s found in remote.", index_file)
except huggingface_hub.utils.LocalEntryNotFoundError: except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) logger.info("No %s found in local cache.", index_file)
# For models like Mistral-7B-v0.3, there are both sharded # For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file. # safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks. # Passing both of these to the weight loader functionality breaks.
# So, we use the SAFE_WEIGHTS_INDEX_NAME to # So, we use the index_file to
# look up which safetensors files should be used. # look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: List[str], def filter_duplicate_safetensors_files(hf_weights_files: List[str],
hf_folder: str) -> List[str]: hf_folder: str,
index_file: str) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the # model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight. # torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name): if not os.path.isfile(index_file_name):
return hf_weights_files return hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files) # Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use. # to identify weights that we should use.
with open(index_file_name) as index_file: with open(index_file_name, "r") as f:
weight_map = json.load(index_file)["weight_map"] weight_map = json.load(f)["weight_map"]
weight_files_in_index = set() weight_files_in_index = set()
for weight_name in weight_map: for weight_name in weight_map:
weight_files_in_index.add( weight_files_in_index.add(

View File

@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj": ("gate_up_proj", 0), "gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
}
def __init__( def __init__(
self, self,
@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if ("rotary_emb.cos_cached" in name if ("rotary_emb.cos_cached" in name
@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
else: else:
raise RuntimeError("Self attention has no KV cache scaling " raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!") "factor attribute!")
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self, name: str,
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
def permute(w, n_heads):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
mapping = self.mistral_mapping
modules = name.split(".")
# rotary embeds should be sliced
if "wk" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
for item in modules:
if item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight

View File

@ -1,12 +1,16 @@
import contextlib import contextlib
import enum
import json
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Optional, Type, Union
from huggingface_hub import file_exists, hf_hub_download
from transformers import GenerationConfig, PretrainedConfig from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import ( from transformers.models.auto.image_processing_auto import (
get_image_processor_config) get_image_processor_config)
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
@ -27,6 +31,8 @@ if VLLM_USE_MODELSCOPE:
else: else:
from transformers import AutoConfig from transformers import AutoConfig
MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__) logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
@ -53,6 +59,20 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig.register(name, cls) AutoConfig.register(name, cls)
class ConfigFormat(str, enum.Enum):
AUTO = "auto"
HF = "hf"
MISTRAL = "mistral"
def file_or_path_exists(model: Union[str, Path], config_name, revision,
token) -> bool:
if Path(model).exists():
return (Path(model) / config_name).is_file()
return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
def get_config( def get_config(
model: Union[str, Path], model: Union[str, Path],
trust_remote_code: bool, trust_remote_code: bool,
@ -60,45 +80,68 @@ def get_config(
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None, rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
**kwargs, **kwargs,
) -> PretrainedConfig: ) -> PretrainedConfig:
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(model) is_gguf = check_gguf_file(model)
if is_gguf: if is_gguf:
kwargs["gguf_file"] = Path(model).name kwargs["gguf_file"] = Path(model).name
model = Path(model).parent model = Path(model).parent
config_dict, _ = PretrainedConfig.get_config_dict( if config_format == ConfigFormat.AUTO:
model, revision=revision, code_revision=code_revision, **kwargs) if is_gguf or file_or_path_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token")):
config_format = ConfigFormat.HF
elif file_or_path_exists(model,
MISTRAL_CONFIG_NAME,
revision=revision,
token=kwargs.get("token")):
config_format = ConfigFormat.MISTRAL
else:
raise ValueError(f"No supported config format found in {model}")
# Use custom model class if it's in our registry if config_format == ConfigFormat.HF:
model_type = config_dict.get("model_type") config_dict, _ = PretrainedConfig.get_config_dict(
if model_type in _CONFIG_REGISTRY: model, revision=revision, code_revision=code_revision, **kwargs)
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model, # Use custom model class if it's in our registry
revision=revision, model_type = config_dict.get("model_type")
code_revision=code_revision) if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
else:
try:
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
**kwargs,
)
except ValueError as e:
if (not trust_remote_code
and "requires you to execute the configuration file"
in str(e)):
err_msg = (
"Failed to load the model config. If the model "
"is a custom model not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision)
else: else:
try: raise ValueError(f"Unsupported config format: {config_format}")
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
**kwargs)
except ValueError as e:
if (not trust_remote_code
and "requires you to execute the configuration file"
in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
# Special architecture mapping check for GGUF models # Special architecture mapping check for GGUF models
if is_gguf: if is_gguf:
@ -108,16 +151,70 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]}) config.update({"architectures": [model_type]})
for key, value in [("rope_scaling", rope_scaling), for key, value in [
("rope_theta", rope_theta)]: ("rope_scaling", rope_scaling),
("rope_theta", rope_theta),
]:
if value is not None: if value is not None:
logger.info("Updating %s from %r to %r", key, logger.info(
getattr(config, key, None), value) "Updating %s from %r to %r",
key,
getattr(config, key, None),
value,
)
config.update({key: value}) config.update({key: value})
return config return config
def load_params_config(model, revision) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_file_name = "params.json"
config_path = Path(model) / config_file_name
if not config_path.is_file():
config_path = Path(
hf_hub_download(model, config_file_name, revision=revision))
with open(config_path, "r") as file:
config_dict = json.load(file)
config_mapping = {
"dim": "hidden_size",
"norm_eps": "rms_norm_eps",
"n_kv_heads": "num_key_value_heads",
"n_layers": "num_hidden_layers",
"n_heads": "num_attention_heads",
"hidden_dim": "intermediate_size",
}
def recurse_elems(elem: Any):
if isinstance(elem, dict):
config_dict = {}
for key, value in elem.items():
key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value)
return PretrainedConfig(**config_dict)
else:
return elem
config_dict["model_type"] = config_dict.get("model_type", "transformer")
config_dict["hidden_act"] = config_dict.get("activation", "silu")
config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False)
if config_dict["model_type"] == "transformer":
if "moe" in config_dict:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
return recurse_elems(config_dict)
def get_hf_image_processor_config( def get_hf_image_processor_config(
model: Union[str, Path], model: Union[str, Path],
revision: Optional[str] = None, revision: Optional[str] = None,
@ -134,7 +231,7 @@ def get_hf_image_processor_config(
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.
No op for pure text models. No op for pure text models.
""" """
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have # The code operates under the assumption that text_config should have