[Bugfix]: serialize config by value for --trust-remote-code (#6751)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
76a5e13270
commit
b729901139
@ -28,19 +28,25 @@ class ParallelSetup(NamedTuple):
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
class PPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
trust_remote_code: bool
|
||||
tokenizer_mode: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPTestSettings:
|
||||
parallel_setups: List[ParallelSetup]
|
||||
distributed_backends: List[str]
|
||||
task: TaskOption
|
||||
trust_remote_code: bool
|
||||
tokenizer_mode: Optional[str]
|
||||
test_options: PPTestOptions
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 1,
|
||||
pp_base: int = 2,
|
||||
multi_node_only: bool = False,
|
||||
task: TaskOption = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
@ -70,8 +76,9 @@ class PPTestSettings:
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
task=task,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
tokenizer_mode=tokenizer_mode),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -80,6 +87,7 @@ class PPTestSettings:
|
||||
tp_base: int = 1,
|
||||
pp_base: int = 2,
|
||||
task: TaskOption = "auto",
|
||||
multi_node_only: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
):
|
||||
@ -92,15 +100,18 @@ class PPTestSettings:
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
task=task,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
tokenizer_mode=tokenizer_mode),
|
||||
)
|
||||
|
||||
def iter_params(self, model_name: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for distributed_backend in self.distributed_backends:
|
||||
yield (model_name, parallel_setup, distributed_backend,
|
||||
self.task, self.trust_remote_code, self.tokenizer_mode)
|
||||
self.task, opts)
|
||||
|
||||
|
||||
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
|
||||
@ -110,6 +121,7 @@ class PPTestSettings:
|
||||
GENERATION_MODEL_SETTINGS = {
|
||||
# [DETAILED TESTS]
|
||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
|
||||
# [FAST TESTS]
|
||||
# Uses Llama
|
||||
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
||||
@ -151,10 +163,8 @@ GENERATION_MODEL_SETTINGS = {
|
||||
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
|
||||
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"microsoft/phi-2": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
# FIXME: https://github.com/vllm-project/vllm/issues/8553
|
||||
# "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
||||
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
|
||||
@ -205,6 +215,7 @@ TEST_MODELS = [
|
||||
# [LANGUAGE GENERATION]
|
||||
"meta-llama/Meta-Llama-3-8B",
|
||||
"ibm/PowerLM-3b",
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
# [LANGUAGE EMBEDDING]
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
"BAAI/bge-multilingual-gemma2",
|
||||
@ -220,19 +231,21 @@ def _compare_tp(
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
task: TaskOption,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate", "encode"] = "encode",
|
||||
method: Literal["generate", "encode"],
|
||||
):
|
||||
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
|
||||
multi_node_only, trust_remote_code, tokenizer_mode = test_options
|
||||
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
common_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
@ -307,7 +320,7 @@ def _compare_tp(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
"test_options"),
|
||||
[
|
||||
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
@ -320,23 +333,21 @@ def test_tp_language_generation(
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
task: TaskOption,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
task,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
"test_options"),
|
||||
[
|
||||
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
@ -349,23 +360,21 @@ def test_tp_language_embedding(
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
task: TaskOption,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
task,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="encode")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
"test_options"),
|
||||
[
|
||||
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
@ -378,15 +387,13 @@ def test_tp_multimodal_generation(
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
task: TaskOption,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
task,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate")
|
||||
|
@ -16,6 +16,8 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -924,6 +926,8 @@ class EngineArgs:
|
||||
"supported for multimodal models and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
maybe_register_config_serialize_by_value(self.trust_remote_code)
|
||||
|
||||
cache_config = CacheConfig(
|
||||
# neuron needs block_size = max_model_len
|
||||
block_size=self.block_size if self.device != "neuron" else
|
||||
|
@ -232,6 +232,68 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
||||
"""Try to register HF model configuration class to serialize by value
|
||||
|
||||
With trust_remote_code, the config class is typically an instance of a
|
||||
custom class imported from the HF modules cache. The class will not be
|
||||
importable in spawned workers by default (and won't exist at all on
|
||||
other nodes), which breaks serialization of the config.
|
||||
|
||||
In this function we tell the cloudpickle serialization library to pass
|
||||
instances of these generated classes by value instead of by reference,
|
||||
i.e. the class definition is serialized along with its data so that the
|
||||
class module does not need to be importable on the receiving end. This
|
||||
registration only works if the modules cache has already been
|
||||
initialized.
|
||||
|
||||
|
||||
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
|
||||
"""
|
||||
if not trust_remote_code:
|
||||
return
|
||||
|
||||
try:
|
||||
import transformers_modules
|
||||
except ImportError:
|
||||
logger.debug("Could not import transformers_modules used for remote"
|
||||
" code. If remote code is not needed remove"
|
||||
" `--trust-remote-code`.")
|
||||
return
|
||||
|
||||
try:
|
||||
import cloudpickle
|
||||
cloudpickle.register_pickle_by_value(transformers_modules)
|
||||
|
||||
# ray vendors its own version of cloudpickle
|
||||
from vllm.executor.ray_utils import ray
|
||||
if ray:
|
||||
ray.cloudpickle.register_pickle_by_value(transformers_modules)
|
||||
|
||||
# multiprocessing uses pickle to serialize arguments when using spawn
|
||||
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
|
||||
# that contain instances of the custom config class to avoid
|
||||
# serialization problems if the generated module (and model) has a `.`
|
||||
# in its name
|
||||
import multiprocessing
|
||||
import pickle
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
def _reduce_modelconfig(mc: ModelConfig):
|
||||
return (pickle.loads, (cloudpickle.dumps(mc), ))
|
||||
|
||||
multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Unable to register remote classes used by"
|
||||
" trust_remote_code with by-value serialization. This may"
|
||||
" lead to a later error. If remote code is not needed"
|
||||
" remove `--trust-remote-code`",
|
||||
exc_info=e)
|
||||
|
||||
|
||||
def load_params_config(model, revision) -> PretrainedConfig:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
|
@ -968,6 +968,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
"""
|
||||
Lazy initialization of the Hugging Face modules.
|
||||
|
Loading…
x
Reference in New Issue
Block a user