Adds method to read the pooling types from model's files (#9506)
Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
e036e527a0
commit
aa9078fa03
@ -230,7 +230,7 @@ def quantize_model(model, quant_cfg, calib_dataloader=None):
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
raise EnvironmentError("GPU is required for inference.")
|
raise OSError("GPU is required for inference.")
|
||||||
|
|
||||||
random.seed(RAND_SEED)
|
random.seed(RAND_SEED)
|
||||||
np.random.seed(RAND_SEED)
|
np.random.seed(RAND_SEED)
|
||||||
@ -314,7 +314,7 @@ def main(args):
|
|||||||
|
|
||||||
# Workaround for wo quantization
|
# Workaround for wo quantization
|
||||||
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
||||||
with open(f"{export_path}/config.json", 'r') as f:
|
with open(f"{export_path}/config.json") as f:
|
||||||
tensorrt_llm_config = json.load(f)
|
tensorrt_llm_config = json.load(f)
|
||||||
if args.qformat == "int8_wo":
|
if args.qformat == "int8_wo":
|
||||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
||||||
|
@ -30,6 +30,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
|
|||||||
assert args.limit_mm_per_prompt == expected
|
assert args.limit_mm_per_prompt == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_pooling_config():
|
||||||
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
args = parser.parse_args(["--pooling-type=MEAN"])
|
||||||
|
engine_args = EngineArgs.from_cli_args(args=args)
|
||||||
|
assert engine_args.pooling_type == 'MEAN'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("arg"),
|
("arg"),
|
||||||
[
|
[
|
||||||
|
50
tests/model_executor/test_model_load_with_params.py
Normal file
50
tests/model_executor/test_model_load_with_params.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
|
from vllm.model_executor.models.bert import BertEmbeddingModel
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
MAX_MODEL_LEN = 128
|
||||||
|
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
|
||||||
|
REVISION = os.environ.get("REVISION", "main")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
|
def test_model_loading_with_params(vllm_runner):
|
||||||
|
"""
|
||||||
|
Test parameter weight loading with tp>1.
|
||||||
|
"""
|
||||||
|
with vllm_runner(model_name=MODEL_NAME,
|
||||||
|
revision=REVISION,
|
||||||
|
dtype="float16",
|
||||||
|
max_model_len=MAX_MODEL_LEN) as model:
|
||||||
|
output = model.encode("Write a short story about a robot that"
|
||||||
|
" dreams for the first time.\n")
|
||||||
|
|
||||||
|
model_config = model.model.llm_engine.model_config
|
||||||
|
|
||||||
|
model_tokenizer = model.model.llm_engine.tokenizer
|
||||||
|
|
||||||
|
# asserts on the bert model config file
|
||||||
|
assert model_config.encoder_config["max_seq_length"] == 512
|
||||||
|
assert model_config.encoder_config["do_lower_case"]
|
||||||
|
|
||||||
|
# asserts on the pooling config files
|
||||||
|
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
|
||||||
|
assert model_config.pooler_config.pooling_norm
|
||||||
|
|
||||||
|
# asserts on the tokenizer loaded
|
||||||
|
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5"
|
||||||
|
assert model_tokenizer.tokenizer_config["do_lower_case"]
|
||||||
|
assert model_tokenizer.tokenizer.model_max_length == 512
|
||||||
|
|
||||||
|
model = model.model.llm_engine.model_executor\
|
||||||
|
.driver_worker.model_runner.model
|
||||||
|
assert isinstance(model, BertEmbeddingModel)
|
||||||
|
assert model._pooler.pooling_type == PoolingType.CLS
|
||||||
|
assert model._pooler.normalize
|
||||||
|
# assert output
|
||||||
|
assert output
|
@ -1,6 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("model_id", "expected_task"), [
|
@pytest.mark.parametrize(("model_id", "expected_task"), [
|
||||||
@ -102,6 +104,76 @@ def test_get_sliding_window():
|
|||||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
|
def test_get_pooling_config():
|
||||||
|
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||||
|
minilm_model_config = ModelConfig(
|
||||||
|
model_id,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
minilm_pooling_config = minilm_model_config._init_pooler_config(
|
||||||
|
pooling_type=None,
|
||||||
|
pooling_norm=None,
|
||||||
|
pooling_returned_token_ids=None,
|
||||||
|
pooling_softmax=None,
|
||||||
|
pooling_step_tag_id=None)
|
||||||
|
|
||||||
|
assert minilm_pooling_config.pooling_norm
|
||||||
|
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
|
def test_get_pooling_config_from_args():
|
||||||
|
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||||
|
minilm_model_config = ModelConfig(model_id,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None)
|
||||||
|
|
||||||
|
minilm_pooling_config = minilm_model_config._init_pooler_config(
|
||||||
|
pooling_type='CLS',
|
||||||
|
pooling_norm=True,
|
||||||
|
pooling_returned_token_ids=None,
|
||||||
|
pooling_softmax=None,
|
||||||
|
pooling_step_tag_id=None)
|
||||||
|
|
||||||
|
assert minilm_pooling_config.pooling_norm
|
||||||
|
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
|
def test_get_bert_tokenization_sentence_transformer_config():
|
||||||
|
bge_model_config = ModelConfig(
|
||||||
|
model="BAAI/bge-base-en-v1.5",
|
||||||
|
task="auto",
|
||||||
|
tokenizer="BAAI/bge-base-en-v1.5",
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
bert_bge_model_config = bge_model_config._get_encoder_config()
|
||||||
|
|
||||||
|
assert bert_bge_model_config["max_seq_length"] == 512
|
||||||
|
assert bert_bge_model_config["do_lower_case"]
|
||||||
|
|
||||||
|
|
||||||
def test_rope_customization():
|
def test_rope_customization():
|
||||||
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
|
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
|
||||||
TEST_ROPE_THETA = 16_000_000.0
|
TEST_ROPE_THETA = 16_000_000.0
|
||||||
|
@ -15,6 +15,7 @@ import openai
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from openai.types.completion import Completion
|
from openai.types.completion import Completion
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
@ -515,13 +516,14 @@ def compare_all_settings(model: str,
|
|||||||
ref_result = copy.deepcopy(ref_result)
|
ref_result = copy.deepcopy(ref_result)
|
||||||
compare_result = copy.deepcopy(compare_result)
|
compare_result = copy.deepcopy(compare_result)
|
||||||
if "embedding" in ref_result and method == "encode":
|
if "embedding" in ref_result and method == "encode":
|
||||||
ref_embedding = torch.tensor(ref_result["embedding"])
|
sim = F.cosine_similarity(
|
||||||
compare_embedding = torch.tensor(
|
torch.tensor(ref_result["embedding"]),
|
||||||
compare_result["embedding"])
|
torch.tensor(compare_result["embedding"]),
|
||||||
mse = ((ref_embedding - compare_embedding)**2).mean()
|
dim=0,
|
||||||
assert mse < 1e-6, (
|
)
|
||||||
|
assert sim >= 0.999, (
|
||||||
f"Embedding for {model=} are not the same.\n"
|
f"Embedding for {model=} are not the same.\n"
|
||||||
f"mse={mse}\n")
|
f"cosine_similarity={sim}\n")
|
||||||
del ref_result["embedding"]
|
del ref_result["embedding"]
|
||||||
del compare_result["embedding"]
|
del compare_result["embedding"]
|
||||||
assert ref_result == compare_result, (
|
assert ref_result == compare_result, (
|
||||||
|
@ -13,10 +13,10 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||||
from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
from vllm.transformers_utils.config import (
|
||||||
get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
get_hf_text_config,
|
get_hf_text_config, get_pooling_config,
|
||||||
is_encoder_decoder, uses_mrope)
|
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||||
print_warning_once)
|
print_warning_once)
|
||||||
|
|
||||||
@ -197,6 +197,7 @@ class ModelConfig:
|
|||||||
code_revision, rope_scaling, rope_theta,
|
code_revision, rope_scaling, rope_theta,
|
||||||
config_format)
|
config_format)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
|
self.encoder_config = self._get_encoder_config()
|
||||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||||
self.model, revision)
|
self.model, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
@ -229,7 +230,8 @@ class ModelConfig:
|
|||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||||
spec_target_max_model_len=spec_target_max_model_len)
|
spec_target_max_model_len=spec_target_max_model_len,
|
||||||
|
encoder_config=self.encoder_config)
|
||||||
self.served_model_name = get_served_model_name(model,
|
self.served_model_name = get_served_model_name(model,
|
||||||
served_model_name)
|
served_model_name)
|
||||||
self.multimodal_config = self._init_multimodal_config(
|
self.multimodal_config = self._init_multimodal_config(
|
||||||
@ -273,6 +275,10 @@ class ModelConfig:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _get_encoder_config(self):
|
||||||
|
return get_sentence_transformer_tokenizer_config(
|
||||||
|
self.model, self.revision)
|
||||||
|
|
||||||
def _init_pooler_config(
|
def _init_pooler_config(
|
||||||
self,
|
self,
|
||||||
pooling_type: Optional[str] = None,
|
pooling_type: Optional[str] = None,
|
||||||
@ -282,6 +288,14 @@ class ModelConfig:
|
|||||||
pooling_returned_token_ids: Optional[List[int]] = None
|
pooling_returned_token_ids: Optional[List[int]] = None
|
||||||
) -> Optional["PoolerConfig"]:
|
) -> Optional["PoolerConfig"]:
|
||||||
if self.task == "embedding":
|
if self.task == "embedding":
|
||||||
|
pooling_config = get_pooling_config(self.model, self.revision)
|
||||||
|
if pooling_config is not None:
|
||||||
|
# override if user does not
|
||||||
|
# specifies pooling_type and/or pooling_norm
|
||||||
|
if pooling_type is None:
|
||||||
|
pooling_type = pooling_config["pooling_type"]
|
||||||
|
if pooling_norm is None:
|
||||||
|
pooling_norm = pooling_config["normalize"]
|
||||||
return PoolerConfig(
|
return PoolerConfig(
|
||||||
pooling_type=pooling_type,
|
pooling_type=pooling_type,
|
||||||
pooling_norm=pooling_norm,
|
pooling_norm=pooling_norm,
|
||||||
@ -1795,6 +1809,7 @@ def _get_and_verify_max_len(
|
|||||||
disable_sliding_window: bool,
|
disable_sliding_window: bool,
|
||||||
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
|
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
|
||||||
spec_target_max_model_len: Optional[int] = None,
|
spec_target_max_model_len: Optional[int] = None,
|
||||||
|
encoder_config: Optional[Any] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Get and verify the model's maximum length."""
|
"""Get and verify the model's maximum length."""
|
||||||
derived_max_model_len = float("inf")
|
derived_max_model_len = float("inf")
|
||||||
@ -1877,6 +1892,9 @@ def _get_and_verify_max_len(
|
|||||||
"original_max_position_embeddings"]
|
"original_max_position_embeddings"]
|
||||||
derived_max_model_len *= scaling_factor
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
|
if encoder_config and "max_seq_length" in encoder_config:
|
||||||
|
derived_max_model_len = encoder_config["max_seq_length"]
|
||||||
|
|
||||||
# If the user specified a max length, make sure it is smaller than the
|
# If the user specified a max length, make sure it is smaller than the
|
||||||
# derived length from the HF model config.
|
# derived length from the HF model config.
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
|
@ -16,6 +16,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
|||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
@ -863,7 +864,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--pooling-type',
|
'--pooling-type',
|
||||||
choices=['LAST', 'ALL', 'CLS', 'STEP'],
|
choices=[pt.name for pt in PoolingType],
|
||||||
default=None,
|
default=None,
|
||||||
help='Used to configure the pooling method in the embedding model.'
|
help='Used to configure the pooling method in the embedding model.'
|
||||||
)
|
)
|
||||||
|
@ -16,6 +16,7 @@ class PoolingType(IntEnum):
|
|||||||
ALL = 1
|
ALL = 1
|
||||||
CLS = 2
|
CLS = 2
|
||||||
STEP = 3
|
STEP = 3
|
||||||
|
MEAN = 4
|
||||||
|
|
||||||
|
|
||||||
class Pooler(nn.Module):
|
class Pooler(nn.Module):
|
||||||
@ -27,7 +28,7 @@ class Pooler(nn.Module):
|
|||||||
3. Returns structured results as `PoolerOutput`.
|
3. Returns structured results as `PoolerOutput`.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
pooling_type: The type of pooling to use (LAST, ALL, CLS).
|
pooling_type: The type of pooling to use.
|
||||||
normalize: Whether to normalize the pooled data.
|
normalize: Whether to normalize the pooled data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -97,6 +98,17 @@ class Pooler(nn.Module):
|
|||||||
for prompt_len in prompt_lens:
|
for prompt_len in prompt_lens:
|
||||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||||
offset += prompt_len
|
offset += prompt_len
|
||||||
|
elif self.pooling_type == PoolingType.MEAN:
|
||||||
|
# Calculate mean pooling
|
||||||
|
cumsum = torch.cumsum(hidden_states, dim=0)
|
||||||
|
start_indices = torch.cat([
|
||||||
|
torch.tensor([0], device=hidden_states.device),
|
||||||
|
torch.cumsum(prompt_lens[:-1], dim=0)
|
||||||
|
])
|
||||||
|
end_indices = torch.cumsum(prompt_lens, dim=0)
|
||||||
|
pooled_data = (
|
||||||
|
cumsum[end_indices - 1] - cumsum[start_indices] +
|
||||||
|
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
||||||
elif self.pooling_type == PoolingType.STEP:
|
elif self.pooling_type == PoolingType.STEP:
|
||||||
if self.returned_token_ids is not None and len(
|
if self.returned_token_ids is not None and len(
|
||||||
self.returned_token_ids) > 0:
|
self.returned_token_ids) > 0:
|
||||||
|
@ -6,6 +6,9 @@ from typing import Any, Dict, Optional, Type, Union
|
|||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub import (file_exists, hf_hub_download,
|
from huggingface_hub import (file_exists, hf_hub_download,
|
||||||
try_to_load_from_cache)
|
try_to_load_from_cache)
|
||||||
|
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError)
|
||||||
from transformers import GenerationConfig, PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
from transformers.models.auto.image_processing_auto import (
|
from transformers.models.auto.image_processing_auto import (
|
||||||
get_image_processor_config)
|
get_image_processor_config)
|
||||||
@ -213,7 +216,7 @@ def get_config(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif config_format == ConfigFormat.MISTRAL:
|
elif config_format == ConfigFormat.MISTRAL:
|
||||||
config = load_params_config(model, revision)
|
config = load_params_config(model, revision, token=kwargs.get("token"))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported config format: {config_format}")
|
raise ValueError(f"Unsupported config format: {config_format}")
|
||||||
|
|
||||||
@ -243,6 +246,158 @@ def get_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_file_to_dict(file_name: str,
|
||||||
|
model: Union[str, Path],
|
||||||
|
revision: Optional[str] = 'main',
|
||||||
|
token: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Downloads a file from the Hugging Face Hub and returns
|
||||||
|
its contents as a dictionary.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file_name (str): The name of the file to download.
|
||||||
|
- model (str): The name of the model on the Hugging Face Hub.
|
||||||
|
- revision (str): The specific version of the model.
|
||||||
|
- token (str): The Hugging Face authentication token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- config_dict (dict): A dictionary containing
|
||||||
|
the contents of the downloaded file.
|
||||||
|
"""
|
||||||
|
file_path = Path(model) / file_name
|
||||||
|
|
||||||
|
if file_or_path_exists(model=model,
|
||||||
|
config_name=file_name,
|
||||||
|
revision=revision,
|
||||||
|
token=token):
|
||||||
|
|
||||||
|
if not file_path.is_file():
|
||||||
|
try:
|
||||||
|
hf_hub_file = hf_hub_download(model,
|
||||||
|
file_name,
|
||||||
|
revision=revision)
|
||||||
|
except (RepositoryNotFoundError, RevisionNotFoundError,
|
||||||
|
EntryNotFoundError, LocalEntryNotFoundError) as e:
|
||||||
|
logger.debug("File or repository not found in hf_hub_download",
|
||||||
|
e)
|
||||||
|
return None
|
||||||
|
file_path = Path(hf_hub_file)
|
||||||
|
|
||||||
|
with open(file_path) as file:
|
||||||
|
return json.load(file)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pooling_config(model: str,
|
||||||
|
revision: Optional[str] = 'main',
|
||||||
|
token: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
This function gets the pooling and normalize
|
||||||
|
config from the model - only applies to
|
||||||
|
sentence-transformers models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The name of the Hugging Face model.
|
||||||
|
revision (str, optional): The specific version
|
||||||
|
of the model to use. Defaults to 'main'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the pooling
|
||||||
|
type and whether normalization is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
modules_file_name = "modules.json"
|
||||||
|
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision,
|
||||||
|
token)
|
||||||
|
|
||||||
|
if modules_dict is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
pooling = next((item for item in modules_dict
|
||||||
|
if item["type"] == "sentence_transformers.models.Pooling"),
|
||||||
|
None)
|
||||||
|
normalize = bool(
|
||||||
|
next((item for item in modules_dict
|
||||||
|
if item["type"] == "sentence_transformers.models.Normalize"),
|
||||||
|
False))
|
||||||
|
|
||||||
|
if pooling:
|
||||||
|
|
||||||
|
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||||
|
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision,
|
||||||
|
token)
|
||||||
|
pooling_type_name = next(
|
||||||
|
(item for item, val in pooling_dict.items() if val is True), None)
|
||||||
|
|
||||||
|
if pooling_type_name is not None:
|
||||||
|
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
||||||
|
|
||||||
|
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
|
||||||
|
if "pooling_mode_" in pooling_name:
|
||||||
|
pooling_name = pooling_name.replace("pooling_mode_", "")
|
||||||
|
|
||||||
|
if "_" in pooling_name:
|
||||||
|
pooling_name = pooling_name.split("_")[0]
|
||||||
|
|
||||||
|
if "lasttoken" in pooling_name:
|
||||||
|
pooling_name = "last"
|
||||||
|
|
||||||
|
supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN']
|
||||||
|
pooling_type_name = pooling_name.upper()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if pooling_type_name in supported_pooling_types:
|
||||||
|
return pooling_type_name
|
||||||
|
except NotImplementedError as e:
|
||||||
|
logger.debug("Pooling type not supported", e)
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_sentence_transformer_tokenizer_config(model: str,
|
||||||
|
revision: Optional[str] = 'main',
|
||||||
|
token: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Returns the tokenization configuration dictionary for a
|
||||||
|
given Sentence Transformer BERT model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- model (str): The name of the Sentence Transformer
|
||||||
|
BERT model.
|
||||||
|
- revision (str, optional): The revision of the m
|
||||||
|
odel to use. Defaults to 'main'.
|
||||||
|
- token (str): A Hugging Face access token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- dict: A dictionary containing the configuration parameters
|
||||||
|
for the Sentence Transformer BERT model.
|
||||||
|
"""
|
||||||
|
for config_name in [
|
||||||
|
"sentence_bert_config.json",
|
||||||
|
"sentence_roberta_config.json",
|
||||||
|
"sentence_distilbert_config.json",
|
||||||
|
"sentence_camembert_config.json",
|
||||||
|
"sentence_albert_config.json",
|
||||||
|
"sentence_xlm-roberta_config.json",
|
||||||
|
"sentence_xlnet_config.json",
|
||||||
|
]:
|
||||||
|
encoder_dict = get_hf_file_to_dict(config_name, model, revision, token)
|
||||||
|
if encoder_dict:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not encoder_dict:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
|
||||||
|
return encoder_dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
||||||
"""Try to register HF model configuration class to serialize by value
|
"""Try to register HF model configuration class to serialize by value
|
||||||
|
|
||||||
@ -305,20 +460,15 @@ def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
|||||||
exc_info=e)
|
exc_info=e)
|
||||||
|
|
||||||
|
|
||||||
def load_params_config(model, revision) -> PretrainedConfig:
|
def load_params_config(model: Union[str, Path],
|
||||||
|
revision: Optional[str],
|
||||||
|
token: Optional[str] = None) -> PretrainedConfig:
|
||||||
# This function loads a params.json config which
|
# This function loads a params.json config which
|
||||||
# should be used when loading models in mistral format
|
# should be used when loading models in mistral format
|
||||||
|
|
||||||
config_file_name = "params.json"
|
config_file_name = "params.json"
|
||||||
|
|
||||||
config_path = Path(model) / config_file_name
|
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
|
||||||
|
|
||||||
if not config_path.is_file():
|
|
||||||
config_path = Path(
|
|
||||||
hf_hub_download(model, config_file_name, revision=revision))
|
|
||||||
|
|
||||||
with open(config_path) as file:
|
|
||||||
config_dict = json.load(file)
|
|
||||||
|
|
||||||
config_mapping = {
|
config_mapping = {
|
||||||
"dim": "hidden_size",
|
"dim": "hidden_size",
|
||||||
|
@ -25,6 +25,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig,
|
|||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
revision=model_config.tokenizer_revision)
|
revision=model_config.tokenizer_revision)
|
||||||
|
|
||||||
|
if (model_config.encoder_config is not None
|
||||||
|
and "do_lower_case" in model_config.encoder_config):
|
||||||
|
init_kwargs["do_lower_case"] = model_config.encoder_config[
|
||||||
|
"do_lower_case"]
|
||||||
|
|
||||||
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
|
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
|
||||||
**init_kwargs)
|
**init_kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user