[Misc] Fix import error in tensorizer tests and cleanup some code (#10349)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-15 17:34:17 +08:00 committed by GitHub
parent 3d158cdc8d
commit b311efd0bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 67 additions and 58 deletions

View File

@ -8,10 +8,12 @@ from unittest.mock import MagicMock, patch
import openai import openai
import pytest import pytest
import torch import torch
from huggingface_hub import snapshot_download
from tensorizer import EncryptionParams from tensorizer import EncryptionParams
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
# yapf conflicts with isort for this docstring
# yapf: disable # yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer, TensorSerializer,
@ -20,13 +22,14 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
open_stream, open_stream,
serialize_vllm_model, serialize_vllm_model,
tensorize_vllm_model) tensorize_vllm_model)
# yapf: enable
from vllm.utils import import_from_path
from ..conftest import VllmRunner from ..conftest import VllmRunner
from ..utils import RemoteOpenAIServer from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import retry_until_skip from .conftest import retry_until_skip
# yapf conflicts with isort for this docstring EXAMPLES_PATH = VLLM_PATH / "examples"
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
@ -94,8 +97,8 @@ def test_can_deserialize_s3(vllm_runner):
num_readers=1, num_readers=1,
s3_endpoint="object.ord1.coreweave.com", s3_endpoint="object.ord1.coreweave.com",
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(prompts, deserialized_outputs = loaded_hf_model.generate(
sampling_params) prompts, sampling_params)
# noqa: E501 # noqa: E501
assert deserialized_outputs assert deserialized_outputs
@ -111,23 +114,21 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
outputs = vllm_model.generate(prompts, sampling_params) outputs = vllm_model.generate(prompts, sampling_params)
config_for_serializing = TensorizerConfig( config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
tensorizer_uri=model_path, encryption_keyfile=key_path)
encryption_keyfile=key_path
)
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing) config_for_serializing)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path) encryption_keyfile=key_path)
with vllm_runner( with vllm_runner(model_ref,
model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501 model_loader_extra_config=config_for_deserializing
) as loaded_vllm_model: # noqa: E501
deserialized_outputs = loaded_vllm_model.generate(prompts, deserialized_outputs = loaded_vllm_model.generate(
sampling_params) prompts, sampling_params)
# noqa: E501 # noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
@ -156,14 +157,14 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
from huggingface_hub import snapshot_download multilora_inference = import_from_path(
"examples.multilora_inference",
from examples.multilora_inference import (create_test_prompts, EXAMPLES_PATH / "multilora_inference.py",
process_requests) )
model_ref = "meta-llama/Llama-2-7b-hf" model_ref = "meta-llama/Llama-2-7b-hf"
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = create_test_prompts(lora_path) test_prompts = multilora_inference.create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters # Serialize model before deserializing and binding LoRA adapters
with vllm_runner(model_ref, ) as vllm_model: with vllm_runner(model_ref, ) as vllm_model:
@ -186,7 +187,8 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
max_num_seqs=50, max_num_seqs=50,
max_model_len=1000, max_model_len=1000,
) as loaded_vllm_model: ) as loaded_vllm_model:
process_requests(loaded_vllm_model.model.llm_engine, test_prompts) multilora_inference.process_requests(
loaded_vllm_model.model.llm_engine, test_prompts)
assert loaded_vllm_model assert loaded_vllm_model
@ -217,8 +219,11 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
## Start OpenAI API server ## Start OpenAI API server
openai_args = [ openai_args = [
"--dtype", "float16", "--load-format", "--dtype",
"tensorizer", "--model-loader-extra-config", "float16",
"--load-format",
"tensorizer",
"--model-loader-extra-config",
json.dumps(model_loader_extra_config), json.dumps(model_loader_extra_config),
] ]
@ -251,8 +256,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner): def test_tensorizer_with_tp_path_without_template(vllm_runner):
with pytest.raises(ValueError): with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
@ -271,10 +275,9 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
) )
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
reason="Requires 2 GPUs") def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner, vllm_runner, tmp_path):
tmp_path):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model # record outputs from un-sharded un-tensorized model
with vllm_runner( with vllm_runner(
@ -313,13 +316,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
model_loader_extra_config=tensorizer_config) as loaded_vllm_model: model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts, deserialized_outputs = loaded_vllm_model.generate(
sampling_params) prompts, sampling_params)
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
@retry_until_skip(3) @retry_until_skip(3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
gc.collect() gc.collect()
@ -337,8 +339,8 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config) as loaded_vllm_model: model_loader_extra_config=config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts, deserialized_outputs = loaded_vllm_model.generate(
sampling_params) prompts, sampling_params)
# noqa: E501 # noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs

View File

@ -2002,9 +2002,6 @@ class LLMEngine:
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE, SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time) metrics.model_execute_time)
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()
def _validate_model_inputs(self, inputs: ProcessorInputs, def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]): lora_request: Optional[LoRARequest]):
if is_encoder_decoder_inputs(inputs): if is_encoder_decoder_inputs(inputs):

View File

@ -964,6 +964,3 @@ class LLM:
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()

View File

@ -1,5 +1,3 @@
import importlib
import importlib.util
import os import os
from functools import cached_property from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Type, Union from typing import Callable, Dict, List, Optional, Sequence, Type, Union
@ -9,7 +7,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ExtractedToolCallInformation) ExtractedToolCallInformation)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import is_list_of from vllm.utils import import_from_path, is_list_of
logger = init_logger(__name__) logger = init_logger(__name__)
@ -149,13 +147,14 @@ class ToolParserManager:
@classmethod @classmethod
def import_tool_parser(cls, plugin_path: str) -> None: def import_tool_parser(cls, plugin_path: str) -> None:
""" """
Import a user defined tool parser by the path of the tool parser define Import a user-defined tool parser by the path of the tool parser define
file. file.
""" """
module_name = os.path.splitext(os.path.basename(plugin_path))[0] module_name = os.path.splitext(os.path.basename(plugin_path))[0]
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
if spec is None or spec.loader is None: try:
logger.error("load %s from %s failed.", module_name, plugin_path) import_from_path(module_name, plugin_path)
except Exception:
logger.exception("Failed to load module '%s' from %s.",
module_name, plugin_path)
return return
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

View File

@ -67,7 +67,7 @@ class InputPreprocessor:
model config is unavailable. model config is unavailable.
''' '''
if not self.is_encoder_decoder_model(): if not self.model_config.is_encoder_decoder:
print_warning_once("Using None for decoder start token id because " print_warning_once("Using None for decoder start token id because "
"this is not an encoder/decoder model.") "this is not an encoder/decoder model.")
return None return None
@ -632,7 +632,7 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.is_encoder_decoder_model(): if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
@ -660,7 +660,7 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Async version of :meth:`preprocess`.""" """Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model(): if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
@ -679,6 +679,3 @@ class InputPreprocessor:
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder

View File

@ -5,6 +5,7 @@ import datetime
import enum import enum
import gc import gc
import getpass import getpass
import importlib.util
import inspect import inspect
import ipaddress import ipaddress
import os import os
@ -1539,6 +1540,25 @@ def is_in_doc_build() -> bool:
return False return False
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
"""
Import a Python file according to its file path.
Based on the official recipe:
https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ModuleNotFoundError(f"No module named '{module_name}'")
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
# create a library to hold the custom op # create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa vllm_lib = Library("vllm", "FRAGMENT") # noqa

View File

@ -163,9 +163,6 @@ class LLMEngine:
def get_model_config(self): def get_model_config(self):
pass pass
def is_encoder_decoder_model(self):
pass
def start_profile(self): def start_profile(self):
pass pass