
Signed-off-by: JovanSardinha <jovan.sardinha@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
42 lines
1.0 KiB
Python
42 lines
1.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pickle
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
|
from vllm.compilation.pass_manager import PostGradPassManager
|
|
from vllm.config import CompilationConfig
|
|
|
|
|
|
def simple_callable(graph: torch.fx.Graph):
|
|
pass
|
|
|
|
|
|
callable_uuid = CallableInductorPass(simple_callable,
|
|
InductorPass.hash_source(__file__))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"works, callable",
|
|
[
|
|
(False, simple_callable),
|
|
(True, callable_uuid),
|
|
(True, CallableInductorPass(simple_callable)),
|
|
],
|
|
)
|
|
def test_pass_manager(works: bool, callable):
|
|
config = CompilationConfig().pass_config
|
|
|
|
pass_manager = PostGradPassManager()
|
|
pass_manager.configure(config)
|
|
|
|
# Try to add the callable to the pass manager
|
|
if works:
|
|
pass_manager.add(callable)
|
|
pickle.dumps(pass_manager)
|
|
else:
|
|
with pytest.raises(AssertionError):
|
|
pass_manager.add(callable)
|