[Bugfix] Fix tensorizer memory profiling bug during testing (#6881)
This commit is contained in:
parent
5895b24677
commit
052b6f8ca4
@ -1,6 +1,5 @@
|
||||
# isort: skip_file
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
@ -12,34 +11,38 @@ from vllm.distributed import (destroy_distributed_environment,
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
ray.shutdown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test(request) -> bool:
|
||||
"""Allow subdirectories to skip global cleanup by overriding this fixture.
|
||||
This can provide a ~10x speedup for non-GPU unit tests since they don't need
|
||||
to initialize torch.
|
||||
"""
|
||||
def retry_until_skip(n):
|
||||
|
||||
return True
|
||||
def decorator_retry(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper_retry(*args, **kwargs):
|
||||
for i in range(n):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except AssertionError:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if i == n - 1:
|
||||
pytest.skip("Skipping test after attempts..")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
yield
|
||||
if should_do_global_cleanup_after_test:
|
||||
cleanup()
|
||||
return wrapper_retry
|
||||
|
||||
return decorator_retry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tensorizer_config():
|
||||
config = TensorizerConfig(tensorizer_uri="vllm")
|
||||
return config
|
||||
return config
|
||||
|
@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
@ -20,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||
serialize_vllm_model,
|
||||
tensorize_vllm_model)
|
||||
|
||||
from ..conftest import VllmRunner, cleanup
|
||||
from ..conftest import VllmRunner
|
||||
from ..utils import RemoteOpenAIServer
|
||||
from .conftest import retry_until_skip
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
|
||||
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -40,6 +41,7 @@ model_ref = "facebook/opt-125m"
|
||||
tensorize_model_for_testing_script = os.path.join(
|
||||
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
|
||||
|
||||
|
||||
def is_curl_installed():
|
||||
try:
|
||||
subprocess.check_call(['curl', '--version'])
|
||||
@ -47,14 +49,16 @@ def is_curl_installed():
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_model(vllm_runner: VllmRunner):
|
||||
return vllm_runner \
|
||||
.model \
|
||||
.llm_engine \
|
||||
.model_executor \
|
||||
.driver_worker \
|
||||
.model_runner \
|
||||
.model
|
||||
.model \
|
||||
.llm_engine \
|
||||
.model_executor \
|
||||
.driver_worker \
|
||||
.model_runner \
|
||||
.model
|
||||
|
||||
|
||||
def write_keyfile(keyfile_path: str):
|
||||
encryption_params = EncryptionParams.random()
|
||||
@ -63,7 +67,6 @@ def write_keyfile(keyfile_path: str):
|
||||
f.write(encryption_params.key)
|
||||
|
||||
|
||||
|
||||
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
|
||||
def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||
mock_linear_method = MagicMock()
|
||||
@ -85,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner):
|
||||
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
||||
|
||||
with vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=tensorized_path,
|
||||
num_readers=1,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
)) as loaded_hf_model:
|
||||
|
||||
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=tensorized_path,
|
||||
num_readers=1,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
)) as loaded_hf_model:
|
||||
deserialized_outputs = loaded_hf_model.generate(prompts,
|
||||
sampling_params)
|
||||
# noqa: E501
|
||||
|
||||
assert deserialized_outputs
|
||||
|
||||
@ -100,7 +104,6 @@ def test_can_deserialize_s3(vllm_runner):
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||
vllm_runner, tmp_path):
|
||||
cleanup()
|
||||
with vllm_runner(model_ref) as vllm_model:
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
key_path = tmp_path / (model_ref + ".key")
|
||||
@ -113,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||
encryption_keyfile=key_path
|
||||
)
|
||||
serialize_vllm_model(get_torch_model(vllm_model),
|
||||
config_for_serializing)
|
||||
|
||||
config_for_serializing)
|
||||
|
||||
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||
encryption_keyfile=key_path)
|
||||
|
||||
with vllm_runner(
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
|
||||
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
||||
sampling_params)
|
||||
# noqa: E501
|
||||
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
@ -140,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||
serializer.write_module(hf_model.model)
|
||||
|
||||
with vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
)) as loaded_hf_model:
|
||||
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
)) as loaded_hf_model:
|
||||
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||
prompts, max_tokens=max_tokens)
|
||||
|
||||
@ -167,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
|
||||
serialize_vllm_model(get_torch_model(vllm_model),
|
||||
TensorizerConfig(tensorizer_uri=model_path))
|
||||
TensorizerConfig(tensorizer_uri=model_path))
|
||||
|
||||
with vllm_runner(
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
),
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_num_seqs=50,
|
||||
max_model_len=1000,
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
),
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_num_seqs=50,
|
||||
max_model_len=1000,
|
||||
) as loaded_vllm_model:
|
||||
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
|
||||
|
||||
@ -189,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
|
||||
|
||||
def test_load_without_tensorizer_load_format(vllm_runner):
|
||||
model = None
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(
|
||||
model = vllm_runner(
|
||||
model_ref,
|
||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
@ -202,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
|
||||
serialize_vllm_model(get_torch_model(vllm_model),
|
||||
TensorizerConfig(tensorizer_uri=model_path))
|
||||
TensorizerConfig(tensorizer_uri=model_path))
|
||||
|
||||
model_loader_extra_config = {
|
||||
"tensorizer_uri": str(model_path),
|
||||
@ -220,9 +227,9 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
||||
|
||||
client = server.get_client()
|
||||
completion = client.completions.create(model=model_ref,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 1
|
||||
@ -233,11 +240,15 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
||||
|
||||
|
||||
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||
model = None
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(
|
||||
model = vllm_runner(
|
||||
model_ref,
|
||||
load_format="safetensors",
|
||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -259,22 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
|
||||
disable_custom_all_reduce=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Requires 2 GPUs")
|
||||
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
||||
tmp_path):
|
||||
model_ref = "EleutherAI/pythia-1.4b"
|
||||
# record outputs from un-sharded un-tensorized model
|
||||
base_model = vllm_runner(
|
||||
model_ref,
|
||||
disable_custom_all_reduce=True,
|
||||
enforce_eager=True,
|
||||
)
|
||||
outputs = base_model.generate(prompts, sampling_params)
|
||||
|
||||
base_model.model.llm_engine.model_executor.shutdown()
|
||||
del base_model
|
||||
cleanup()
|
||||
with vllm_runner(
|
||||
model_ref,
|
||||
disable_custom_all_reduce=True,
|
||||
enforce_eager=True,
|
||||
) as base_model:
|
||||
outputs = base_model.generate(prompts, sampling_params)
|
||||
base_model.model.llm_engine.model_executor.shutdown()
|
||||
|
||||
# load model with two shards and serialize with encryption
|
||||
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
|
||||
@ -287,32 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
||||
|
||||
tensorize_vllm_model(
|
||||
engine_args=EngineArgs(
|
||||
model=model_ref,
|
||||
tensor_parallel_size=2,
|
||||
disable_custom_all_reduce=True,
|
||||
enforce_eager=True,
|
||||
),
|
||||
model=model_ref,
|
||||
tensor_parallel_size=2,
|
||||
disable_custom_all_reduce=True,
|
||||
enforce_eager=True,
|
||||
),
|
||||
tensorizer_config=tensorizer_config,
|
||||
)
|
||||
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
|
||||
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
|
||||
cleanup()
|
||||
|
||||
loaded_vllm_model = vllm_runner(
|
||||
model_ref,
|
||||
tensor_parallel_size=2,
|
||||
load_format="tensorizer",
|
||||
disable_custom_all_reduce=True,
|
||||
enforce_eager=True,
|
||||
model_loader_extra_config=tensorizer_config)
|
||||
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
with vllm_runner(
|
||||
model_ref,
|
||||
tensor_parallel_size=2,
|
||||
load_format="tensorizer",
|
||||
disable_custom_all_reduce=True,
|
||||
enforce_eager=True,
|
||||
model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
||||
sampling_params)
|
||||
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
|
||||
@retry_until_skip(3)
|
||||
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
cleanup()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
model_ref = "facebook/opt-125m"
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
config = TensorizerConfig(tensorizer_uri=str(model_path))
|
||||
@ -324,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
assert is_vllm_tensorized(config)
|
||||
|
||||
with vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=config) as loaded_vllm_model:
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=config) as loaded_vllm_model:
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
||||
sampling_params)
|
||||
# noqa: E501
|
||||
|
||||
assert outputs == deserialized_outputs
|
||||
|
Loading…
x
Reference in New Issue
Block a user