[CI/Build] Move test_utils.py to tests/utils.py (#4425)

Since #4335 was merged, I've noticed that the definition of ServerRunner in the tests is the same as in the test for OpenAI API. I have moved the class to the test utilities to avoid code duplication. (Although it only has been repeated twice so far, I will add another similar test suite in #4200 which would duplicate the code a third time)

Also, I have moved the test utilities file (test_utils.py) to under the test directory (tests/utils.py), since none of its code is actually used in the main package. Note that I have added __init__.py to each test subpackage and updated the ray.init() call in the test utilities file in order to relative import tests/utils.py.
This commit is contained in:
Cyrus Leung 2024-05-13 22:50:09 +08:00 committed by GitHub
parent 702bee461f
commit 350f9e107f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 138 additions and 164 deletions

View File

@ -24,28 +24,26 @@ steps:
command: pytest -v -s core
- label: Distributed Comm Ops Test
command: pytest -v -s test_comm_ops.py
working_dir: "/vllm-workspace/tests/distributed"
command: pytest -v -s distributed/test_comm_ops.py
working_dir: "/vllm-workspace/tests"
num_gpus: 2
- label: Distributed Tests
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
working_dir: "/vllm-workspace/tests"
num_gpus: 2
mirror_hardwares: [amd]
commands:
- pytest -v -s test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py
- label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests/distributed"
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- pytest -v -s test_pynccl.py
- pytest -v -s distributed/test_pynccl.py
- label: Engine Test
#mirror_hardwares: [amd]

View File

View File

@ -1,61 +1,16 @@
# imports for guided decoding tests
import os
import subprocess
import sys
import time
import openai # use the official client for correctness check
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
from ..utils import ServerRunner
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
@ray.remote(num_gpus=1)
class ServerRunner:
def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError(
"Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def server():
ray.init()
server_runner = ServerRunner.remote([

View File

View File

View File

@ -1,9 +1,10 @@
import pytest
from tests.conftest import cleanup
from vllm import LLM
from vllm.model_executor.utils import set_random_seed
from ....conftest import cleanup
@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,

View File

View File

@ -4,10 +4,12 @@ by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
cd $VLLM_PATH/tests
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_basic_distributed_correctness.py
distributed/test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_basic_distributed_correctness.py
distributed/test_basic_distributed_correctness.py
```
"""
import os

View File

@ -11,8 +11,9 @@ import torch
from vllm.distributed import (broadcast_tensor_dict,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
@ray.remote(num_gpus=1, max_calls=1)

View File

@ -10,8 +10,9 @@ from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_ca_communicator)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]

0
tests/engine/__init__.py Normal file
View File

View File

@ -4,7 +4,6 @@ from unittest.mock import MagicMock
import pytest
from transformers import PreTrainedTokenizer
from tests.core.utils import create_seq_group
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
@ -14,6 +13,8 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
from ...core.utils import create_seq_group
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [1, 12])

View File

View File

@ -1,10 +1,6 @@
# imports for guided decoding tests
import json
import os
import re
import subprocess
import sys
import time
import jsonschema
import openai # use the official client for correctness check
@ -12,7 +8,6 @@ import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
@ -20,7 +15,8 @@ from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
from ..utils import ServerRunner
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@ -78,45 +74,6 @@ TEST_CHOICE = [
pytestmark = pytest.mark.asyncio
@ray.remote(num_gpus=1)
class ServerRunner:
def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError(
"Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
@pytest.fixture(scope="session")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)

View File

View File

@ -2,11 +2,12 @@ from typing import Type
import pytest
import torch
from allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, SiluAndMul)
from .allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing

View File

@ -3,13 +3,14 @@ from typing import List, Optional, Tuple
import pytest
import torch
from allclose_default import get_default_atol, get_default_rtol
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from .allclose_default import get_default_atol, get_default_rtol
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer

View File

@ -3,10 +3,11 @@ from typing import List, Optional
import pytest
import torch
from allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope
from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]

View File

View File

0
tests/models/__init__.py Normal file
View File

View File

@ -13,9 +13,10 @@ import os
import pytest
import torch
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from .utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024

View File

@ -15,9 +15,10 @@ from dataclasses import dataclass
import pytest
import torch
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from .utils import check_logprobs_close
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (capability <

View File

@ -4,7 +4,7 @@ Run `pytest tests/models/test_mistral.py`.
"""
import pytest
from tests.models.utils import check_logprobs_close
from .utils import check_logprobs_close
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",

View File

View File

View File

View File

@ -1,9 +1,10 @@
import pytest
import torch
from tests.conftest import VllmRunner
from vllm import SamplingParams
from ..conftest import VllmRunner
MODELS = ["facebook/opt-125m"]

View File

@ -9,7 +9,6 @@ import torch
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)
from tests.conftest import cleanup
from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -21,6 +20,8 @@ from vllm.sequence import Logprob, MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid
from ...conftest import cleanup
class AsyncLLM:
"""AsyncLLM

View File

@ -9,12 +9,13 @@ import pytest
import ray
import torch
from tests.entrypoints.test_openai_server import ServerRunner
from vllm import SamplingParams
from vllm.model_executor.model_loader.tensorizer import (
EncryptionParams, TensorizerConfig, TensorSerializer,
is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream)
from ..utils import ServerRunner
prompts = [
"Hello, my name is",
"The president of the United States is",

View File

@ -1,9 +1,10 @@
import pytest
from tests.core.utils import create_dummy_prompt
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt
@pytest.fixture
def sample_outputs():

89
tests/utils.py Normal file
View File

@ -0,0 +1,89 @@
import os
import subprocess
import sys
import time
import ray
import requests
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_open_port
# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
@ray.remote(num_gpus=1)
class ServerRunner:
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > self.MAX_SERVER_START_WAIT_S:
raise RuntimeError(
"Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
def init_test_distributed_environment(
tp_size: int,
pp_size: int,
rank: int,
distributed_init_port: str,
local_rank: int = -1,
) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank)
ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_tensor_parallel(
tp_size: int,
pp_size: int,
test_target,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
ray.init(runtime_env={"working_dir": VLLM_PATH})
distributed_init_port = get_open_port()
refs = []
for rank in range(tp_size * pp_size):
refs.append(
test_target.remote(tp_size, pp_size, rank, distributed_init_port))
ray.get(refs)
ray.shutdown()

View File

@ -1,40 +0,0 @@
import ray
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_open_port
def init_test_distributed_environment(
tp_size: int,
pp_size: int,
rank: int,
distributed_init_port: str,
local_rank: int = -1,
) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank)
ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_tensor_parallel(
tp_size: int,
pp_size: int,
test_target,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
ray.init()
distributed_init_port = get_open_port()
refs = []
for rank in range(tp_size * pp_size):
refs.append(
test_target.remote(tp_size, pp_size, rank, distributed_init_port))
ray.get(refs)
ray.shutdown()