Robert Shaw d4d93db2c5
[V1] V1 Enablement Oracle (#13726)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2025-03-14 22:02:20 -07:00

58 lines
1.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import functools
import gc
from typing import Callable, TypeVar
import pytest
import torch
from typing_extensions import ParamSpec
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Tensorizer only tested on V0 so far.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.fixture(autouse=True)
def cleanup():
cleanup_dist_env_and_memory(shutdown_ray=True)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def retry_until_skip(n: int):
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(func)
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
for i in range(n):
try:
return func(*args, **kwargs)
except AssertionError:
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip(f"Skipping test after {n} attempts.")
raise AssertionError("Code should not be reached")
return wrapper_retry
return decorator_retry
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config