[Bugfix]: serialize config by value for --trust-remote-code (#6751)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
76a5e13270
commit
b729901139
@ -28,19 +28,25 @@ class ParallelSetup(NamedTuple):
|
|||||||
chunked_prefill: bool
|
chunked_prefill: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PPTestOptions(NamedTuple):
|
||||||
|
multi_node_only: bool
|
||||||
|
trust_remote_code: bool
|
||||||
|
tokenizer_mode: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPTestSettings:
|
class PPTestSettings:
|
||||||
parallel_setups: List[ParallelSetup]
|
parallel_setups: List[ParallelSetup]
|
||||||
distributed_backends: List[str]
|
distributed_backends: List[str]
|
||||||
task: TaskOption
|
task: TaskOption
|
||||||
trust_remote_code: bool
|
test_options: PPTestOptions
|
||||||
tokenizer_mode: Optional[str]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def detailed(
|
def detailed(
|
||||||
*,
|
*,
|
||||||
tp_base: int = 1,
|
tp_base: int = 1,
|
||||||
pp_base: int = 2,
|
pp_base: int = 2,
|
||||||
|
multi_node_only: bool = False,
|
||||||
task: TaskOption = "auto",
|
task: TaskOption = "auto",
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_mode: Optional[str] = None,
|
tokenizer_mode: Optional[str] = None,
|
||||||
@ -70,8 +76,9 @@ class PPTestSettings:
|
|||||||
],
|
],
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
task=task,
|
task=task,
|
||||||
|
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -80,6 +87,7 @@ class PPTestSettings:
|
|||||||
tp_base: int = 1,
|
tp_base: int = 1,
|
||||||
pp_base: int = 2,
|
pp_base: int = 2,
|
||||||
task: TaskOption = "auto",
|
task: TaskOption = "auto",
|
||||||
|
multi_node_only: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_mode: Optional[str] = None,
|
tokenizer_mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -92,15 +100,18 @@ class PPTestSettings:
|
|||||||
],
|
],
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
task=task,
|
task=task,
|
||||||
|
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_params(self, model_name: str):
|
def iter_params(self, model_name: str):
|
||||||
|
opts = self.test_options
|
||||||
|
|
||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for distributed_backend in self.distributed_backends:
|
for distributed_backend in self.distributed_backends:
|
||||||
yield (model_name, parallel_setup, distributed_backend,
|
yield (model_name, parallel_setup, distributed_backend,
|
||||||
self.task, self.trust_remote_code, self.tokenizer_mode)
|
self.task, opts)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
|
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
|
||||||
@ -110,6 +121,7 @@ class PPTestSettings:
|
|||||||
GENERATION_MODEL_SETTINGS = {
|
GENERATION_MODEL_SETTINGS = {
|
||||||
# [DETAILED TESTS]
|
# [DETAILED TESTS]
|
||||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
|
||||||
# [FAST TESTS]
|
# [FAST TESTS]
|
||||||
# Uses Llama
|
# Uses Llama
|
||||||
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
||||||
@ -151,10 +163,8 @@ GENERATION_MODEL_SETTINGS = {
|
|||||||
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
|
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
|
||||||
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||||
"microsoft/phi-2": PPTestSettings.fast(),
|
"microsoft/phi-2": PPTestSettings.fast(),
|
||||||
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(),
|
|
||||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||||
# FIXME: https://github.com/vllm-project/vllm/issues/8553
|
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||||
# "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
|
||||||
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
||||||
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||||
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
|
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
|
||||||
@ -205,6 +215,7 @@ TEST_MODELS = [
|
|||||||
# [LANGUAGE GENERATION]
|
# [LANGUAGE GENERATION]
|
||||||
"meta-llama/Meta-Llama-3-8B",
|
"meta-llama/Meta-Llama-3-8B",
|
||||||
"ibm/PowerLM-3b",
|
"ibm/PowerLM-3b",
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct",
|
||||||
# [LANGUAGE EMBEDDING]
|
# [LANGUAGE EMBEDDING]
|
||||||
"intfloat/e5-mistral-7b-instruct",
|
"intfloat/e5-mistral-7b-instruct",
|
||||||
"BAAI/bge-multilingual-gemma2",
|
"BAAI/bge-multilingual-gemma2",
|
||||||
@ -220,19 +231,21 @@ def _compare_tp(
|
|||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
task: TaskOption,
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
test_options: PPTestOptions,
|
||||||
tokenizer_mode: Optional[str],
|
|
||||||
num_gpus_available: int,
|
num_gpus_available: int,
|
||||||
*,
|
*,
|
||||||
method: Literal["generate", "encode"] = "encode",
|
method: Literal["generate", "encode"],
|
||||||
):
|
):
|
||||||
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
|
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
|
||||||
|
multi_node_only, trust_remote_code, tokenizer_mode = test_options
|
||||||
|
|
||||||
if num_gpus_available < tp_size * pp_size:
|
if num_gpus_available < tp_size * pp_size:
|
||||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||||
"multiprocessing distributed backend")
|
"multiprocessing distributed backend")
|
||||||
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
common_args = [
|
common_args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
@ -307,7 +320,7 @@ def _compare_tp(
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||||
"trust_remote_code", "tokenizer_mode"),
|
"test_options"),
|
||||||
[
|
[
|
||||||
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
|
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
|
||||||
for params in settings.iter_params(model_name)
|
for params in settings.iter_params(model_name)
|
||||||
@ -320,23 +333,21 @@ def test_tp_language_generation(
|
|||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
task: TaskOption,
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
test_options: PPTestOptions,
|
||||||
tokenizer_mode: Optional[str],
|
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_name,
|
_compare_tp(model_name,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
task,
|
task,
|
||||||
trust_remote_code,
|
test_options,
|
||||||
tokenizer_mode,
|
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="generate")
|
method="generate")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||||
"trust_remote_code", "tokenizer_mode"),
|
"test_options"),
|
||||||
[
|
[
|
||||||
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
|
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
|
||||||
for params in settings.iter_params(model_name)
|
for params in settings.iter_params(model_name)
|
||||||
@ -349,23 +360,21 @@ def test_tp_language_embedding(
|
|||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
task: TaskOption,
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
test_options: PPTestOptions,
|
||||||
tokenizer_mode: Optional[str],
|
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_name,
|
_compare_tp(model_name,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
task,
|
task,
|
||||||
trust_remote_code,
|
test_options,
|
||||||
tokenizer_mode,
|
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="encode")
|
method="encode")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||||
"trust_remote_code", "tokenizer_mode"),
|
"test_options"),
|
||||||
[
|
[
|
||||||
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
|
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
|
||||||
for params in settings.iter_params(model_name)
|
for params in settings.iter_params(model_name)
|
||||||
@ -378,15 +387,13 @@ def test_tp_multimodal_generation(
|
|||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
task: TaskOption,
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
test_options: PPTestOptions,
|
||||||
tokenizer_mode: Optional[str],
|
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
_compare_tp(model_name,
|
_compare_tp(model_name,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
task,
|
task,
|
||||||
trust_remote_code,
|
test_options,
|
||||||
tokenizer_mode,
|
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
method="generate")
|
method="generate")
|
||||||
|
@ -16,6 +16,8 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
|||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
from vllm.transformers_utils.config import (
|
||||||
|
maybe_register_config_serialize_by_value)
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -924,6 +926,8 @@ class EngineArgs:
|
|||||||
"supported for multimodal models and has been disabled.")
|
"supported for multimodal models and has been disabled.")
|
||||||
self.enable_prefix_caching = False
|
self.enable_prefix_caching = False
|
||||||
|
|
||||||
|
maybe_register_config_serialize_by_value(self.trust_remote_code)
|
||||||
|
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
# neuron needs block_size = max_model_len
|
# neuron needs block_size = max_model_len
|
||||||
block_size=self.block_size if self.device != "neuron" else
|
block_size=self.block_size if self.device != "neuron" else
|
||||||
|
@ -232,6 +232,68 @@ def get_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
||||||
|
"""Try to register HF model configuration class to serialize by value
|
||||||
|
|
||||||
|
With trust_remote_code, the config class is typically an instance of a
|
||||||
|
custom class imported from the HF modules cache. The class will not be
|
||||||
|
importable in spawned workers by default (and won't exist at all on
|
||||||
|
other nodes), which breaks serialization of the config.
|
||||||
|
|
||||||
|
In this function we tell the cloudpickle serialization library to pass
|
||||||
|
instances of these generated classes by value instead of by reference,
|
||||||
|
i.e. the class definition is serialized along with its data so that the
|
||||||
|
class module does not need to be importable on the receiving end. This
|
||||||
|
registration only works if the modules cache has already been
|
||||||
|
initialized.
|
||||||
|
|
||||||
|
|
||||||
|
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
|
||||||
|
"""
|
||||||
|
if not trust_remote_code:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import transformers_modules
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("Could not import transformers_modules used for remote"
|
||||||
|
" code. If remote code is not needed remove"
|
||||||
|
" `--trust-remote-code`.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cloudpickle
|
||||||
|
cloudpickle.register_pickle_by_value(transformers_modules)
|
||||||
|
|
||||||
|
# ray vendors its own version of cloudpickle
|
||||||
|
from vllm.executor.ray_utils import ray
|
||||||
|
if ray:
|
||||||
|
ray.cloudpickle.register_pickle_by_value(transformers_modules)
|
||||||
|
|
||||||
|
# multiprocessing uses pickle to serialize arguments when using spawn
|
||||||
|
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
|
||||||
|
# that contain instances of the custom config class to avoid
|
||||||
|
# serialization problems if the generated module (and model) has a `.`
|
||||||
|
# in its name
|
||||||
|
import multiprocessing
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
def _reduce_modelconfig(mc: ModelConfig):
|
||||||
|
return (pickle.loads, (cloudpickle.dumps(mc), ))
|
||||||
|
|
||||||
|
multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to register remote classes used by"
|
||||||
|
" trust_remote_code with by-value serialization. This may"
|
||||||
|
" lead to a later error. If remote code is not needed"
|
||||||
|
" remove `--trust-remote-code`",
|
||||||
|
exc_info=e)
|
||||||
|
|
||||||
|
|
||||||
def load_params_config(model, revision) -> PretrainedConfig:
|
def load_params_config(model, revision) -> PretrainedConfig:
|
||||||
# This function loads a params.json config which
|
# This function loads a params.json config which
|
||||||
# should be used when loading models in mistral format
|
# should be used when loading models in mistral format
|
||||||
|
@ -968,6 +968,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
|||||||
return [item for sublist in lists for item in sublist]
|
return [item for sublist in lists for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This function can be removed if transformer_modules classes are
|
||||||
|
# serialized by value when communicating between processes
|
||||||
def init_cached_hf_modules() -> None:
|
def init_cached_hf_modules() -> None:
|
||||||
"""
|
"""
|
||||||
Lazy initialization of the Hugging Face modules.
|
Lazy initialization of the Hugging Face modules.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user