Fix broken tests (#14713)

Signed-off-by: JovanSardinha <jovan.sardinha@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jovan Sardinha 2025-03-19 19:06:49 -07:00 committed by GitHub
parent 4cb1c05c9e
commit 70e500cad9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 14 deletions

View File

@ -295,6 +295,7 @@ steps:
# these tests need to be separated, cannot combine
- pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py
- pytest -v -s compile/test_pass_manager.py
- label: PyTorch Fullgraph Test # 18min
source_file_dependencies:

View File

@ -30,7 +30,7 @@ matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test
transformers==4.48.2
transformers==4.48.2
# quantization
bitsandbytes>=0.45.3
buildkite-test-collector==0.1.9

View File

@ -4,34 +4,38 @@ import pickle
import pytest
import torch
from torch._inductor.codecache import BypassFxGraphCache
from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import (CallableInductorPass,
as_inductor_pass)
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig
def simple_callable(graph: torch.fx.Graph):
pass
@as_inductor_pass(files=(__file__, ))
def callable_decorated(graph: torch.fx.Graph):
pass
callable_uuid = CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
@pytest.mark.parametrize(
"works, callable",
[(False, simple_callable), (True, callable_decorated),
(True, CallableInductorPass(simple_callable, "simple_callable"))])
[
(False, simple_callable),
(True, callable_uuid),
(True, CallableInductorPass(simple_callable)),
],
)
def test_pass_manager(works: bool, callable):
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager([callable])
pass_manager.configure(config) # Adds default passes
pass_manager = PostGradPassManager()
pass_manager.configure(config)
# Try to add the callable to the pass manager
if works:
pass_manager.add(callable)
pickle.dumps(pass_manager)
else:
with pytest.raises(BypassFxGraphCache):
pickle.dumps(pass_manager)
with pytest.raises(AssertionError):
pass_manager.add(callable)