[Model] Allow loading from original Mistral format (#8168)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
23f322297f
commit
29f49cd6e3
@ -41,3 +41,43 @@ def test_models(
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS[1:])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_mistral_format(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="auto",
|
||||
load_format="safetensors",
|
||||
config_format="hf",
|
||||
) as hf_format_model:
|
||||
hf_format_outputs = hf_format_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
load_format="mistral",
|
||||
config_format="mistral",
|
||||
) as mistral_format_model:
|
||||
mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_format_outputs,
|
||||
outputs_1_lst=mistral_format_outputs,
|
||||
name_0="hf",
|
||||
name_1="mistral",
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (get_config,
|
||||
from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||
get_hf_image_processor_config,
|
||||
get_hf_text_config)
|
||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
||||
@ -121,35 +121,37 @@ class ModelConfig:
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
can not be gathered from the vllm arguments.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self,
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@ -176,7 +178,8 @@ class ModelConfig:
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, rope_scaling, rope_theta)
|
||||
code_revision, rope_scaling, rope_theta,
|
||||
config_format)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, revision)
|
||||
@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
|
||||
SHARDED_STATE = "sharded_state"
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
||||
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@ -65,6 +65,7 @@ class EngineArgs:
|
||||
trust_remote_code: bool = False
|
||||
download_dir: Optional[str] = None
|
||||
load_format: str = 'auto'
|
||||
config_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
quantization_param_path: Optional[str] = None
|
||||
@ -234,6 +235,13 @@ class EngineArgs:
|
||||
'section for more information.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
'--config-format',
|
||||
default=EngineArgs.config_format,
|
||||
choices=[f.value for f in ConfigFormat],
|
||||
help='The format of the model config to load.\n\n'
|
||||
'* "auto" will try to load the config in hf format '
|
||||
'if available else it will try to load in mistral format ')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
@ -813,7 +821,10 @@ class EngineArgs:
|
||||
served_model_name=self.served_model_name,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
override_neuron_config=self.override_neuron_config)
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
config_format=self.config_format,
|
||||
)
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size if self.device != "neuron" else
|
||||
self.max_model_len, # neuron needs block_size = max_model_len
|
||||
|
@ -17,6 +17,7 @@ import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM, PretrainedConfig
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
@ -241,12 +242,17 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == LoadFormat.SAFETENSORS:
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.MISTRAL:
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == LoadFormat.PT:
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == LoadFormat.NPCACHE:
|
||||
@ -284,10 +290,10 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, self.load_config.download_dir,
|
||||
revision)
|
||||
model_name_or_path, index_file,
|
||||
self.load_config.download_dir, revision)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder)
|
||||
hf_weights_files, hf_folder, index_file)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
@ -16,7 +16,6 @@ import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
@ -251,6 +250,7 @@ def download_weights_from_hf(
|
||||
|
||||
def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
index_file: str,
|
||||
cache_dir: Optional[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
|
||||
# Download the safetensors index file.
|
||||
hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
||||
filename=index_file,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
|
||||
# only some models will have index_file.
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
|
||||
logger.info("No %s found in remote.", index_file)
|
||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
|
||||
logger.info("No %s found in local cache.", index_file)
|
||||
|
||||
|
||||
# For models like Mistral-7B-v0.3, there are both sharded
|
||||
# safetensors files and a consolidated safetensors file.
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
|
||||
# So, we use the index_file to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
|
||||
hf_folder: str) -> List[str]:
|
||||
hf_folder: str,
|
||||
index_file: str) -> List[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
index_file_name = os.path.join(hf_folder, index_file)
|
||||
if not os.path.isfile(index_file_name):
|
||||
return hf_weights_files
|
||||
|
||||
# Iterate through the weight_map (weight_name: safetensors files)
|
||||
# to identify weights that we should use.
|
||||
with open(index_file_name) as index_file:
|
||||
weight_map = json.load(index_file)["weight_map"]
|
||||
with open(index_file_name, "r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(
|
||||
|
@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||
# from consolidated.safetensors checkpoints
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
# This function is used to remap the mistral format as
|
||||
# used by Mistral and Llama <=2
|
||||
def maybe_remap_mistral(
|
||||
self, name: str,
|
||||
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w, n_heads):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "wk" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif "wq" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
for item in modules:
|
||||
if item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
@ -1,12 +1,16 @@
|
||||
import contextlib
|
||||
import enum
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from huggingface_hub import file_exists, hf_hub_download
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import (
|
||||
get_image_processor_config)
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
|
||||
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
@ -27,6 +31,8 @@ if VLLM_USE_MODELSCOPE:
|
||||
else:
|
||||
from transformers import AutoConfig
|
||||
|
||||
MISTRAL_CONFIG_NAME = "params.json"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
@ -53,6 +59,20 @@ for name, cls in _CONFIG_REGISTRY.items():
|
||||
AutoConfig.register(name, cls)
|
||||
|
||||
|
||||
class ConfigFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
HF = "hf"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
||||
token) -> bool:
|
||||
if Path(model).exists():
|
||||
return (Path(model) / config_name).is_file()
|
||||
|
||||
return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
|
||||
|
||||
|
||||
def get_config(
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
@ -60,45 +80,68 @@ def get_config(
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
**kwargs,
|
||||
) -> PretrainedConfig:
|
||||
|
||||
# Separate model folder from file path for GGUF models
|
||||
|
||||
is_gguf = check_gguf_file(model)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = Path(model).name
|
||||
model = Path(model).parent
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model, revision=revision, code_revision=code_revision, **kwargs)
|
||||
if config_format == ConfigFormat.AUTO:
|
||||
if is_gguf or file_or_path_exists(model,
|
||||
HF_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=kwargs.get("token")):
|
||||
config_format = ConfigFormat.HF
|
||||
elif file_or_path_exists(model,
|
||||
MISTRAL_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=kwargs.get("token")):
|
||||
config_format = ConfigFormat.MISTRAL
|
||||
else:
|
||||
raise ValueError(f"No supported config format found in {model}")
|
||||
|
||||
# Use custom model class if it's in our registry
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
if config_format == ConfigFormat.HF:
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model, revision=revision, code_revision=code_revision, **kwargs)
|
||||
|
||||
# Use custom model class if it's in our registry
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
else:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code
|
||||
and "requires you to execute the configuration file"
|
||||
in str(e)):
|
||||
err_msg = (
|
||||
"Failed to load the model config. If the model "
|
||||
"is a custom model not yet available in the "
|
||||
"HuggingFace transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
config = load_params_config(model, revision)
|
||||
else:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code
|
||||
and "requires you to execute the configuration file"
|
||||
in str(e)):
|
||||
err_msg = (
|
||||
"Failed to load the model config. If the model is a custom "
|
||||
"model not yet available in the HuggingFace transformers "
|
||||
"library, consider setting `trust_remote_code=True` in LLM "
|
||||
"or using the `--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
raise ValueError(f"Unsupported config format: {config_format}")
|
||||
|
||||
# Special architecture mapping check for GGUF models
|
||||
if is_gguf:
|
||||
@ -108,16 +151,70 @@ def get_config(
|
||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||
config.update({"architectures": [model_type]})
|
||||
|
||||
for key, value in [("rope_scaling", rope_scaling),
|
||||
("rope_theta", rope_theta)]:
|
||||
for key, value in [
|
||||
("rope_scaling", rope_scaling),
|
||||
("rope_theta", rope_theta),
|
||||
]:
|
||||
if value is not None:
|
||||
logger.info("Updating %s from %r to %r", key,
|
||||
getattr(config, key, None), value)
|
||||
logger.info(
|
||||
"Updating %s from %r to %r",
|
||||
key,
|
||||
getattr(config, key, None),
|
||||
value,
|
||||
)
|
||||
config.update({key: value})
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def load_params_config(model, revision) -> PretrainedConfig:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
|
||||
config_file_name = "params.json"
|
||||
|
||||
config_path = Path(model) / config_file_name
|
||||
|
||||
if not config_path.is_file():
|
||||
config_path = Path(
|
||||
hf_hub_download(model, config_file_name, revision=revision))
|
||||
|
||||
with open(config_path, "r") as file:
|
||||
config_dict = json.load(file)
|
||||
|
||||
config_mapping = {
|
||||
"dim": "hidden_size",
|
||||
"norm_eps": "rms_norm_eps",
|
||||
"n_kv_heads": "num_key_value_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
"n_heads": "num_attention_heads",
|
||||
"hidden_dim": "intermediate_size",
|
||||
}
|
||||
|
||||
def recurse_elems(elem: Any):
|
||||
if isinstance(elem, dict):
|
||||
config_dict = {}
|
||||
for key, value in elem.items():
|
||||
key = config_mapping.get(key, key)
|
||||
config_dict[key] = recurse_elems(value)
|
||||
return PretrainedConfig(**config_dict)
|
||||
else:
|
||||
return elem
|
||||
|
||||
config_dict["model_type"] = config_dict.get("model_type", "transformer")
|
||||
config_dict["hidden_act"] = config_dict.get("activation", "silu")
|
||||
config_dict["tie_word_embeddings"] = config_dict.get(
|
||||
"tie_embeddings", False)
|
||||
|
||||
if config_dict["model_type"] == "transformer":
|
||||
if "moe" in config_dict:
|
||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||
else:
|
||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||
|
||||
return recurse_elems(config_dict)
|
||||
|
||||
|
||||
def get_hf_image_processor_config(
|
||||
model: Union[str, Path],
|
||||
revision: Optional[str] = None,
|
||||
@ -134,7 +231,7 @@ def get_hf_image_processor_config(
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
No op for pure text models.
|
||||
No op for pure text models.
|
||||
"""
|
||||
if hasattr(config, "text_config"):
|
||||
# The code operates under the assumption that text_config should have
|
||||
|
Loading…
x
Reference in New Issue
Block a user