[Misc] Fix import error in tensorizer tests and cleanup some code (#10349)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3d158cdc8d
commit
b311efd0bd
@ -8,10 +8,12 @@ from unittest.mock import MagicMock, patch
|
|||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from tensorizer import EncryptionParams
|
from tensorizer import EncryptionParams
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
# yapf conflicts with isort for this docstring
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||||
TensorSerializer,
|
TensorSerializer,
|
||||||
@ -20,13 +22,14 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
|||||||
open_stream,
|
open_stream,
|
||||||
serialize_vllm_model,
|
serialize_vllm_model,
|
||||||
tensorize_vllm_model)
|
tensorize_vllm_model)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.utils import import_from_path
|
||||||
|
|
||||||
from ..conftest import VllmRunner
|
from ..conftest import VllmRunner
|
||||||
from ..utils import RemoteOpenAIServer
|
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
from .conftest import retry_until_skip
|
from .conftest import retry_until_skip
|
||||||
|
|
||||||
# yapf conflicts with isort for this docstring
|
EXAMPLES_PATH = VLLM_PATH / "examples"
|
||||||
|
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@ -94,8 +97,8 @@ def test_can_deserialize_s3(vllm_runner):
|
|||||||
num_readers=1,
|
num_readers=1,
|
||||||
s3_endpoint="object.ord1.coreweave.com",
|
s3_endpoint="object.ord1.coreweave.com",
|
||||||
)) as loaded_hf_model:
|
)) as loaded_hf_model:
|
||||||
deserialized_outputs = loaded_hf_model.generate(prompts,
|
deserialized_outputs = loaded_hf_model.generate(
|
||||||
sampling_params)
|
prompts, sampling_params)
|
||||||
# noqa: E501
|
# noqa: E501
|
||||||
|
|
||||||
assert deserialized_outputs
|
assert deserialized_outputs
|
||||||
@ -111,23 +114,21 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
|||||||
|
|
||||||
outputs = vllm_model.generate(prompts, sampling_params)
|
outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
config_for_serializing = TensorizerConfig(
|
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||||
tensorizer_uri=model_path,
|
encryption_keyfile=key_path)
|
||||||
encryption_keyfile=key_path
|
|
||||||
)
|
|
||||||
serialize_vllm_model(get_torch_model(vllm_model),
|
serialize_vllm_model(get_torch_model(vllm_model),
|
||||||
config_for_serializing)
|
config_for_serializing)
|
||||||
|
|
||||||
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||||
encryption_keyfile=key_path)
|
encryption_keyfile=key_path)
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(model_ref,
|
||||||
model_ref,
|
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
|
model_loader_extra_config=config_for_deserializing
|
||||||
|
) as loaded_vllm_model: # noqa: E501
|
||||||
|
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
deserialized_outputs = loaded_vllm_model.generate(
|
||||||
sampling_params)
|
prompts, sampling_params)
|
||||||
# noqa: E501
|
# noqa: E501
|
||||||
|
|
||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
@ -156,14 +157,14 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
|||||||
|
|
||||||
|
|
||||||
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||||
from huggingface_hub import snapshot_download
|
multilora_inference = import_from_path(
|
||||||
|
"examples.multilora_inference",
|
||||||
from examples.multilora_inference import (create_test_prompts,
|
EXAMPLES_PATH / "multilora_inference.py",
|
||||||
process_requests)
|
)
|
||||||
|
|
||||||
model_ref = "meta-llama/Llama-2-7b-hf"
|
model_ref = "meta-llama/Llama-2-7b-hf"
|
||||||
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||||
test_prompts = create_test_prompts(lora_path)
|
test_prompts = multilora_inference.create_test_prompts(lora_path)
|
||||||
|
|
||||||
# Serialize model before deserializing and binding LoRA adapters
|
# Serialize model before deserializing and binding LoRA adapters
|
||||||
with vllm_runner(model_ref, ) as vllm_model:
|
with vllm_runner(model_ref, ) as vllm_model:
|
||||||
@ -186,7 +187,8 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
|||||||
max_num_seqs=50,
|
max_num_seqs=50,
|
||||||
max_model_len=1000,
|
max_model_len=1000,
|
||||||
) as loaded_vllm_model:
|
) as loaded_vllm_model:
|
||||||
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
|
multilora_inference.process_requests(
|
||||||
|
loaded_vllm_model.model.llm_engine, test_prompts)
|
||||||
|
|
||||||
assert loaded_vllm_model
|
assert loaded_vllm_model
|
||||||
|
|
||||||
@ -217,8 +219,11 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
|||||||
|
|
||||||
## Start OpenAI API server
|
## Start OpenAI API server
|
||||||
openai_args = [
|
openai_args = [
|
||||||
"--dtype", "float16", "--load-format",
|
"--dtype",
|
||||||
"tensorizer", "--model-loader-extra-config",
|
"float16",
|
||||||
|
"--load-format",
|
||||||
|
"tensorizer",
|
||||||
|
"--model-loader-extra-config",
|
||||||
json.dumps(model_loader_extra_config),
|
json.dumps(model_loader_extra_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -251,8 +256,7 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
|
||||||
reason="Requires 2 GPUs")
|
|
||||||
def test_tensorizer_with_tp_path_without_template(vllm_runner):
|
def test_tensorizer_with_tp_path_without_template(vllm_runner):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
model_ref = "EleutherAI/pythia-1.4b"
|
model_ref = "EleutherAI/pythia-1.4b"
|
||||||
@ -271,10 +275,9 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs")
|
||||||
reason="Requires 2 GPUs")
|
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
|
||||||
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
vllm_runner, tmp_path):
|
||||||
tmp_path):
|
|
||||||
model_ref = "EleutherAI/pythia-1.4b"
|
model_ref = "EleutherAI/pythia-1.4b"
|
||||||
# record outputs from un-sharded un-tensorized model
|
# record outputs from un-sharded un-tensorized model
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
@ -313,13 +316,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
|||||||
disable_custom_all_reduce=True,
|
disable_custom_all_reduce=True,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
|
model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
deserialized_outputs = loaded_vllm_model.generate(
|
||||||
sampling_params)
|
prompts, sampling_params)
|
||||||
|
|
||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@retry_until_skip(3)
|
@retry_until_skip(3)
|
||||||
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -337,8 +339,8 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
|||||||
with vllm_runner(model_ref,
|
with vllm_runner(model_ref,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=config) as loaded_vllm_model:
|
model_loader_extra_config=config) as loaded_vllm_model:
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
deserialized_outputs = loaded_vllm_model.generate(
|
||||||
sampling_params)
|
prompts, sampling_params)
|
||||||
# noqa: E501
|
# noqa: E501
|
||||||
|
|
||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
|
@ -2002,9 +2002,6 @@ class LLMEngine:
|
|||||||
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
|
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
|
||||||
metrics.model_execute_time)
|
metrics.model_execute_time)
|
||||||
|
|
||||||
def is_encoder_decoder_model(self):
|
|
||||||
return self.input_preprocessor.is_encoder_decoder_model()
|
|
||||||
|
|
||||||
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
||||||
lora_request: Optional[LoRARequest]):
|
lora_request: Optional[LoRARequest]):
|
||||||
if is_encoder_decoder_inputs(inputs):
|
if is_encoder_decoder_inputs(inputs):
|
||||||
|
@ -964,6 +964,3 @@ class LLM:
|
|||||||
# This is necessary because some requests may be finished earlier than
|
# This is necessary because some requests may be finished earlier than
|
||||||
# its previous requests.
|
# its previous requests.
|
||||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||||
|
|
||||||
def _is_encoder_decoder_model(self):
|
|
||||||
return self.llm_engine.is_encoder_decoder_model()
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import importlib
|
|
||||||
import importlib.util
|
|
||||||
import os
|
import os
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
||||||
@ -9,7 +7,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
ExtractedToolCallInformation)
|
ExtractedToolCallInformation)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import import_from_path, is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -149,13 +147,14 @@ class ToolParserManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||||
"""
|
"""
|
||||||
Import a user defined tool parser by the path of the tool parser define
|
Import a user-defined tool parser by the path of the tool parser define
|
||||||
file.
|
file.
|
||||||
"""
|
"""
|
||||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
|
||||||
if spec is None or spec.loader is None:
|
try:
|
||||||
logger.error("load %s from %s failed.", module_name, plugin_path)
|
import_from_path(module_name, plugin_path)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load module '%s' from %s.",
|
||||||
|
module_name, plugin_path)
|
||||||
return
|
return
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
|
@ -67,7 +67,7 @@ class InputPreprocessor:
|
|||||||
model config is unavailable.
|
model config is unavailable.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if not self.is_encoder_decoder_model():
|
if not self.model_config.is_encoder_decoder:
|
||||||
print_warning_once("Using None for decoder start token id because "
|
print_warning_once("Using None for decoder start token id because "
|
||||||
"this is not an encoder/decoder model.")
|
"this is not an encoder/decoder model.")
|
||||||
return None
|
return None
|
||||||
@ -632,7 +632,7 @@ class InputPreprocessor:
|
|||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
"""Preprocess the input prompt."""
|
"""Preprocess the input prompt."""
|
||||||
if self.is_encoder_decoder_model():
|
if self.model_config.is_encoder_decoder:
|
||||||
# Encoder-decoder model requires special mapping of
|
# Encoder-decoder model requires special mapping of
|
||||||
# input prompts to encoder & decoder
|
# input prompts to encoder & decoder
|
||||||
return self._process_encoder_decoder_prompt(
|
return self._process_encoder_decoder_prompt(
|
||||||
@ -660,7 +660,7 @@ class InputPreprocessor:
|
|||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
"""Async version of :meth:`preprocess`."""
|
"""Async version of :meth:`preprocess`."""
|
||||||
if self.is_encoder_decoder_model():
|
if self.model_config.is_encoder_decoder:
|
||||||
# Encoder-decoder model requires special mapping of
|
# Encoder-decoder model requires special mapping of
|
||||||
# input prompts to encoder & decoder
|
# input prompts to encoder & decoder
|
||||||
return await self._process_encoder_decoder_prompt_async(
|
return await self._process_encoder_decoder_prompt_async(
|
||||||
@ -679,6 +679,3 @@ class InputPreprocessor:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_encoder_decoder_model(self):
|
|
||||||
return self.model_config.is_encoder_decoder
|
|
||||||
|
@ -5,6 +5,7 @@ import datetime
|
|||||||
import enum
|
import enum
|
||||||
import gc
|
import gc
|
||||||
import getpass
|
import getpass
|
||||||
|
import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import os
|
import os
|
||||||
@ -1539,6 +1540,25 @@ def is_in_doc_build() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
|
||||||
|
"""
|
||||||
|
Import a Python file according to its file path.
|
||||||
|
|
||||||
|
Based on the official recipe:
|
||||||
|
https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||||
|
"""
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec is None:
|
||||||
|
raise ModuleNotFoundError(f"No module named '{module_name}'")
|
||||||
|
|
||||||
|
assert spec.loader is not None
|
||||||
|
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
# create a library to hold the custom op
|
# create a library to hold the custom op
|
||||||
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||||
|
|
||||||
|
@ -163,9 +163,6 @@ class LLMEngine:
|
|||||||
def get_model_config(self):
|
def get_model_config(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_encoder_decoder_model(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def start_profile(self):
|
def start_profile(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user