[Bugfix] Fix tensorizer memory profiling bug during testing (#6881)

This commit is contained in:
Sanger Steel 2024-07-30 14:48:50 -04:00 committed by GitHub
parent 5895b24677
commit 052b6f8ca4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 110 additions and 94 deletions

View File

@ -1,6 +1,5 @@
# isort: skip_file
import contextlib import contextlib
import functools
import gc import gc
import pytest import pytest
@ -12,34 +11,38 @@ from vllm.distributed import (destroy_distributed_environment,
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@pytest.fixture(autouse=True)
def cleanup(): def cleanup():
destroy_model_parallel() destroy_model_parallel()
destroy_distributed_environment() destroy_distributed_environment()
with contextlib.suppress(AssertionError): with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
ray.shutdown()
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
ray.shutdown()
@pytest.fixture() def retry_until_skip(n):
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
return True def decorator_retry(func):
@functools.wraps(func)
def wrapper_retry(*args, **kwargs):
for i in range(n):
try:
return func(*args, **kwargs)
except AssertionError:
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip("Skipping test after attempts..")
@pytest.fixture(autouse=True) return wrapper_retry
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield return decorator_retry
if should_do_global_cleanup_after_test:
cleanup()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def tensorizer_config(): def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm") config = TensorizerConfig(tensorizer_uri="vllm")
return config return config

View File

@ -1,3 +1,4 @@
import gc
import json import json
import os import os
import pathlib import pathlib
@ -20,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
serialize_vllm_model, serialize_vllm_model,
tensorize_vllm_model) tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup from ..conftest import VllmRunner
from ..utils import RemoteOpenAIServer from ..utils import RemoteOpenAIServer
from .conftest import retry_until_skip
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
@ -40,6 +41,7 @@ model_ref = "facebook/opt-125m"
tensorize_model_for_testing_script = os.path.join( tensorize_model_for_testing_script = os.path.join(
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
def is_curl_installed(): def is_curl_installed():
try: try:
subprocess.check_call(['curl', '--version']) subprocess.check_call(['curl', '--version'])
@ -47,14 +49,16 @@ def is_curl_installed():
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
return False return False
def get_torch_model(vllm_runner: VllmRunner): def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \ return vllm_runner \
.model \ .model \
.llm_engine \ .llm_engine \
.model_executor \ .model_executor \
.driver_worker \ .driver_worker \
.model_runner \ .model_runner \
.model .model
def write_keyfile(keyfile_path: str): def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random() encryption_params = EncryptionParams.random()
@ -63,7 +67,6 @@ def write_keyfile(keyfile_path: str):
f.write(encryption_params.key) f.write(encryption_params.key)
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent') @patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config): def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock() mock_linear_method = MagicMock()
@ -85,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner):
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path, tensorizer_uri=tensorized_path,
num_readers=1, num_readers=1,
s3_endpoint="object.ord1.coreweave.com", s3_endpoint="object.ord1.coreweave.com",
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(prompts,
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501 sampling_params)
# noqa: E501
assert deserialized_outputs assert deserialized_outputs
@ -100,7 +104,6 @@ def test_can_deserialize_s3(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs( def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path): vllm_runner, tmp_path):
cleanup()
with vllm_runner(model_ref) as vllm_model: with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key") key_path = tmp_path / (model_ref + ".key")
@ -113,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
encryption_keyfile=key_path encryption_keyfile=key_path
) )
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing) config_for_serializing)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path) encryption_keyfile=key_path)
with vllm_runner( with vllm_runner(
model_ref, model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501 model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
# noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
@ -140,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
serializer.write_module(hf_model.model) serializer.write_module(hf_model.model)
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
num_readers=1, num_readers=1,
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy( deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens) prompts, max_tokens=max_tokens)
@ -167,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
with vllm_runner( with vllm_runner(
model_ref, model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
num_readers=1, num_readers=1,
), ),
enable_lora=True, enable_lora=True,
max_loras=1, max_loras=1,
max_lora_rank=8, max_lora_rank=8,
max_cpu_loras=2, max_cpu_loras=2,
max_num_seqs=50, max_num_seqs=50,
max_model_len=1000, max_model_len=1000,
) as loaded_vllm_model: ) as loaded_vllm_model:
process_requests(loaded_vllm_model.model.llm_engine, test_prompts) process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
@ -189,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
def test_load_without_tensorizer_load_format(vllm_runner): def test_load_without_tensorizer_load_format(vllm_runner):
model = None
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner( model = vllm_runner(
model_ref, model_ref,
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
@ -202,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
model_loader_extra_config = { model_loader_extra_config = {
"tensorizer_uri": str(model_path), "tensorizer_uri": str(model_path),
@ -220,9 +227,9 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
client = server.get_client() client = server.get_client()
completion = client.completions.create(model=model_ref, completion = client.completions.create(model=model_ref,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
temperature=0.0) temperature=0.0)
assert completion.id is not None assert completion.id is not None
assert len(completion.choices) == 1 assert len(completion.choices) == 1
@ -233,11 +240,15 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
def test_raise_value_error_on_invalid_load_format(vllm_runner): def test_raise_value_error_on_invalid_load_format(vllm_runner):
model = None
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner( model = vllm_runner(
model_ref, model_ref,
load_format="safetensors", load_format="safetensors",
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -259,22 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
) )
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs") reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner, def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path): tmp_path):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model # record outputs from un-sharded un-tensorized model
base_model = vllm_runner( with vllm_runner(
model_ref, model_ref,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
) ) as base_model:
outputs = base_model.generate(prompts, sampling_params) outputs = base_model.generate(prompts, sampling_params)
base_model.model.llm_engine.model_executor.shutdown()
base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
# load model with two shards and serialize with encryption # load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors")) model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
@ -287,32 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tensorize_vllm_model( tensorize_vllm_model(
engine_args=EngineArgs( engine_args=EngineArgs(
model=model_ref, model=model_ref,
tensor_parallel_size=2, tensor_parallel_size=2,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
), ),
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
) )
assert os.path.isfile(model_path % 0), "Serialization subprocess failed" assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed" assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup()
loaded_vllm_model = vllm_runner( with vllm_runner(
model_ref, model_ref,
tensor_parallel_size=2, tensor_parallel_size=2,
load_format="tensorizer", load_format="tensorizer",
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
model_loader_extra_config=tensorizer_config) model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts,
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) sampling_params)
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
@retry_until_skip(3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
cleanup() gc.collect()
torch.cuda.empty_cache()
model_ref = "facebook/opt-125m" model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path)) config = TensorizerConfig(tensorizer_uri=str(model_path))
@ -324,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
assert is_vllm_tensorized(config) assert is_vllm_tensorized(config)
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config) as loaded_vllm_model: model_loader_extra_config=config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
# noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs