[Bugfix] Fix tensorizer memory profiling bug during testing (#6881)
This commit is contained in:
parent
5895b24677
commit
052b6f8ca4
@ -1,6 +1,5 @@
|
|||||||
# isort: skip_file
|
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -12,34 +11,38 @@ from vllm.distributed import (destroy_distributed_environment,
|
|||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
def cleanup():
|
def cleanup():
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
destroy_distributed_environment()
|
destroy_distributed_environment()
|
||||||
with contextlib.suppress(AssertionError):
|
with contextlib.suppress(AssertionError):
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
|
ray.shutdown()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
ray.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
def retry_until_skip(n):
|
||||||
def should_do_global_cleanup_after_test(request) -> bool:
|
|
||||||
"""Allow subdirectories to skip global cleanup by overriding this fixture.
|
|
||||||
This can provide a ~10x speedup for non-GPU unit tests since they don't need
|
|
||||||
to initialize torch.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return True
|
def decorator_retry(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper_retry(*args, **kwargs):
|
||||||
|
for i in range(n):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except AssertionError:
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if i == n - 1:
|
||||||
|
pytest.skip("Skipping test after attempts..")
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
return wrapper_retry
|
||||||
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
|
||||||
yield
|
return decorator_retry
|
||||||
if should_do_global_cleanup_after_test:
|
|
||||||
cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def tensorizer_config():
|
def tensorizer_config():
|
||||||
config = TensorizerConfig(tensorizer_uri="vllm")
|
config = TensorizerConfig(tensorizer_uri="vllm")
|
||||||
return config
|
return config
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import gc
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -20,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
|||||||
serialize_vllm_model,
|
serialize_vllm_model,
|
||||||
tensorize_vllm_model)
|
tensorize_vllm_model)
|
||||||
|
|
||||||
from ..conftest import VllmRunner, cleanup
|
from ..conftest import VllmRunner
|
||||||
from ..utils import RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
from .conftest import retry_until_skip
|
||||||
|
|
||||||
# yapf conflicts with isort for this docstring
|
# yapf conflicts with isort for this docstring
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
@ -40,6 +41,7 @@ model_ref = "facebook/opt-125m"
|
|||||||
tensorize_model_for_testing_script = os.path.join(
|
tensorize_model_for_testing_script = os.path.join(
|
||||||
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
|
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
|
||||||
|
|
||||||
|
|
||||||
def is_curl_installed():
|
def is_curl_installed():
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(['curl', '--version'])
|
subprocess.check_call(['curl', '--version'])
|
||||||
@ -47,14 +49,16 @@ def is_curl_installed():
|
|||||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_torch_model(vllm_runner: VllmRunner):
|
def get_torch_model(vllm_runner: VllmRunner):
|
||||||
return vllm_runner \
|
return vllm_runner \
|
||||||
.model \
|
.model \
|
||||||
.llm_engine \
|
.llm_engine \
|
||||||
.model_executor \
|
.model_executor \
|
||||||
.driver_worker \
|
.driver_worker \
|
||||||
.model_runner \
|
.model_runner \
|
||||||
.model
|
.model
|
||||||
|
|
||||||
|
|
||||||
def write_keyfile(keyfile_path: str):
|
def write_keyfile(keyfile_path: str):
|
||||||
encryption_params = EncryptionParams.random()
|
encryption_params = EncryptionParams.random()
|
||||||
@ -63,7 +67,6 @@ def write_keyfile(keyfile_path: str):
|
|||||||
f.write(encryption_params.key)
|
f.write(encryption_params.key)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
|
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
|
||||||
def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||||
mock_linear_method = MagicMock()
|
mock_linear_method = MagicMock()
|
||||||
@ -85,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner):
|
|||||||
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
||||||
|
|
||||||
with vllm_runner(model_ref,
|
with vllm_runner(model_ref,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(
|
||||||
tensorizer_uri=tensorized_path,
|
tensorizer_uri=tensorized_path,
|
||||||
num_readers=1,
|
num_readers=1,
|
||||||
s3_endpoint="object.ord1.coreweave.com",
|
s3_endpoint="object.ord1.coreweave.com",
|
||||||
)) as loaded_hf_model:
|
)) as loaded_hf_model:
|
||||||
|
deserialized_outputs = loaded_hf_model.generate(prompts,
|
||||||
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501
|
sampling_params)
|
||||||
|
# noqa: E501
|
||||||
|
|
||||||
assert deserialized_outputs
|
assert deserialized_outputs
|
||||||
|
|
||||||
@ -100,7 +104,6 @@ def test_can_deserialize_s3(vllm_runner):
|
|||||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||||
vllm_runner, tmp_path):
|
vllm_runner, tmp_path):
|
||||||
cleanup()
|
|
||||||
with vllm_runner(model_ref) as vllm_model:
|
with vllm_runner(model_ref) as vllm_model:
|
||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
key_path = tmp_path / (model_ref + ".key")
|
key_path = tmp_path / (model_ref + ".key")
|
||||||
@ -113,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
|||||||
encryption_keyfile=key_path
|
encryption_keyfile=key_path
|
||||||
)
|
)
|
||||||
serialize_vllm_model(get_torch_model(vllm_model),
|
serialize_vllm_model(get_torch_model(vllm_model),
|
||||||
config_for_serializing)
|
config_for_serializing)
|
||||||
|
|
||||||
|
|
||||||
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||||
encryption_keyfile=key_path)
|
encryption_keyfile=key_path)
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
|
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
|
||||||
|
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
|
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
||||||
|
sampling_params)
|
||||||
|
# noqa: E501
|
||||||
|
|
||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
@ -140,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
|||||||
serializer.write_module(hf_model.model)
|
serializer.write_module(hf_model.model)
|
||||||
|
|
||||||
with vllm_runner(model_ref,
|
with vllm_runner(model_ref,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(
|
||||||
tensorizer_uri=model_path,
|
tensorizer_uri=model_path,
|
||||||
num_readers=1,
|
num_readers=1,
|
||||||
)) as loaded_hf_model:
|
)) as loaded_hf_model:
|
||||||
|
|
||||||
deserialized_outputs = loaded_hf_model.generate_greedy(
|
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||||
prompts, max_tokens=max_tokens)
|
prompts, max_tokens=max_tokens)
|
||||||
|
|
||||||
@ -167,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
|||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
|
||||||
serialize_vllm_model(get_torch_model(vllm_model),
|
serialize_vllm_model(get_torch_model(vllm_model),
|
||||||
TensorizerConfig(tensorizer_uri=model_path))
|
TensorizerConfig(tensorizer_uri=model_path))
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(
|
||||||
tensorizer_uri=model_path,
|
tensorizer_uri=model_path,
|
||||||
num_readers=1,
|
num_readers=1,
|
||||||
),
|
),
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_loras=1,
|
max_loras=1,
|
||||||
max_lora_rank=8,
|
max_lora_rank=8,
|
||||||
max_cpu_loras=2,
|
max_cpu_loras=2,
|
||||||
max_num_seqs=50,
|
max_num_seqs=50,
|
||||||
max_model_len=1000,
|
max_model_len=1000,
|
||||||
) as loaded_vllm_model:
|
) as loaded_vllm_model:
|
||||||
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
|
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
|
||||||
|
|
||||||
@ -189,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
def test_load_without_tensorizer_load_format(vllm_runner):
|
def test_load_without_tensorizer_load_format(vllm_runner):
|
||||||
|
model = None
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
vllm_runner(
|
model = vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
@ -202,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
|||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
|
||||||
serialize_vllm_model(get_torch_model(vllm_model),
|
serialize_vllm_model(get_torch_model(vllm_model),
|
||||||
TensorizerConfig(tensorizer_uri=model_path))
|
TensorizerConfig(tensorizer_uri=model_path))
|
||||||
|
|
||||||
model_loader_extra_config = {
|
model_loader_extra_config = {
|
||||||
"tensorizer_uri": str(model_path),
|
"tensorizer_uri": str(model_path),
|
||||||
@ -220,9 +227,9 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
|||||||
|
|
||||||
client = server.get_client()
|
client = server.get_client()
|
||||||
completion = client.completions.create(model=model_ref,
|
completion = client.completions.create(model=model_ref,
|
||||||
prompt="Hello, my name is",
|
prompt="Hello, my name is",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert len(completion.choices) == 1
|
assert len(completion.choices) == 1
|
||||||
@ -233,11 +240,15 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||||
|
model = None
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
vllm_runner(
|
model = vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
load_format="safetensors",
|
load_format="safetensors",
|
||||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
@ -259,22 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
|
|||||||
disable_custom_all_reduce=True,
|
disable_custom_all_reduce=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
reason="Requires 2 GPUs")
|
reason="Requires 2 GPUs")
|
||||||
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
||||||
tmp_path):
|
tmp_path):
|
||||||
model_ref = "EleutherAI/pythia-1.4b"
|
model_ref = "EleutherAI/pythia-1.4b"
|
||||||
# record outputs from un-sharded un-tensorized model
|
# record outputs from un-sharded un-tensorized model
|
||||||
base_model = vllm_runner(
|
with vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
disable_custom_all_reduce=True,
|
disable_custom_all_reduce=True,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
)
|
) as base_model:
|
||||||
outputs = base_model.generate(prompts, sampling_params)
|
outputs = base_model.generate(prompts, sampling_params)
|
||||||
|
base_model.model.llm_engine.model_executor.shutdown()
|
||||||
base_model.model.llm_engine.model_executor.shutdown()
|
|
||||||
del base_model
|
|
||||||
cleanup()
|
|
||||||
|
|
||||||
# load model with two shards and serialize with encryption
|
# load model with two shards and serialize with encryption
|
||||||
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
|
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
|
||||||
@ -287,32 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
|||||||
|
|
||||||
tensorize_vllm_model(
|
tensorize_vllm_model(
|
||||||
engine_args=EngineArgs(
|
engine_args=EngineArgs(
|
||||||
model=model_ref,
|
model=model_ref,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
disable_custom_all_reduce=True,
|
disable_custom_all_reduce=True,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
),
|
),
|
||||||
tensorizer_config=tensorizer_config,
|
tensorizer_config=tensorizer_config,
|
||||||
)
|
)
|
||||||
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
|
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
|
||||||
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
|
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
|
||||||
cleanup()
|
|
||||||
|
|
||||||
loaded_vllm_model = vllm_runner(
|
with vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
disable_custom_all_reduce=True,
|
disable_custom_all_reduce=True,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
model_loader_extra_config=tensorizer_config)
|
model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
|
||||||
|
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
sampling_params)
|
||||||
|
|
||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@retry_until_skip(3)
|
||||||
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||||
cleanup()
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
model_ref = "facebook/opt-125m"
|
model_ref = "facebook/opt-125m"
|
||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
config = TensorizerConfig(tensorizer_uri=str(model_path))
|
config = TensorizerConfig(tensorizer_uri=str(model_path))
|
||||||
@ -324,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
|||||||
assert is_vllm_tensorized(config)
|
assert is_vllm_tensorized(config)
|
||||||
|
|
||||||
with vllm_runner(model_ref,
|
with vllm_runner(model_ref,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=config) as loaded_vllm_model:
|
model_loader_extra_config=config) as loaded_vllm_model:
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
|
deserialized_outputs = loaded_vllm_model.generate(prompts,
|
||||||
|
sampling_params)
|
||||||
|
# noqa: E501
|
||||||
|
|
||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user