[Model] Allow loading from original Mistral format (#8168)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
23f322297f
commit
29f49cd6e3
@ -41,3 +41,43 @@ def test_models(
|
|||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS[1:])
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_mistral_format(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
load_format="safetensors",
|
||||||
|
config_format="hf",
|
||||||
|
) as hf_format_model:
|
||||||
|
hf_format_outputs = hf_format_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
tokenizer_mode="mistral",
|
||||||
|
load_format="mistral",
|
||||||
|
config_format="mistral",
|
||||||
|
) as mistral_format_model:
|
||||||
|
mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_format_outputs,
|
||||||
|
outputs_1_lst=mistral_format_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="mistral",
|
||||||
|
)
|
||||||
|
@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||||
from vllm.transformers_utils.config import (get_config,
|
from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||||
get_hf_image_processor_config,
|
get_hf_image_processor_config,
|
||||||
get_hf_text_config)
|
get_hf_text_config)
|
||||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
||||||
@ -121,10 +121,11 @@ class ModelConfig:
|
|||||||
override default neuron config that are specific to Neuron devices,
|
override default neuron config that are specific to Neuron devices,
|
||||||
this argument will be used to configure the neuron config that
|
this argument will be used to configure the neuron config that
|
||||||
can not be gathered from the vllm arguments.
|
can not be gathered from the vllm arguments.
|
||||||
|
config_format: The config format which shall be loaded.
|
||||||
|
Defaults to 'auto' which defaults to 'hf'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
@ -149,7 +150,8 @@ class ModelConfig:
|
|||||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||||
use_async_output_proc: bool = True,
|
use_async_output_proc: bool = True,
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
|
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||||
|
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
@ -176,7 +178,8 @@ class ModelConfig:
|
|||||||
self.skip_tokenizer_init = skip_tokenizer_init
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
|
|
||||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
code_revision, rope_scaling, rope_theta)
|
code_revision, rope_scaling, rope_theta,
|
||||||
|
config_format)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||||
self.model, revision)
|
self.model, revision)
|
||||||
@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
SHARDED_STATE = "sharded_state"
|
SHARDED_STATE = "sharded_state"
|
||||||
GGUF = "gguf"
|
GGUF = "gguf"
|
||||||
BITSANDBYTES = "bitsandbytes"
|
BITSANDBYTES = "bitsandbytes"
|
||||||
|
MISTRAL = "mistral"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
||||||
EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
|
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
|
||||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig,
|
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||||
SpeculativeConfig, TokenizerPoolConfig)
|
SpeculativeConfig, TokenizerPoolConfig)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -65,6 +65,7 @@ class EngineArgs:
|
|||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
load_format: str = 'auto'
|
load_format: str = 'auto'
|
||||||
|
config_format: str = 'auto'
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
kv_cache_dtype: str = 'auto'
|
kv_cache_dtype: str = 'auto'
|
||||||
quantization_param_path: Optional[str] = None
|
quantization_param_path: Optional[str] = None
|
||||||
@ -234,6 +235,13 @@ class EngineArgs:
|
|||||||
'section for more information.\n'
|
'section for more information.\n'
|
||||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||||
'quantization.\n')
|
'quantization.\n')
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-format',
|
||||||
|
default=EngineArgs.config_format,
|
||||||
|
choices=[f.value for f in ConfigFormat],
|
||||||
|
help='The format of the model config to load.\n\n'
|
||||||
|
'* "auto" will try to load the config in hf format '
|
||||||
|
'if available else it will try to load in mistral format ')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
type=str,
|
type=str,
|
||||||
@ -813,7 +821,10 @@ class EngineArgs:
|
|||||||
served_model_name=self.served_model_name,
|
served_model_name=self.served_model_name,
|
||||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||||
use_async_output_proc=not self.disable_async_output_proc,
|
use_async_output_proc=not self.disable_async_output_proc,
|
||||||
override_neuron_config=self.override_neuron_config)
|
override_neuron_config=self.override_neuron_config,
|
||||||
|
config_format=self.config_format,
|
||||||
|
)
|
||||||
|
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=self.block_size if self.device != "neuron" else
|
block_size=self.block_size if self.device != "neuron" else
|
||||||
self.max_model_len, # neuron needs block_size = max_model_len
|
self.max_model_len, # neuron needs block_size = max_model_len
|
||||||
|
@ -17,6 +17,7 @@ import torch
|
|||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModelForCausalLM, PretrainedConfig
|
from transformers import AutoModelForCausalLM, PretrainedConfig
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||||
@ -241,12 +242,17 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
load_format = self.load_config.load_format
|
load_format = self.load_config.load_format
|
||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
|
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||||
# Some quantized models use .pt files for storing the weights.
|
# Some quantized models use .pt files for storing the weights.
|
||||||
if load_format == LoadFormat.AUTO:
|
if load_format == LoadFormat.AUTO:
|
||||||
allow_patterns = ["*.safetensors", "*.bin"]
|
allow_patterns = ["*.safetensors", "*.bin"]
|
||||||
elif load_format == LoadFormat.SAFETENSORS:
|
elif load_format == LoadFormat.SAFETENSORS:
|
||||||
use_safetensors = True
|
use_safetensors = True
|
||||||
allow_patterns = ["*.safetensors"]
|
allow_patterns = ["*.safetensors"]
|
||||||
|
elif load_format == LoadFormat.MISTRAL:
|
||||||
|
use_safetensors = True
|
||||||
|
allow_patterns = ["consolidated*.safetensors"]
|
||||||
|
index_file = "consolidated.safetensors.index.json"
|
||||||
elif load_format == LoadFormat.PT:
|
elif load_format == LoadFormat.PT:
|
||||||
allow_patterns = ["*.pt"]
|
allow_patterns = ["*.pt"]
|
||||||
elif load_format == LoadFormat.NPCACHE:
|
elif load_format == LoadFormat.NPCACHE:
|
||||||
@ -284,10 +290,10 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
# any files not found in the index.
|
# any files not found in the index.
|
||||||
if not is_local:
|
if not is_local:
|
||||||
download_safetensors_index_file_from_hf(
|
download_safetensors_index_file_from_hf(
|
||||||
model_name_or_path, self.load_config.download_dir,
|
model_name_or_path, index_file,
|
||||||
revision)
|
self.load_config.download_dir, revision)
|
||||||
hf_weights_files = filter_duplicate_safetensors_files(
|
hf_weights_files = filter_duplicate_safetensors_files(
|
||||||
hf_weights_files, hf_folder)
|
hf_weights_files, hf_folder, index_file)
|
||||||
else:
|
else:
|
||||||
hf_weights_files = filter_files_not_needed_for_inference(
|
hf_weights_files = filter_files_not_needed_for_inference(
|
||||||
hf_weights_files)
|
hf_weights_files)
|
||||||
|
@ -16,7 +16,6 @@ import torch
|
|||||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||||
from safetensors.torch import load_file, safe_open, save_file
|
from safetensors.torch import load_file, safe_open, save_file
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
||||||
|
|
||||||
from vllm.config import LoadConfig, ModelConfig
|
from vllm.config import LoadConfig, ModelConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
@ -251,6 +250,7 @@ def download_weights_from_hf(
|
|||||||
|
|
||||||
def download_safetensors_index_file_from_hf(
|
def download_safetensors_index_file_from_hf(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
index_file: str,
|
||||||
cache_dir: Optional[str],
|
cache_dir: Optional[str],
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
|
|||||||
# Download the safetensors index file.
|
# Download the safetensors index file.
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
repo_id=model_name_or_path,
|
repo_id=model_name_or_path,
|
||||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
filename=index_file,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
)
|
)
|
||||||
# If file not found on remote or locally, we should not fail since
|
# If file not found on remote or locally, we should not fail since
|
||||||
# only some models will have SAFE_WEIGHTS_INDEX_NAME.
|
# only some models will have index_file.
|
||||||
except huggingface_hub.utils.EntryNotFoundError:
|
except huggingface_hub.utils.EntryNotFoundError:
|
||||||
logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
|
logger.info("No %s found in remote.", index_file)
|
||||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||||
logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
|
logger.info("No %s found in local cache.", index_file)
|
||||||
|
|
||||||
|
|
||||||
# For models like Mistral-7B-v0.3, there are both sharded
|
# For models like Mistral-7B-v0.3, there are both sharded
|
||||||
# safetensors files and a consolidated safetensors file.
|
# safetensors files and a consolidated safetensors file.
|
||||||
# Passing both of these to the weight loader functionality breaks.
|
# Passing both of these to the weight loader functionality breaks.
|
||||||
# So, we use the SAFE_WEIGHTS_INDEX_NAME to
|
# So, we use the index_file to
|
||||||
# look up which safetensors files should be used.
|
# look up which safetensors files should be used.
|
||||||
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
|
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
|
||||||
hf_folder: str) -> List[str]:
|
hf_folder: str,
|
||||||
|
index_file: str) -> List[str]:
|
||||||
# model.safetensors.index.json is a mapping from keys in the
|
# model.safetensors.index.json is a mapping from keys in the
|
||||||
# torch state_dict to safetensors file holding that weight.
|
# torch state_dict to safetensors file holding that weight.
|
||||||
index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
|
index_file_name = os.path.join(hf_folder, index_file)
|
||||||
if not os.path.isfile(index_file_name):
|
if not os.path.isfile(index_file_name):
|
||||||
return hf_weights_files
|
return hf_weights_files
|
||||||
|
|
||||||
# Iterate through the weight_map (weight_name: safetensors files)
|
# Iterate through the weight_map (weight_name: safetensors files)
|
||||||
# to identify weights that we should use.
|
# to identify weights that we should use.
|
||||||
with open(index_file_name) as index_file:
|
with open(index_file_name, "r") as f:
|
||||||
weight_map = json.load(index_file)["weight_map"]
|
weight_map = json.load(f)["weight_map"]
|
||||||
weight_files_in_index = set()
|
weight_files_in_index = set()
|
||||||
for weight_name in weight_map:
|
for weight_name in weight_map:
|
||||||
weight_files_in_index.add(
|
weight_files_in_index.add(
|
||||||
|
@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
"gate_proj": ("gate_up_proj", 0),
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
"up_proj": ("gate_up_proj", 1),
|
"up_proj": ("gate_up_proj", 1),
|
||||||
}
|
}
|
||||||
|
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||||
|
# from consolidated.safetensors checkpoints
|
||||||
|
mistral_mapping = {
|
||||||
|
"layers": "model.layers",
|
||||||
|
"attention": "self_attn",
|
||||||
|
"wq": "q_proj",
|
||||||
|
"wk": "k_proj",
|
||||||
|
"wv": "v_proj",
|
||||||
|
"wo": "o_proj",
|
||||||
|
"attention_norm": "input_layernorm",
|
||||||
|
"feed_forward": "mlp",
|
||||||
|
"w1": "gate_proj",
|
||||||
|
"w2": "down_proj",
|
||||||
|
"w3": "up_proj",
|
||||||
|
"ffn_norm": "post_attention_layernorm",
|
||||||
|
"tok_embeddings": "model.embed_tokens",
|
||||||
|
"output": "lm_head",
|
||||||
|
"norm": "model.norm"
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
|
||||||
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
if ("rotary_emb.cos_cached" in name
|
if ("rotary_emb.cos_cached" in name
|
||||||
@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError("Self attention has no KV cache scaling "
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
"factor attribute!")
|
"factor attribute!")
|
||||||
|
|
||||||
|
# This function is used to remap the mistral format as
|
||||||
|
# used by Mistral and Llama <=2
|
||||||
|
def maybe_remap_mistral(
|
||||||
|
self, name: str,
|
||||||
|
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
|
||||||
|
|
||||||
|
def permute(w, n_heads):
|
||||||
|
attn_in = self.config.head_dim * n_heads
|
||||||
|
attn_out = self.config.hidden_size
|
||||||
|
|
||||||
|
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||||
|
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||||
|
|
||||||
|
mapping = self.mistral_mapping
|
||||||
|
modules = name.split(".")
|
||||||
|
|
||||||
|
# rotary embeds should be sliced
|
||||||
|
if "wk" in modules:
|
||||||
|
loaded_weight = permute(loaded_weight,
|
||||||
|
self.config.num_key_value_heads)
|
||||||
|
elif "wq" in modules:
|
||||||
|
loaded_weight = permute(loaded_weight,
|
||||||
|
self.config.num_attention_heads)
|
||||||
|
|
||||||
|
for item in modules:
|
||||||
|
if item in mapping and mapping[item] not in name:
|
||||||
|
name = name.replace(item, mapping[item])
|
||||||
|
|
||||||
|
return name, loaded_weight
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Type, Union
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
from huggingface_hub import file_exists, hf_hub_download
|
||||||
from transformers import GenerationConfig, PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
from transformers.models.auto.image_processing_auto import (
|
from transformers.models.auto.image_processing_auto import (
|
||||||
get_image_processor_config)
|
get_image_processor_config)
|
||||||
from transformers.models.auto.modeling_auto import (
|
from transformers.models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||||
|
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
|
||||||
|
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -27,6 +31,8 @@ if VLLM_USE_MODELSCOPE:
|
|||||||
else:
|
else:
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
MISTRAL_CONFIG_NAME = "params.json"
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||||
@ -53,6 +59,20 @@ for name, cls in _CONFIG_REGISTRY.items():
|
|||||||
AutoConfig.register(name, cls)
|
AutoConfig.register(name, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigFormat(str, enum.Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
HF = "hf"
|
||||||
|
MISTRAL = "mistral"
|
||||||
|
|
||||||
|
|
||||||
|
def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
||||||
|
token) -> bool:
|
||||||
|
if Path(model).exists():
|
||||||
|
return (Path(model) / config_name).is_file()
|
||||||
|
|
||||||
|
return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
model: Union[str, Path],
|
model: Union[str, Path],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -60,15 +80,31 @@ def get_config(
|
|||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
rope_scaling: Optional[dict] = None,
|
rope_scaling: Optional[dict] = None,
|
||||||
rope_theta: Optional[float] = None,
|
rope_theta: Optional[float] = None,
|
||||||
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> PretrainedConfig:
|
) -> PretrainedConfig:
|
||||||
|
|
||||||
# Separate model folder from file path for GGUF models
|
# Separate model folder from file path for GGUF models
|
||||||
|
|
||||||
is_gguf = check_gguf_file(model)
|
is_gguf = check_gguf_file(model)
|
||||||
if is_gguf:
|
if is_gguf:
|
||||||
kwargs["gguf_file"] = Path(model).name
|
kwargs["gguf_file"] = Path(model).name
|
||||||
model = Path(model).parent
|
model = Path(model).parent
|
||||||
|
|
||||||
|
if config_format == ConfigFormat.AUTO:
|
||||||
|
if is_gguf or file_or_path_exists(model,
|
||||||
|
HF_CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
token=kwargs.get("token")):
|
||||||
|
config_format = ConfigFormat.HF
|
||||||
|
elif file_or_path_exists(model,
|
||||||
|
MISTRAL_CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
token=kwargs.get("token")):
|
||||||
|
config_format = ConfigFormat.MISTRAL
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No supported config format found in {model}")
|
||||||
|
|
||||||
|
if config_format == ConfigFormat.HF:
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model, revision=revision, code_revision=code_revision, **kwargs)
|
model, revision=revision, code_revision=code_revision, **kwargs)
|
||||||
|
|
||||||
@ -86,20 +122,27 @@ def get_config(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if (not trust_remote_code
|
if (not trust_remote_code
|
||||||
and "requires you to execute the configuration file"
|
and "requires you to execute the configuration file"
|
||||||
in str(e)):
|
in str(e)):
|
||||||
err_msg = (
|
err_msg = (
|
||||||
"Failed to load the model config. If the model is a custom "
|
"Failed to load the model config. If the model "
|
||||||
"model not yet available in the HuggingFace transformers "
|
"is a custom model not yet available in the "
|
||||||
"library, consider setting `trust_remote_code=True` in LLM "
|
"HuggingFace transformers library, consider setting "
|
||||||
"or using the `--trust-remote-code` flag in the CLI.")
|
"`trust_remote_code=True` in LLM or using the "
|
||||||
|
"`--trust-remote-code` flag in the CLI.")
|
||||||
raise RuntimeError(err_msg) from e
|
raise RuntimeError(err_msg) from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
elif config_format == ConfigFormat.MISTRAL:
|
||||||
|
config = load_params_config(model, revision)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported config format: {config_format}")
|
||||||
|
|
||||||
# Special architecture mapping check for GGUF models
|
# Special architecture mapping check for GGUF models
|
||||||
if is_gguf:
|
if is_gguf:
|
||||||
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
@ -108,16 +151,70 @@ def get_config(
|
|||||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||||
config.update({"architectures": [model_type]})
|
config.update({"architectures": [model_type]})
|
||||||
|
|
||||||
for key, value in [("rope_scaling", rope_scaling),
|
for key, value in [
|
||||||
("rope_theta", rope_theta)]:
|
("rope_scaling", rope_scaling),
|
||||||
|
("rope_theta", rope_theta),
|
||||||
|
]:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
logger.info("Updating %s from %r to %r", key,
|
logger.info(
|
||||||
getattr(config, key, None), value)
|
"Updating %s from %r to %r",
|
||||||
|
key,
|
||||||
|
getattr(config, key, None),
|
||||||
|
value,
|
||||||
|
)
|
||||||
config.update({key: value})
|
config.update({key: value})
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def load_params_config(model, revision) -> PretrainedConfig:
|
||||||
|
# This function loads a params.json config which
|
||||||
|
# should be used when loading models in mistral format
|
||||||
|
|
||||||
|
config_file_name = "params.json"
|
||||||
|
|
||||||
|
config_path = Path(model) / config_file_name
|
||||||
|
|
||||||
|
if not config_path.is_file():
|
||||||
|
config_path = Path(
|
||||||
|
hf_hub_download(model, config_file_name, revision=revision))
|
||||||
|
|
||||||
|
with open(config_path, "r") as file:
|
||||||
|
config_dict = json.load(file)
|
||||||
|
|
||||||
|
config_mapping = {
|
||||||
|
"dim": "hidden_size",
|
||||||
|
"norm_eps": "rms_norm_eps",
|
||||||
|
"n_kv_heads": "num_key_value_heads",
|
||||||
|
"n_layers": "num_hidden_layers",
|
||||||
|
"n_heads": "num_attention_heads",
|
||||||
|
"hidden_dim": "intermediate_size",
|
||||||
|
}
|
||||||
|
|
||||||
|
def recurse_elems(elem: Any):
|
||||||
|
if isinstance(elem, dict):
|
||||||
|
config_dict = {}
|
||||||
|
for key, value in elem.items():
|
||||||
|
key = config_mapping.get(key, key)
|
||||||
|
config_dict[key] = recurse_elems(value)
|
||||||
|
return PretrainedConfig(**config_dict)
|
||||||
|
else:
|
||||||
|
return elem
|
||||||
|
|
||||||
|
config_dict["model_type"] = config_dict.get("model_type", "transformer")
|
||||||
|
config_dict["hidden_act"] = config_dict.get("activation", "silu")
|
||||||
|
config_dict["tie_word_embeddings"] = config_dict.get(
|
||||||
|
"tie_embeddings", False)
|
||||||
|
|
||||||
|
if config_dict["model_type"] == "transformer":
|
||||||
|
if "moe" in config_dict:
|
||||||
|
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||||
|
else:
|
||||||
|
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||||
|
|
||||||
|
return recurse_elems(config_dict)
|
||||||
|
|
||||||
|
|
||||||
def get_hf_image_processor_config(
|
def get_hf_image_processor_config(
|
||||||
model: Union[str, Path],
|
model: Union[str, Path],
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user