2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-03-23 21:54:07 -04:00
|
|
|
import copy
|
2024-11-21 00:44:57 -05:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2025-03-19 19:06:49 -07:00
|
|
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
2024-11-21 00:44:57 -05:00
|
|
|
from vllm.compilation.pass_manager import PostGradPassManager
|
2025-03-19 19:06:49 -07:00
|
|
|
from vllm.config import CompilationConfig
|
2024-11-21 00:44:57 -05:00
|
|
|
|
|
|
|
|
2025-03-23 21:54:07 -04:00
|
|
|
# dummy custom pass that doesn't inherit
|
2024-11-21 00:44:57 -05:00
|
|
|
def simple_callable(graph: torch.fx.Graph):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2025-03-23 21:54:07 -04:00
|
|
|
# Should fail to add directly to the pass manager
|
|
|
|
def test_bad_callable():
|
|
|
|
config = CompilationConfig().pass_config
|
|
|
|
|
|
|
|
pass_manager = PostGradPassManager()
|
|
|
|
pass_manager.configure(config)
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
pass_manager.add(simple_callable) # noqa, type wrong on purpose
|
|
|
|
|
|
|
|
|
|
|
|
# Pass that inherits from InductorPass
|
|
|
|
class ProperPass(InductorPass):
|
|
|
|
|
|
|
|
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
|
|
|
pass
|
2024-11-21 00:44:57 -05:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2025-03-23 21:54:07 -04:00
|
|
|
"callable",
|
2025-03-19 19:06:49 -07:00
|
|
|
[
|
2025-03-23 21:54:07 -04:00
|
|
|
ProperPass(),
|
|
|
|
# Can also wrap callables in CallableInductorPass for compliance
|
|
|
|
CallableInductorPass(simple_callable),
|
|
|
|
CallableInductorPass(simple_callable,
|
|
|
|
InductorPass.hash_source(__file__))
|
2025-03-19 19:06:49 -07:00
|
|
|
],
|
|
|
|
)
|
2025-03-23 21:54:07 -04:00
|
|
|
def test_pass_manager_uuid(callable):
|
2024-11-21 00:44:57 -05:00
|
|
|
config = CompilationConfig().pass_config
|
|
|
|
|
2025-03-19 19:06:49 -07:00
|
|
|
pass_manager = PostGradPassManager()
|
|
|
|
pass_manager.configure(config)
|
|
|
|
|
2025-03-23 21:54:07 -04:00
|
|
|
# Check that UUID is different if the same pass is added 2x
|
|
|
|
pass_manager.add(callable)
|
|
|
|
uuid1 = pass_manager.uuid()
|
|
|
|
pass_manager.add(callable)
|
|
|
|
uuid2 = pass_manager.uuid()
|
|
|
|
assert uuid1 != uuid2
|
|
|
|
|
|
|
|
# UUID should be the same as the original one,
|
|
|
|
# as we constructed in the same way.
|
|
|
|
pass_manager2 = PostGradPassManager()
|
|
|
|
pass_manager2.configure(config)
|
|
|
|
pass_manager2.add(callable)
|
|
|
|
assert uuid1 == pass_manager2.uuid()
|
|
|
|
|
|
|
|
# UUID should be different due to config change
|
|
|
|
config2 = copy.deepcopy(config)
|
|
|
|
config2.enable_fusion = not config2.enable_fusion
|
|
|
|
pass_manager3 = PostGradPassManager()
|
|
|
|
pass_manager3.configure(config2)
|
|
|
|
pass_manager3.add(callable)
|
|
|
|
assert uuid1 != pass_manager3.uuid()
|