From aa9078fa035abfac54179cbdca8b741e49c8cd0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1via=20B=C3=A9o?= <119421251+flaviabeo@users.noreply.github.com> Date: Thu, 7 Nov 2024 05:42:40 -0300 Subject: [PATCH] Adds method to read the pooling types from model's files (#9506) Signed-off-by: Flavia Beo Signed-off-by: Max de Bayser Co-authored-by: Max de Bayser --- examples/fp8/quantizer/quantize.py | 4 +- tests/engine/test_arg_utils.py | 7 + .../test_model_load_with_params.py | 50 ++++++ tests/test_config.py | 72 ++++++++ tests/utils.py | 14 +- vllm/config.py | 28 ++- vllm/engine/arg_utils.py | 3 +- vllm/model_executor/layers/pooler.py | 14 +- vllm/transformers_utils/config.py | 170 ++++++++++++++++-- .../tokenizer_group/__init__.py | 5 + 10 files changed, 342 insertions(+), 25 deletions(-) create mode 100644 tests/model_executor/test_model_load_with_params.py diff --git a/examples/fp8/quantizer/quantize.py b/examples/fp8/quantizer/quantize.py index 15f1a06b..d75cc8b3 100644 --- a/examples/fp8/quantizer/quantize.py +++ b/examples/fp8/quantizer/quantize.py @@ -230,7 +230,7 @@ def quantize_model(model, quant_cfg, calib_dataloader=None): def main(args): if not torch.cuda.is_available(): - raise EnvironmentError("GPU is required for inference.") + raise OSError("GPU is required for inference.") random.seed(RAND_SEED) np.random.seed(RAND_SEED) @@ -314,7 +314,7 @@ def main(args): # Workaround for wo quantization if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: - with open(f"{export_path}/config.json", 'r') as f: + with open(f"{export_path}/config.json") as f: tensorrt_llm_config = json.load(f) if args.qformat == "int8_wo": tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index f7dc167f..e92e2588 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -30,6 +30,13 @@ def test_limit_mm_per_prompt_parser(arg, expected): assert args.limit_mm_per_prompt == expected +def test_valid_pooling_config(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args(["--pooling-type=MEAN"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.pooling_type == 'MEAN' + + @pytest.mark.parametrize( ("arg"), [ diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py new file mode 100644 index 00000000..7e5e2780 --- /dev/null +++ b/tests/model_executor/test_model_load_with_params.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from vllm.model_executor.layers.pooler import PoolingType +from vllm.model_executor.models.bert import BertEmbeddingModel +from vllm.platforms import current_platform + +MAX_MODEL_LEN = 128 +MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") +REVISION = os.environ.get("REVISION", "main") + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_model_loading_with_params(vllm_runner): + """ + Test parameter weight loading with tp>1. + """ + with vllm_runner(model_name=MODEL_NAME, + revision=REVISION, + dtype="float16", + max_model_len=MAX_MODEL_LEN) as model: + output = model.encode("Write a short story about a robot that" + " dreams for the first time.\n") + + model_config = model.model.llm_engine.model_config + + model_tokenizer = model.model.llm_engine.tokenizer + + # asserts on the bert model config file + assert model_config.encoder_config["max_seq_length"] == 512 + assert model_config.encoder_config["do_lower_case"] + + # asserts on the pooling config files + assert model_config.pooler_config.pooling_type == PoolingType.CLS.name + assert model_config.pooler_config.pooling_norm + + # asserts on the tokenizer loaded + assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" + assert model_tokenizer.tokenizer_config["do_lower_case"] + assert model_tokenizer.tokenizer.model_max_length == 512 + + model = model.model.llm_engine.model_executor\ + .driver_worker.model_runner.model + assert isinstance(model, BertEmbeddingModel) + assert model._pooler.pooling_type == PoolingType.CLS + assert model._pooler.normalize + # assert output + assert output diff --git a/tests/test_config.py b/tests/test_config.py index 5211049b..66bdb883 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,8 @@ import pytest from vllm.config import ModelConfig +from vllm.model_executor.layers.pooler import PoolingType +from vllm.platforms import current_platform @pytest.mark.parametrize(("model_id", "expected_task"), [ @@ -102,6 +104,76 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_get_pooling_config(): + model_id = "sentence-transformers/all-MiniLM-L12-v2" + minilm_model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + minilm_pooling_config = minilm_model_config._init_pooler_config( + pooling_type=None, + pooling_norm=None, + pooling_returned_token_ids=None, + pooling_softmax=None, + pooling_step_tag_id=None) + + assert minilm_pooling_config.pooling_norm + assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_get_pooling_config_from_args(): + model_id = "sentence-transformers/all-MiniLM-L12-v2" + minilm_model_config = ModelConfig(model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None) + + minilm_pooling_config = minilm_model_config._init_pooler_config( + pooling_type='CLS', + pooling_norm=True, + pooling_returned_token_ids=None, + pooling_softmax=None, + pooling_step_tag_id=None) + + assert minilm_pooling_config.pooling_norm + assert minilm_pooling_config.pooling_type == PoolingType.CLS.name + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_get_bert_tokenization_sentence_transformer_config(): + bge_model_config = ModelConfig( + model="BAAI/bge-base-en-v1.5", + task="auto", + tokenizer="BAAI/bge-base-en-v1.5", + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + bert_bge_model_config = bge_model_config._get_encoder_config() + + assert bert_bge_model_config["max_seq_length"] == 512 + assert bert_bge_model_config["do_lower_case"] + + def test_rope_customization(): TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0} TEST_ROPE_THETA = 16_000_000.0 diff --git a/tests/utils.py b/tests/utils.py index 00c7dabe..a893667e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ import openai import pytest import requests import torch +import torch.nn.functional as F from openai.types.completion import Completion from typing_extensions import ParamSpec @@ -515,13 +516,14 @@ def compare_all_settings(model: str, ref_result = copy.deepcopy(ref_result) compare_result = copy.deepcopy(compare_result) if "embedding" in ref_result and method == "encode": - ref_embedding = torch.tensor(ref_result["embedding"]) - compare_embedding = torch.tensor( - compare_result["embedding"]) - mse = ((ref_embedding - compare_embedding)**2).mean() - assert mse < 1e-6, ( + sim = F.cosine_similarity( + torch.tensor(ref_result["embedding"]), + torch.tensor(compare_result["embedding"]), + dim=0, + ) + assert sim >= 0.999, ( f"Embedding for {model=} are not the same.\n" - f"mse={mse}\n") + f"cosine_similarity={sim}\n") del ref_result["embedding"] del compare_result["embedding"] assert ref_result == compare_result, ( diff --git a/vllm/config.py b/vllm/config.py index c7fad3a2..e844a46b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,10 +13,10 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import (ConfigFormat, get_config, - get_hf_image_processor_config, - get_hf_text_config, - is_encoder_decoder, uses_mrope) +from vllm.transformers_utils.config import ( + ConfigFormat, get_config, get_hf_image_processor_config, + get_hf_text_config, get_pooling_config, + get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once) @@ -197,6 +197,7 @@ class ModelConfig: code_revision, rope_scaling, rope_theta, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) + self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -229,7 +230,8 @@ class ModelConfig: max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), - spec_target_max_model_len=spec_target_max_model_len) + spec_target_max_model_len=spec_target_max_model_len, + encoder_config=self.encoder_config) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( @@ -273,6 +275,10 @@ class ModelConfig: return None + def _get_encoder_config(self): + return get_sentence_transformer_tokenizer_config( + self.model, self.revision) + def _init_pooler_config( self, pooling_type: Optional[str] = None, @@ -282,6 +288,14 @@ class ModelConfig: pooling_returned_token_ids: Optional[List[int]] = None ) -> Optional["PoolerConfig"]: if self.task == "embedding": + pooling_config = get_pooling_config(self.model, self.revision) + if pooling_config is not None: + # override if user does not + # specifies pooling_type and/or pooling_norm + if pooling_type is None: + pooling_type = pooling_config["pooling_type"] + if pooling_norm is None: + pooling_norm = pooling_config["normalize"] return PoolerConfig( pooling_type=pooling_type, pooling_norm=pooling_norm, @@ -1795,6 +1809,7 @@ def _get_and_verify_max_len( disable_sliding_window: bool, sliding_window_len: Optional[Union[int, List[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, + encoder_config: Optional[Any] = None, ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1877,6 +1892,9 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + if encoder_config and "max_seq_length" in encoder_config: + derived_max_model_len = encoder_config["max_seq_length"] + # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. if max_model_len is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b556c0ee..8c5b442e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -863,7 +864,7 @@ class EngineArgs: parser.add_argument( '--pooling-type', - choices=['LAST', 'ALL', 'CLS', 'STEP'], + choices=[pt.name for pt in PoolingType], default=None, help='Used to configure the pooling method in the embedding model.' ) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 1c9772b4..024badbc 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -16,6 +16,7 @@ class PoolingType(IntEnum): ALL = 1 CLS = 2 STEP = 3 + MEAN = 4 class Pooler(nn.Module): @@ -27,7 +28,7 @@ class Pooler(nn.Module): 3. Returns structured results as `PoolerOutput`. Attributes: - pooling_type: The type of pooling to use (LAST, ALL, CLS). + pooling_type: The type of pooling to use. normalize: Whether to normalize the pooled data. """ @@ -97,6 +98,17 @@ class Pooler(nn.Module): for prompt_len in prompt_lens: pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len + elif self.pooling_type == PoolingType.MEAN: + # Calculate mean pooling + cumsum = torch.cumsum(hidden_states, dim=0) + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) + end_indices = torch.cumsum(prompt_lens, dim=0) + pooled_data = ( + cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) elif self.pooling_type == PoolingType.STEP: if self.returned_token_ids is not None and len( self.returned_token_ids) > 0: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 415d8bf7..6b38ee31 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,6 +6,9 @@ from typing import Any, Dict, Optional, Type, Union import huggingface_hub from huggingface_hub import (file_exists, hf_hub_download, try_to_load_from_cache) +from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError) from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -213,7 +216,7 @@ def get_config( raise e elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision) + config = load_params_config(model, revision, token=kwargs.get("token")) else: raise ValueError(f"Unsupported config format: {config_format}") @@ -243,6 +246,158 @@ def get_config( return config +def get_hf_file_to_dict(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main', + token: Optional[str] = None): + """ + Downloads a file from the Hugging Face Hub and returns + its contents as a dictionary. + + Parameters: + - file_name (str): The name of the file to download. + - model (str): The name of the model on the Hugging Face Hub. + - revision (str): The specific version of the model. + - token (str): The Hugging Face authentication token. + + Returns: + - config_dict (dict): A dictionary containing + the contents of the downloaded file. + """ + file_path = Path(model) / file_name + + if file_or_path_exists(model=model, + config_name=file_name, + revision=revision, + token=token): + + if not file_path.is_file(): + try: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision) + except (RepositoryNotFoundError, RevisionNotFoundError, + EntryNotFoundError, LocalEntryNotFoundError) as e: + logger.debug("File or repository not found in hf_hub_download", + e) + return None + file_path = Path(hf_hub_file) + + with open(file_path) as file: + return json.load(file) + return None + + +def get_pooling_config(model: str, + revision: Optional[str] = 'main', + token: Optional[str] = None): + """ + This function gets the pooling and normalize + config from the model - only applies to + sentence-transformers models. + + Args: + model (str): The name of the Hugging Face model. + revision (str, optional): The specific version + of the model to use. Defaults to 'main'. + + Returns: + dict: A dictionary containing the pooling + type and whether normalization is used. + """ + + modules_file_name = "modules.json" + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision, + token) + + if modules_dict is None: + return None + + pooling = next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Pooling"), + None) + normalize = bool( + next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Normalize"), + False)) + + if pooling: + + pooling_file_name = "{}/config.json".format(pooling["path"]) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision, + token) + pooling_type_name = next( + (item for item, val in pooling_dict.items() if val is True), None) + + if pooling_type_name is not None: + pooling_type_name = get_pooling_config_name(pooling_type_name) + + return {"pooling_type": pooling_type_name, "normalize": normalize} + + return None + + +def get_pooling_config_name(pooling_name: str) -> Union[str, None]: + if "pooling_mode_" in pooling_name: + pooling_name = pooling_name.replace("pooling_mode_", "") + + if "_" in pooling_name: + pooling_name = pooling_name.split("_")[0] + + if "lasttoken" in pooling_name: + pooling_name = "last" + + supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'] + pooling_type_name = pooling_name.upper() + + try: + if pooling_type_name in supported_pooling_types: + return pooling_type_name + except NotImplementedError as e: + logger.debug("Pooling type not supported", e) + return None + return None + + +def get_sentence_transformer_tokenizer_config(model: str, + revision: Optional[str] = 'main', + token: Optional[str] = None): + """ + Returns the tokenization configuration dictionary for a + given Sentence Transformer BERT model. + + Parameters: + - model (str): The name of the Sentence Transformer + BERT model. + - revision (str, optional): The revision of the m + odel to use. Defaults to 'main'. + - token (str): A Hugging Face access token. + + Returns: + - dict: A dictionary containing the configuration parameters + for the Sentence Transformer BERT model. + """ + for config_name in [ + "sentence_bert_config.json", + "sentence_roberta_config.json", + "sentence_distilbert_config.json", + "sentence_camembert_config.json", + "sentence_albert_config.json", + "sentence_xlm-roberta_config.json", + "sentence_xlnet_config.json", + ]: + encoder_dict = get_hf_file_to_dict(config_name, model, revision, token) + if encoder_dict: + break + + if not encoder_dict: + return None + + if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")): + return encoder_dict + return None + + def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: """Try to register HF model configuration class to serialize by value @@ -305,20 +460,15 @@ def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: exc_info=e) -def load_params_config(model, revision) -> PretrainedConfig: +def load_params_config(model: Union[str, Path], + revision: Optional[str], + token: Optional[str] = None) -> PretrainedConfig: # This function loads a params.json config which # should be used when loading models in mistral format config_file_name = "params.json" - config_path = Path(model) / config_file_name - - if not config_path.is_file(): - config_path = Path( - hf_hub_download(model, config_file_name, revision=revision)) - - with open(config_path) as file: - config_dict = json.load(file) + config_dict = get_hf_file_to_dict(config_file_name, model, revision, token) config_mapping = { "dim": "hidden_size", diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9a414925..6a114b51 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -25,6 +25,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) + if (model_config.encoder_config is not None + and "do_lower_case" in model_config.encoder_config): + init_kwargs["do_lower_case"] = model_config.encoder_config[ + "do_lower_case"] + return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs)