[torch.compile] Inductor code caching fix (#10273)

Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
This commit is contained in:
Luka Govedič 2024-11-21 00:44:57 -05:00 committed by GitHub
parent 9d827170a3
commit 8b0fe06c89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 602 additions and 286 deletions

View File

@ -1,7 +1,9 @@
from copy import deepcopy
from typing import Callable
from typing import Callable, Union
import torch
from torch import fx
from vllm.compilation.inductor_pass import InductorPass
class TestBackend:
@ -11,19 +13,21 @@ class TestBackend:
It also saves the graph before and after the custom passes for inspection.
"""
def __init__(self, *args: Callable[[torch.fx.Graph], None]):
self.custom_passes = args
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
self.custom_passes = list(passes)
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass
def __call__(self, graph: torch.fx.GraphModule, example_inputs):
def __call__(self, graph: fx.GraphModule, example_inputs):
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.current_config)
def post_pass(self, graph: torch.fx.Graph):
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
for pass_ in self.custom_passes:
pass_(graph)

View File

@ -0,0 +1,95 @@
import pytest
import torch
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.compilation.vllm_inductor_pass import is_func
from vllm.config import CompilationConfig
from .backend import TestBackend
OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.silu_and_mul.default,
]
RMS_OP = torch.ops._C.rms_norm.default
RMS_QUANT_OPS = {
"static_fp8": [
torch.ops._C.rms_norm_static_fp8_quant.default,
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
],
}
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
@pytest.mark.parametrize("model",
["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"])
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fix_functionalization(model: str, do_fusion: bool):
torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
# instantiate a full engine and manually compile the model 2x
# (with and without FixFunctionalizationPass)
llm = LLM(model=model, enforce_eager=True)
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
orig_model = model_runner.model
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
# Can only do that by using the decorator but then we'd have to instantiate
# 2 LLM instances.
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_func)
gen_func = llm.generate(prompts, sampling_params)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)
gen_no_func = llm.generate(prompts, sampling_params)
for output_func, output_no_func in zip(gen_func, gen_no_func):
assert output_func.outputs[0].text == output_no_func.outputs[0].text
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"]
if do_fusion else [RMS_OP])
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in ops:
if is_func(node, op):
found[op] = True
assert all(found[op] for op in ops)

View File

@ -38,12 +38,6 @@ class TestModel(torch.nn.Module):
return y3
# Init does pattern registration, which can only happen once
config = CompilationConfig(enable_fusion=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@ -58,6 +52,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
pytest.skip("Only test eps=1e-5 for now")
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)

View File

@ -0,0 +1,35 @@
import pickle
import pytest
import torch
from torch._inductor.codecache import BypassFxGraphCache
from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import (CallableInductorPass,
as_inductor_pass)
from vllm.compilation.pass_manager import PostGradPassManager
def simple_callable(graph: torch.fx.Graph):
pass
@as_inductor_pass(files=(__file__, ))
def callable_decorated(graph: torch.fx.Graph):
pass
@pytest.mark.parametrize(
"works, callable",
[(False, simple_callable), (True, callable_decorated),
(True, CallableInductorPass(simple_callable, "simple_callable"))])
def test_pass_manager(works: bool, callable):
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager([callable])
pass_manager.configure(config) # Adds default passes
if works:
pickle.dumps(pass_manager)
else:
with pytest.raises(BypassFxGraphCache):
pickle.dumps(pass_manager)

View File

@ -1,6 +1,5 @@
import copy
import dataclasses
import operator
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch
@ -11,205 +10,15 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors
from vllm.utils import weak_ref_tensors
from .counter import compilation_counter
from .fusion import FusionPass
from .reshapes import RedundantReshapesPass
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
logger = init_logger(__name__)
def fix_functionalization(graph: fx.Graph):
"""
Rewrite the graph module to replace the pattern involving
torch._higher_order_ops.auto_functionalize.auto_functionalized
with a direct call to the inplace custom op.
# TODO: check if PyTorch nightly has fixed this issue
"""
# debug code, if we want to see the graph before the transformation
# with open("before.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
nodes_to_remove = []
for node in graph.nodes:
# Identify the auto_functionalized node
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
if node.args[0] == torch.ops._C.rotary_embedding.default:
# manual replace for rotary_embedding
# Now, collect the arguments
kwargs = node.kwargs
query = kwargs['query']
mm_node = query.args[0].args[0]
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(torch.ops._C.rotary_embedding.default,
kwargs=kwargs)
# Remove the auto_functionalized node
# Since the node may have outputs, we need to handle its users
# Replace uses of the outputs (getitem nodes) with mm_node
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
for getitem_user in list(user.users):
if (getitem_user.op == 'call_function'
and getitem_user.target
== torch.ops.aten.slice_scatter.default):
# Replace the uses of slice_scatter node
# with mm_node
getitem_user.replace_all_uses_with(mm_node)
nodes_to_remove.append(getitem_user)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
# manual replace for fused_add_rms_norm
# this is the most effective optimization for llama
# failing to do this will result in many unnecessary copies
kwargs = node.kwargs
input = kwargs['input']
residual = kwargs['residual']
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
if user.args[1] == 1:
replace_node = input
elif user.args[1] == 2:
replace_node = residual
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif (node.args[0] ==
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default):
# manual replace for fused_add_rms_norm_static_fp8_quant
# this is the most effective optimization for llama
# failing to do this will result in many unnecessary copies
kwargs = node.kwargs
result = kwargs['result']
residual = kwargs['residual']
# Create a new call to
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.fused_add_rms_norm_static_fp8_quant.
default,
kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
if user.args[1] == 1:
replace_node = result
elif user.args[1] == 2:
replace_node = residual
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.rms_norm.default:
# manual replace for rms_norm
kwargs = node.kwargs
replace_node = kwargs['result']
# Create a new call to torch.ops._C.rms_norm.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(torch.ops._C.rms_norm.default,
kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[
0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa
# manual replace for rms_norm_static_fp8_quant
kwargs = node.kwargs
replace_node = kwargs['result']
# Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.rms_norm_static_fp8_quant.default,
kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.silu_and_mul.default:
# manual replace for silu_and_mul
kwargs = node.kwargs
input = kwargs['input']
out = kwargs['out']
# Create a new call to torch.ops._C.silu_and_mul.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.silu_and_mul.default,
args=(out, input),
)
replace_node = out
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
# Remove the nodes all at once
for node in nodes_to_remove:
graph.erase_node(node)
# debug code, if we want to see the graph after the transformation
# with open("after.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
@ -368,12 +177,8 @@ class VllmBackend:
The major work of this backend is to split the graph into
piecewise graphs, and pass them to the piecewise backend.
This backend also handles custom passes and adds them to Inductor config.
The order of the post-grad post-passes is:
1. post_grad_passes (constructor parameter)
2. config["post_grad_custom_post_pass"]
3. fix_functionalization
This way, all passes operate on a functionalized graph.
This backend also adds the PostGradPassManager to Inductor config,
which handles the post-grad passes.
"""
compilation_configs: CompilationConfig
@ -402,7 +207,9 @@ class VllmBackend:
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
self.post_grad_passes = []
# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()
self.sym_tensor_indices = []
self.input_buffers = []
@ -412,24 +219,19 @@ class VllmBackend:
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def add_passes_to_config(self):
def configure_post_pass(self):
config = self.compilation_configs
passes = list(self.post_grad_passes)
passes = passes + [RedundantReshapesPass(config)]
if config.enable_fusion:
passes = passes + [FusionPass.instance(config)]
self.post_grad_pass_manager.configure(config.pass_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config
if "post_grad_custom_post_pass" in inductor_config:
passes = passes + [inductor_config["post_grad_custom_post_pass"]]
# add the fix_functionalization pass last, so that all other
# passes operate on a functionalized graph
passes = passes + [fix_functionalization]
combined_pass = combine_fx_passes(passes)
inductor_config["post_grad_custom_post_pass"] = combined_pass
PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
@ -444,7 +246,7 @@ class VllmBackend:
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs.init_during_runtime()
self.add_passes_to_config()
self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.splitting_ops)

View File

@ -0,0 +1,177 @@
import operator
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass, is_func
logger = init_logger(__name__)
class FixFunctionalizationPass(VllmInductorPass):
"""
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
After this pass, DCE (dead-code elimination) should never be run,
as de-functionalized nodes may appear as dead code.
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
self.nodes_to_remove: List[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs['query']
mm_node = query.args[0].args[0]
# rotary_embedding is a special case: the two mutating inputs
# are query and key, which are slices of mm_node.
# While functionalized, results at[1] and at[2] are scattered
# back into mm_node. After de-functionalization, we can just
# use mm_node directly.
for idx, user in self.getitem_users(node).items():
for user_of_getitem in user.users:
if is_func(user_of_getitem,
torch.ops.aten.slice_scatter.default):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
# These 2 replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: 'input', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default
]:
mutated_args = {1: 'result'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: 'out'}
# Because we have an 'out', need to specify args directly
self.defunctionalize(graph,
node,
mutated_args,
args=('out', 'input'))
else:
continue # skip the count
count += 1
self.dump_graph(graph, "before_fix_functionalization_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
count_removed)
self.dump_graph(graph, "after_fix_functionalization")
self.end_and_log()
def _remove(self, node_or_nodes: Union[torch.fx.Node,
Iterable[torch.fx.Node]]):
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
if isinstance(node_or_nodes, torch.fx.Node):
self.nodes_to_remove.append(node_or_nodes)
else:
self.nodes_to_remove.extend(node_or_nodes)
def defunctionalize(self,
graph: torch.fx.Graph,
node: torch.fx.Node,
mutated_args: Dict[int, Union[torch.fx.Node, str]],
args: Optional[Tuple[Union[torch.fx.Node, str],
...]] = None):
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
See replace_users_with_mutated_args and insert_defunctionalized.
"""
self.replace_users_with_mutated_args(node, mutated_args)
self.insert_defunctionalized(graph, node, args=args)
self._remove(node)
def replace_users_with_mutated_args(self, node: torch.fx.Node,
mutated_args: Dict[int,
Union[torch.fx.Node,
str]]):
"""
Replace all getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
arg = mutated_args[idx]
arg = node.kwargs[arg] if isinstance(arg, str) else arg
user.replace_all_uses_with(arg)
self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
"""
users = {}
for user in node.users:
if is_func(user, operator.getitem):
idx = user.args[1]
users[idx] = user
return users
def insert_defunctionalized(self,
graph: torch.fx.Graph,
node: torch.fx.Node,
args: Optional[Tuple[Union[torch.fx.Node, str],
...]] = None):
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), \
f"node must be auto-functionalized, is {node} instead"
# Create a new call to the original function
with graph.inserting_before(node):
function = node.args[0]
if args is None:
graph.call_function(function, kwargs=node.kwargs)
else:
# Args passed as strings refer to items in node.kwargs
args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
for arg in args)
graph.call_function(function, args=args)

View File

@ -6,10 +6,11 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass, is_func
logger = init_logger(__name__)
@ -90,8 +91,6 @@ def empty_fp32(*args, **kwargs):
# Utilities for post-processing multi-output matches
def is_func(node: torch.fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
# Returns the first auto_functionalized node with the given op (if it exists)
@ -127,7 +126,7 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
return ret
class FusionPass(InductorPass):
class FusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
@ -142,7 +141,7 @@ class FusionPass(InductorPass):
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: CompilationConfig):
def instance(cls, config: CompilationConfig.PassConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
@ -154,7 +153,7 @@ class FusionPass(InductorPass):
cls._instance.config = config
return cls._instance
def __init__(self, config: CompilationConfig):
def __init__(self, config: CompilationConfig.PassConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)
@ -278,6 +277,7 @@ class FusionPass(InductorPass):
for node in match.nodes)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_fusion")
count = self.patterns.apply(graph)
@ -289,3 +289,4 @@ class FusionPass(InductorPass):
logger.debug("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_fusion")
self.matches.clear()
self.end_and_log()

View File

@ -1,38 +1,84 @@
import hashlib
import inspect
import types
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
import torch
from vllm.config import CompilationConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
get_tensor_model_parallel_world_size as get_tp_world_size)
from vllm.distributed import model_parallel_is_initialized as p_is_init
# yapf: enable
from vllm.logger import init_logger
logger = init_logger(__name__)
from torch import fx
class InductorPass(ABC):
"""
General custom inductor pass interface.
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
"""
@abstractmethod
def __call__(self, graph: torch.fx.Graph):
"""
Execute the pass on the given graph.
"""
raise NotImplementedError
def __init__(self, config: CompilationConfig):
self.config = config
def uuid(self) -> Any:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
return InductorPass.hash_source(self)
def dump_graph(self, graph: torch.fx.Graph, stage: str):
if stage in self.config.dump_graph_stages:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
@staticmethod
def hash_source(*srcs: Union[str, Any]):
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if isinstance(src, str):
src_str = src
elif isinstance(src, types.FunctionType):
src_str = inspect.getsource(src)
else:
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.digest()
logger.info("Printing graph to %s", filepath)
with open(filepath, "w") as f:
src = graph.python_code(root_module="self", verbose=True).src
# Add imports so it's not full of errors
print("import torch; from torch import device", file=f)
print(src, file=f)
class CallableInductorPass(InductorPass):
"""
This class is a wrapper for a callable that automatically provides an
implementation of the UUID.
"""
def __init__(self,
callable: Callable[[fx.Graph], None],
uuid: Optional[Any] = None):
self.callable = callable
if uuid is None:
uuid = InductorPass.hash_source(callable)
self._uuid = uuid
def __call__(self, graph: torch.fx.Graph):
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
def __getstate__(self):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return self._uuid
def __setstate__(self, state):
raise ValueError("Cannot unpickle CallableInductorPass")

View File

@ -0,0 +1,77 @@
from typing import List
from torch import fx as fx
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .reshapes import RedundantReshapesPass
logger = init_logger(__name__)
class PostGradPassManager:
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache.
TODO(torch==2.6), use CustomGraphPass
(torch._inductor.custom_graph_pass.CustomGraphPass)
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (RedundantReshapesPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
"""
def __init__(self):
self.passes: List[InductorPass] = []
def __call__(self, graph: fx.Graph):
for pass_ in self.passes:
pass_(graph)
# always run fix_functionalization last
self.fix_functionalization(graph)
def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_reshape:
self.passes += [RedundantReshapesPass(pass_config)]
if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]
self.fix_functionalization = FixFunctionalizationPass(pass_config)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def __getstate__(self):
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
Pickling occurs because the pass manager is set as the value of
`config["post_grad_custom_post_pass"]` in the Inductor config.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
return state
def __setstate__(self, state):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise ValueError("Cannot unpickle PostGradPassManager")

View File

@ -3,14 +3,14 @@ from typing import Union
import torch.fx
from torch import SymInt
from vllm.compilation.fusion import is_func
from vllm.compilation.inductor_pass import InductorPass
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass, is_func
logger = init_logger(__name__)
class RedundantReshapesPass(InductorPass):
class RedundantReshapesPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape operations.
It is required for RMSNorm-quant fusion to work properly.
@ -31,6 +31,7 @@ class RedundantReshapesPass(InductorPass):
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_reshapes")
count = 0
# Remove no-op reshapes/views:
@ -56,6 +57,7 @@ class RedundantReshapesPass(InductorPass):
logger.debug("Removed %s no-op reshapes", count)
self.dump_graph(graph, "after_reshapes")
self.end_and_log()
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:

View File

@ -0,0 +1,53 @@
import time
import torch
from vllm.config import CompilationConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
get_tensor_model_parallel_world_size as get_tp_world_size)
from vllm.distributed import model_parallel_is_initialized as p_is_init
# yapf: enable
from vllm.logger import init_logger
from .inductor_pass import InductorPass
logger = init_logger(__name__)
def is_func(node: torch.fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
class VllmInductorPass(InductorPass):
"""
An inductor pass with access to vLLM PassConfig.
It provides timing, logging, and dumping utilities.
"""
def __init__(self, config: CompilationConfig.PassConfig):
self.config = config
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str):
if stage in self.config.dump_graph_stages:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
logger.info("%s printing graph to %s", self.pass_name, filepath)
with open(filepath, "w") as f:
src = graph.python_code(root_module="self", verbose=True).src
# Add imports so it's not full of errors
print("import torch; from torch import device", file=f)
print(src, file=f)
def begin(self):
self._start_time = time.perf_counter_ns()
def end_and_log(self):
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)

View File

@ -1,5 +1,6 @@
import copy
import enum
import hashlib
import json
import warnings
from dataclasses import dataclass, field, replace
@ -13,6 +14,7 @@ from pydantic import BaseModel, Field, PrivateAttr
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
@ -2120,12 +2122,7 @@ class CompilationConfig(BaseModel):
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
- custom inductor passes: see PassConfig for more details
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
@ -2157,9 +2154,43 @@ class CompilationConfig(BaseModel):
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
class PassConfig(BaseModel):
"""
Configuration for custom Inductor passes.
This is separate from general CompilationConfig so that inductor passes
don't all have access to full configuration - that would create a cycle
as the PassManager is set as a property of config.
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graphs. Default is .
- enable_fusion: whether to enable the custom fusion pass.
- enable_reshape: whether to enable the custom reshape elimination pass.
TODO better pass enabling system.
"""
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_reshape: bool = True
def uuid(self):
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(
include={"enable_fusion", "enable_reshape"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).digest()
def model_post_init(self, __context: Any) -> None:
if not self.enable_reshape and self.enable_fusion:
print_warning_once(
"Fusion enabled but reshape elimination disabled."
"RMSNorm + quant (fp8) fusion might not work")
pass_config: PassConfig = Field(default_factory=PassConfig)
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
@ -2185,8 +2216,9 @@ class CompilationConfig(BaseModel):
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
f"pass {k} should be a function or a qualified name")
self.inductor_compile_config[k] = v
f"pass {k} should be callable or a qualified name")
self.inductor_compile_config[k] = v if isinstance(
v, InductorPass) else CallableInductorPass(v)
continue
# resolve function from qualified name
@ -2194,7 +2226,8 @@ class CompilationConfig(BaseModel):
module = ".".join(names[:-1])
func_name = names[-1]
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
self.inductor_compile_config[k] = func if isinstance(
func, InductorPass) else CallableInductorPass(func)
self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
@ -2344,7 +2377,8 @@ class VllmConfig:
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_reshape = False
current_platform.check_and_update_config(self)

View File

@ -1501,15 +1501,6 @@ class LazyDict(Mapping, Generic[T]):
return len(self._factory)
def combine_fx_passes(passes: List[Callable]) -> Callable:
def combined_fx(graph) -> None:
for fx in passes:
fx(graph)
return combined_fx
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Create a weak reference to a tensor.

View File

@ -548,7 +548,7 @@ class GPUModelRunner:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please add "
"-O 3 to use CUDA graphs.", CompilationLevel.PIECEWISE)
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()