[torch.compile] Inductor code caching fix (#10273)
Signed-off-by: luka <luka@neuralmagic.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
This commit is contained in:
parent
9d827170a3
commit
8b0fe06c89
@ -1,7 +1,9 @@
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
|
||||
|
||||
class TestBackend:
|
||||
@ -11,19 +13,21 @@ class TestBackend:
|
||||
It also saves the graph before and after the custom passes for inspection.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Callable[[torch.fx.Graph], None]):
|
||||
self.custom_passes = args
|
||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
||||
None]]):
|
||||
self.custom_passes = list(passes)
|
||||
from torch._inductor import config
|
||||
self.current_config = config.shallow_copy_dict()
|
||||
self.current_config['force_disable_caches'] = True
|
||||
self.current_config['post_grad_custom_post_pass'] = self.post_pass
|
||||
|
||||
def __call__(self, graph: torch.fx.GraphModule, example_inputs):
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
return compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=self.current_config)
|
||||
|
||||
def post_pass(self, graph: torch.fx.Graph):
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
self.graph_pre_pass = deepcopy(graph)
|
||||
for pass_ in self.custom_passes:
|
||||
pass_(graph)
|
||||
|
95
tests/compile/test_functionalization.py
Normal file
95
tests/compile/test_functionalization.py
Normal file
@ -0,0 +1,95 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||
from vllm.compilation.vllm_inductor_pass import is_func
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
OPS_IN_MODEL = [
|
||||
torch.ops._C.rotary_embedding.default,
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.silu_and_mul.default,
|
||||
]
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
|
||||
RMS_QUANT_OPS = {
|
||||
"static_fp8": [
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||
],
|
||||
}
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model",
|
||||
["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"])
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
def test_fix_functionalization(model: str, do_fusion: bool):
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||
enable_reshape=True)
|
||||
reshape_pass = RedundantReshapesPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
|
||||
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
|
||||
func_pass = FixFunctionalizationPass(config)
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
# instantiate a full engine and manually compile the model 2x
|
||||
# (with and without FixFunctionalizationPass)
|
||||
llm = LLM(model=model, enforce_eager=True)
|
||||
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
|
||||
orig_model = model_runner.model
|
||||
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
|
||||
# Can only do that by using the decorator but then we'd have to instantiate
|
||||
# 2 LLM instances.
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
model_runner.model = torch.compile(orig_model,
|
||||
fullgraph=True,
|
||||
backend=backend_func)
|
||||
gen_func = llm.generate(prompts, sampling_params)
|
||||
|
||||
model_runner.model = torch.compile(orig_model,
|
||||
fullgraph=True,
|
||||
backend=backend_no_func)
|
||||
gen_no_func = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output_func, output_no_func in zip(gen_func, gen_no_func):
|
||||
assert output_func.outputs[0].text == output_no_func.outputs[0].text
|
||||
|
||||
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
|
||||
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
||||
ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"]
|
||||
if do_fusion else [RMS_OP])
|
||||
|
||||
for op in ops:
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
||||
op) is None # noqa: E501
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in ops:
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in ops)
|
@ -38,12 +38,6 @@ class TestModel(torch.nn.Module):
|
||||
return y3
|
||||
|
||||
|
||||
# Init does pattern registration, which can only happen once
|
||||
config = CompilationConfig(enable_fusion=True)
|
||||
reshape_pass = RedundantReshapesPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||
@ -58,6 +52,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
|
||||
pytest.skip("Only test eps=1e-5 for now")
|
||||
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_reshape=True)
|
||||
reshape_pass = RedundantReshapesPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
|
||||
backend = TestBackend(reshape_pass, fusion_pass)
|
||||
model = TestModel(hidden_size, eps)
|
||||
|
||||
|
35
tests/compile/test_pass_manager.py
Normal file
35
tests/compile/test_pass_manager.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._inductor.codecache import BypassFxGraphCache
|
||||
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
from vllm.compilation.inductor_pass import (CallableInductorPass,
|
||||
as_inductor_pass)
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
|
||||
|
||||
def simple_callable(graph: torch.fx.Graph):
|
||||
pass
|
||||
|
||||
|
||||
@as_inductor_pass(files=(__file__, ))
|
||||
def callable_decorated(graph: torch.fx.Graph):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"works, callable",
|
||||
[(False, simple_callable), (True, callable_decorated),
|
||||
(True, CallableInductorPass(simple_callable, "simple_callable"))])
|
||||
def test_pass_manager(works: bool, callable):
|
||||
config = CompilationConfig().pass_config
|
||||
pass_manager = PostGradPassManager([callable])
|
||||
pass_manager.configure(config) # Adds default passes
|
||||
|
||||
if works:
|
||||
pickle.dumps(pass_manager)
|
||||
else:
|
||||
with pytest.raises(BypassFxGraphCache):
|
||||
pickle.dumps(pass_manager)
|
@ -1,6 +1,5 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import operator
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from unittest.mock import patch
|
||||
@ -11,205 +10,15 @@ import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import combine_fx_passes, weak_ref_tensors
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
from .counter import compilation_counter
|
||||
from .fusion import FusionPass
|
||||
from .reshapes import RedundantReshapesPass
|
||||
from .inductor_pass import InductorPass
|
||||
from .pass_manager import PostGradPassManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fix_functionalization(graph: fx.Graph):
|
||||
"""
|
||||
Rewrite the graph module to replace the pattern involving
|
||||
torch._higher_order_ops.auto_functionalize.auto_functionalized
|
||||
with a direct call to the inplace custom op.
|
||||
|
||||
# TODO: check if PyTorch nightly has fixed this issue
|
||||
"""
|
||||
|
||||
# debug code, if we want to see the graph before the transformation
|
||||
# with open("before.py", "w") as f:
|
||||
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||
|
||||
nodes_to_remove = []
|
||||
|
||||
for node in graph.nodes:
|
||||
# Identify the auto_functionalized node
|
||||
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
|
||||
if node.args[0] == torch.ops._C.rotary_embedding.default:
|
||||
# manual replace for rotary_embedding
|
||||
|
||||
# Now, collect the arguments
|
||||
kwargs = node.kwargs
|
||||
|
||||
query = kwargs['query']
|
||||
mm_node = query.args[0].args[0]
|
||||
|
||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(torch.ops._C.rotary_embedding.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
# Remove the auto_functionalized node
|
||||
# Since the node may have outputs, we need to handle its users
|
||||
# Replace uses of the outputs (getitem nodes) with mm_node
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
# Remove the getitem node
|
||||
for getitem_user in list(user.users):
|
||||
if (getitem_user.op == 'call_function'
|
||||
and getitem_user.target
|
||||
== torch.ops.aten.slice_scatter.default):
|
||||
# Replace the uses of slice_scatter node
|
||||
# with mm_node
|
||||
getitem_user.replace_all_uses_with(mm_node)
|
||||
nodes_to_remove.append(getitem_user)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
|
||||
# manual replace for fused_add_rms_norm
|
||||
# this is the most effective optimization for llama
|
||||
# failing to do this will result in many unnecessary copies
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
input = kwargs['input']
|
||||
residual = kwargs['residual']
|
||||
|
||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
# Remove the getitem node
|
||||
if user.args[1] == 1:
|
||||
replace_node = input
|
||||
elif user.args[1] == 2:
|
||||
replace_node = residual
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
elif (node.args[0] ==
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default):
|
||||
# manual replace for fused_add_rms_norm_static_fp8_quant
|
||||
# this is the most effective optimization for llama
|
||||
# failing to do this will result in many unnecessary copies
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
result = kwargs['result']
|
||||
residual = kwargs['residual']
|
||||
|
||||
# Create a new call to
|
||||
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.
|
||||
default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
# Remove the getitem node
|
||||
if user.args[1] == 1:
|
||||
replace_node = result
|
||||
elif user.args[1] == 2:
|
||||
replace_node = residual
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[0] == torch.ops._C.rms_norm.default:
|
||||
# manual replace for rms_norm
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
replace_node = kwargs['result']
|
||||
# Create a new call to torch.ops._C.rms_norm.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(torch.ops._C.rms_norm.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[
|
||||
0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa
|
||||
# manual replace for rms_norm_static_fp8_quant
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
replace_node = kwargs['result']
|
||||
# Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[0] == torch.ops._C.silu_and_mul.default:
|
||||
# manual replace for silu_and_mul
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
input = kwargs['input']
|
||||
out = kwargs['out']
|
||||
|
||||
# Create a new call to torch.ops._C.silu_and_mul.default
|
||||
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.silu_and_mul.default,
|
||||
args=(out, input),
|
||||
)
|
||||
replace_node = out
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
# Remove the nodes all at once
|
||||
for node in nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
# debug code, if we want to see the graph after the transformation
|
||||
# with open("after.py", "w") as f:
|
||||
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||
|
||||
|
||||
def wrap_inductor(graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
@ -368,12 +177,8 @@ class VllmBackend:
|
||||
The major work of this backend is to split the graph into
|
||||
piecewise graphs, and pass them to the piecewise backend.
|
||||
|
||||
This backend also handles custom passes and adds them to Inductor config.
|
||||
The order of the post-grad post-passes is:
|
||||
1. post_grad_passes (constructor parameter)
|
||||
2. config["post_grad_custom_post_pass"]
|
||||
3. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
This backend also adds the PostGradPassManager to Inductor config,
|
||||
which handles the post-grad passes.
|
||||
"""
|
||||
|
||||
compilation_configs: CompilationConfig
|
||||
@ -402,7 +207,9 @@ class VllmBackend:
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = global_graph_pool
|
||||
self.post_grad_passes = []
|
||||
|
||||
# Passes to run on the graph post-grad.
|
||||
self.post_grad_pass_manager = PostGradPassManager()
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
@ -412,24 +219,19 @@ class VllmBackend:
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def add_passes_to_config(self):
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_configs
|
||||
passes = list(self.post_grad_passes)
|
||||
|
||||
passes = passes + [RedundantReshapesPass(config)]
|
||||
|
||||
if config.enable_fusion:
|
||||
passes = passes + [FusionPass.instance(config)]
|
||||
self.post_grad_pass_manager.configure(config.pass_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
inductor_config = config.inductor_compile_config
|
||||
if "post_grad_custom_post_pass" in inductor_config:
|
||||
passes = passes + [inductor_config["post_grad_custom_post_pass"]]
|
||||
|
||||
# add the fix_functionalization pass last, so that all other
|
||||
# passes operate on a functionalized graph
|
||||
passes = passes + [fix_functionalization]
|
||||
combined_pass = combine_fx_passes(passes)
|
||||
inductor_config["post_grad_custom_post_pass"] = combined_pass
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
|
||||
@ -444,7 +246,7 @@ class VllmBackend:
|
||||
# we get the sizes to capture for cudagraph
|
||||
# from compilation context
|
||||
self.compilation_configs.init_during_runtime()
|
||||
self.add_passes_to_config()
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_configs.splitting_ops)
|
||||
|
177
vllm/compilation/fix_functionalization.py
Normal file
177
vllm/compilation/fix_functionalization.py
Normal file
@ -0,0 +1,177 @@
|
||||
import operator
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass, is_func
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fix_functionalization")
|
||||
|
||||
self.nodes_to_remove: List[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs['query']
|
||||
mm_node = query.args[0].args[0]
|
||||
|
||||
# rotary_embedding is a special case: the two mutating inputs
|
||||
# are query and key, which are slices of mm_node.
|
||||
# While functionalized, results at[1] and at[2] are scattered
|
||||
# back into mm_node. After de-functionalization, we can just
|
||||
# use mm_node directly.
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(user_of_getitem,
|
||||
torch.ops.aten.slice_scatter.default):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
# These 2 replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: 'input', 2: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: 'result', 2: 'residual'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default
|
||||
]:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: 'out'}
|
||||
# Because we have an 'out', need to specify args directly
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('out', 'input'))
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
||||
count_removed)
|
||||
self.dump_graph(graph, "after_fix_functionalization")
|
||||
self.end_and_log()
|
||||
|
||||
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
||||
Iterable[torch.fx.Node]]):
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: Dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[Tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
||||
mutated_args: Dict[int,
|
||||
Union[torch.fx.Node,
|
||||
str]]):
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: Optional[Tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), \
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
for arg in args)
|
||||
graph.call_function(function, args=args)
|
@ -6,10 +6,11 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
||||
fwd_only, register_replacement)
|
||||
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass, is_func
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -90,8 +91,6 @@ def empty_fp32(*args, **kwargs):
|
||||
|
||||
|
||||
# Utilities for post-processing multi-output matches
|
||||
def is_func(node: torch.fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
@ -127,7 +126,7 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
|
||||
return ret
|
||||
|
||||
|
||||
class FusionPass(InductorPass):
|
||||
class FusionPass(VllmInductorPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
@ -142,7 +141,7 @@ class FusionPass(InductorPass):
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: CompilationConfig):
|
||||
def instance(cls, config: CompilationConfig.PassConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
@ -154,7 +153,7 @@ class FusionPass(InductorPass):
|
||||
cls._instance.config = config
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
def __init__(self, config: CompilationConfig.PassConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
@ -278,6 +277,7 @@ class FusionPass(InductorPass):
|
||||
for node in match.nodes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
@ -289,3 +289,4 @@ class FusionPass(InductorPass):
|
||||
logger.debug("Post-processed %s matches", len(self.matches))
|
||||
self.dump_graph(graph, "after_fusion")
|
||||
self.matches.clear()
|
||||
self.end_and_log()
|
||||
|
@ -1,38 +1,84 @@
|
||||
import hashlib
|
||||
import inspect
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
# yapf: disable
|
||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size as get_tp_world_size)
|
||||
from vllm.distributed import model_parallel_is_initialized as p_is_init
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from torch import fx
|
||||
|
||||
|
||||
class InductorPass(ABC):
|
||||
"""
|
||||
General custom inductor pass interface.
|
||||
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
"""
|
||||
Execute the pass on the given graph.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
self.config = config
|
||||
def uuid(self) -> Any:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
if stage in self.config.dump_graph_stages:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
||||
@staticmethod
|
||||
def hash_source(*srcs: Union[str, Any]):
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, types.FunctionType):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.digest()
|
||||
|
||||
logger.info("Printing graph to %s", filepath)
|
||||
with open(filepath, "w") as f:
|
||||
src = graph.python_code(root_module="self", verbose=True).src
|
||||
# Add imports so it's not full of errors
|
||||
print("import torch; from torch import device", file=f)
|
||||
print(src, file=f)
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
callable: Callable[[fx.Graph], None],
|
||||
uuid: Optional[Any] = None):
|
||||
self.callable = callable
|
||||
if uuid is None:
|
||||
uuid = InductorPass.hash_source(callable)
|
||||
self._uuid = uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Pickling occurs in the Inductor code cache if a pass is not given to
|
||||
the pass manager but is instead directly added to config as a pass.
|
||||
See PostGradPassManager for more.
|
||||
|
||||
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
|
||||
"""
|
||||
return self._uuid
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise ValueError("Cannot unpickle CallableInductorPass")
|
||||
|
77
vllm/compilation/pass_manager.py
Normal file
77
vllm/compilation/pass_manager.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing import List
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import InductorPass
|
||||
from .reshapes import RedundantReshapesPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PostGradPassManager:
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It also supports pickling, which is used by the Inductor code cache.
|
||||
TODO(torch==2.6), use CustomGraphPass
|
||||
(torch._inductor.custom_graph_pass.CustomGraphPass)
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (RedundantReshapesPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: List[InductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
for pass_ in self.passes:
|
||||
pass_(graph)
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
|
||||
def configure(self, pass_config: CompilationConfig.PassConfig):
|
||||
self.pass_config = pass_config
|
||||
if pass_config.enable_reshape:
|
||||
self.passes += [RedundantReshapesPass(pass_config)]
|
||||
|
||||
if pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(pass_config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(pass_config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Custom pickling for the pass manager, as some passes cannot be pickled.
|
||||
Pickling occurs because the pass manager is set as the value of
|
||||
`config["post_grad_custom_post_pass"]` in the Inductor config.
|
||||
The config is pickled to act as a key in the Inductor code cache.
|
||||
Any other passes in the config are pickled as well.
|
||||
|
||||
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""
|
||||
Do not allow unpickling of the pass manager.
|
||||
If this is needed in the future, it should properly pickle the passes.
|
||||
"""
|
||||
raise ValueError("Cannot unpickle PostGradPassManager")
|
@ -3,14 +3,14 @@ from typing import Union
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.compilation.fusion import is_func
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass, is_func
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RedundantReshapesPass(InductorPass):
|
||||
class RedundantReshapesPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
@ -31,6 +31,7 @@ class RedundantReshapesPass(InductorPass):
|
||||
"""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_reshapes")
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
@ -56,6 +57,7 @@ class RedundantReshapesPass(InductorPass):
|
||||
logger.debug("Removed %s no-op reshapes", count)
|
||||
|
||||
self.dump_graph(graph, "after_reshapes")
|
||||
self.end_and_log()
|
||||
|
||||
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
|
53
vllm/compilation/vllm_inductor_pass.py
Normal file
53
vllm/compilation/vllm_inductor_pass.py
Normal file
@ -0,0 +1,53 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
# yapf: disable
|
||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size as get_tp_world_size)
|
||||
from vllm.distributed import model_parallel_is_initialized as p_is_init
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inductor_pass import InductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_func(node: torch.fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CompilationConfig.PassConfig):
|
||||
self.config = config
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
if stage in self.config.dump_graph_stages:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
||||
|
||||
logger.info("%s printing graph to %s", self.pass_name, filepath)
|
||||
with open(filepath, "w") as f:
|
||||
src = graph.python_code(root_module="self", verbose=True).src
|
||||
# Add imports so it's not full of errors
|
||||
print("import torch; from torch import device", file=f)
|
||||
print(src, file=f)
|
||||
|
||||
def begin(self):
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self):
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, replace
|
||||
@ -13,6 +14,7 @@ from pydantic import BaseModel, Field, PrivateAttr
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
@ -2120,12 +2122,7 @@ class CompilationConfig(BaseModel):
|
||||
name because the config uses json format. If we pass the config
|
||||
from Python, functions can also be passed directly via Python object
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||
- custom inductor passes:
|
||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graph. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
TODO better pass enabling system.
|
||||
- custom inductor passes: see PassConfig for more details
|
||||
|
||||
Why we have different sizes for cudagraph and inductor:
|
||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||
@ -2157,9 +2154,43 @@ class CompilationConfig(BaseModel):
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||
cudagraph_copy_inputs: bool = False
|
||||
|
||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
class PassConfig(BaseModel):
|
||||
"""
|
||||
Configuration for custom Inductor passes.
|
||||
This is separate from general CompilationConfig so that inductor passes
|
||||
don't all have access to full configuration - that would create a cycle
|
||||
as the PassManager is set as a property of config.
|
||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graphs. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
- enable_reshape: whether to enable the custom reshape elimination pass.
|
||||
TODO better pass enabling system.
|
||||
"""
|
||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
enable_reshape: bool = True
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
dict_ = self.model_dump(
|
||||
include={"enable_fusion", "enable_reshape"})
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).digest()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not self.enable_reshape and self.enable_fusion:
|
||||
print_warning_once(
|
||||
"Fusion enabled but reshape elimination disabled."
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
|
||||
pass_config: PassConfig = Field(default_factory=PassConfig)
|
||||
|
||||
# not configurable, computed after init
|
||||
compile_sizes: List[int] = PrivateAttr
|
||||
@ -2185,8 +2216,9 @@ class CompilationConfig(BaseModel):
|
||||
for k, v in self.inductor_passes.items():
|
||||
if not isinstance(v, str):
|
||||
assert callable(v), (
|
||||
f"pass {k} should be a function or a qualified name")
|
||||
self.inductor_compile_config[k] = v
|
||||
f"pass {k} should be callable or a qualified name")
|
||||
self.inductor_compile_config[k] = v if isinstance(
|
||||
v, InductorPass) else CallableInductorPass(v)
|
||||
continue
|
||||
|
||||
# resolve function from qualified name
|
||||
@ -2194,7 +2226,8 @@ class CompilationConfig(BaseModel):
|
||||
module = ".".join(names[:-1])
|
||||
func_name = names[-1]
|
||||
func = __import__(module).__dict__[func_name]
|
||||
self.inductor_compile_config[k] = func
|
||||
self.inductor_compile_config[k] = func if isinstance(
|
||||
func, InductorPass) else CallableInductorPass(func)
|
||||
|
||||
self.enabled_custom_ops = Counter()
|
||||
self.disabled_custom_ops = Counter()
|
||||
@ -2344,7 +2377,8 @@ class VllmConfig:
|
||||
self.compilation_config.custom_ops = ["none"]
|
||||
self.compilation_config.use_cudagraph = True
|
||||
self.compilation_config.use_inductor = True
|
||||
self.compilation_config.enable_fusion = False
|
||||
self.compilation_config.pass_config.enable_fusion = False
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
|
@ -1501,15 +1501,6 @@ class LazyDict(Mapping, Generic[T]):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
def combine_fx_passes(passes: List[Callable]) -> Callable:
|
||||
|
||||
def combined_fx(graph) -> None:
|
||||
for fx in passes:
|
||||
fx(graph)
|
||||
|
||||
return combined_fx
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
|
@ -548,7 +548,7 @@ class GPUModelRunner:
|
||||
if not self.use_cuda_graph:
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. Please add "
|
||||
"-O 3 to use CUDA graphs.", CompilationLevel.PIECEWISE)
|
||||
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
Loading…
x
Reference in New Issue
Block a user