[torch.compile] rework compile control with piecewise cudagraph (#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-10-29 23:03:49 -07:00 committed by GitHub
parent 7b0365efef
commit ff5ed6e1bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 979 additions and 102 deletions

View File

@ -229,6 +229,9 @@ steps:
- tests/compile - tests/compile
commands: commands:
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
# these tests need to be separated, cannot combine
- pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py
- label: "PyTorch Fullgraph Test" # 18min - label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies: source_file_dependencies:

View File

View File

@ -0,0 +1,4 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"]
}

View File

@ -0,0 +1,96 @@
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import os
import torch
from torch import nn
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
global_counter = 0
@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
global_counter += 1
print(f"{global_counter=}")
out.copy_(q)
out[0] += 1
@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overall effect:
x += 1
x[0] += 2
global_counter += 2
"""
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x + 1
return x
def test_simple_piecewise_compile():
model = SillyModel()
directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
input_buffer = torch.randn(100).cuda()
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
with set_compile_context([1, 2]):
model(input_buffer)
model(input_buffer[:2])
model(input_buffer[:1])
input_buffer[:2].zero_()
global global_counter
global_counter = 0
output = model(input_buffer[:2])
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]

View File

@ -0,0 +1,334 @@
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_compilation_config
@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
@dataclass
class LlamaConfig:
hidden_size: int = 128
mlp_size: int = 256
vocab_size: int = 128
num_layers: int = 2
class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_up_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.mlp_size * 2,
bias=False,
)
self.down_projection = nn.Linear(
in_features=config.mlp_size,
out_features=config.hidden_size,
bias=False,
)
self.gate_up_projection.weight.data.fill_(0.0)
self.down_projection.weight.data.fill_(0.0)
def forward(self, x):
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
x = self.down_projection(x)
return x
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.qkv_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size * 3,
)
self.output_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
)
self.qkv_projection.weight.data.fill_(0.0)
self.output_projection.weight.data.fill_(0.0)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_projection(hidden_states)
hidden_size = qkv.size(-1) // 3
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
q = q + positions.unsqueeze(1)
k = k + positions.unsqueeze(1)
attn_output = torch.empty_like(q)
torch.ops.silly.attention(q, k, v, attn_output)
output = self.output_projection(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.self_attention = LlamaAttention(config)
self.mlp = LlamaMLP(config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = hidden_states / 2
else:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
self.embedding_tokens.weight.data.fill_(0.0)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embedding_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
return hidden_states
@torch.inference_mode
def run_model(llama_config,
use_compile: bool,
split_attn: bool = False) -> torch.Tensor:
if use_compile:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.PIECEWISE)
if split_attn:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
else:
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
else:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
cls = LlamaModel
if use_compile:
cls = support_torch_compile(LlamaModel)
model = cls(llama_config).eval().cuda()
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()
with set_compile_context([1, 2]):
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
# manual cleanup
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)
return output.cpu()
def test_toy_llama():
# compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=2)
outputs = []
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_inductor_compilations=0,
num_cudagraph_caputured=0,
):
outputs.append(run_model(llama_config, use_compile=False))
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1,
num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(run_model(llama_config, use_compile=True))
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=2 * llama_config.num_layers +
1, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=1 +
llama_config.num_layers, # 1 + num_layers
num_inductor_compilations=1 +
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=2 *
(1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(llama_config, use_compile=True, split_attn=True))
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])
@torch.inference_mode
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench
cls = support_torch_compile(LlamaModel)
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
# a tiny model to measure the overhead
# of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40,
mlp_size=80,
vocab_size=128,
num_layers=2)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
eager_time = {}
full_cudagraph_time = {}
piecewise_cudagraph_time = {}
pool = torch.cuda.graph_pool_handle()
for piecewise in [False, True]:
if piecewise:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
else:
set_compilation_config(None)
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {}
with set_compile_context(cudagraph_sizes):
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
else:
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
for b in cudagraph_sizes:
if piecewise:
# noqa is for `Function definition does not bind loop variable`
# it will be problematic if we save the created lambda function
# and use it later, because it will look up the name `b` in the
# enclosing scope, and the value of `b` will always be 256.
# it is fine here, because we only use the lambda function once.
runtime = do_bench(lambda: graphs[b][0] # noqa
(input_ids[:b], positions[:b])) # noqa
piecewise_cudagraph_time[b] = runtime
else:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench(
lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime
# print in tabular format
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
for b in cudagraph_sizes:
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}"))
if __name__ == "__main__":
benchmark()

View File

@ -9,7 +9,7 @@ from .utils import TEST_MODELS, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS) @pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"optimization_level", "optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR]) [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE])
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_full_graph(model_info, optimization_level): def test_full_graph(model_info, optimization_level):
model = model_info[0] model = model_info[0]

View File

@ -9,17 +9,19 @@ from vllm.platforms import current_platform
TEST_MODELS = [ TEST_MODELS = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { # TODO: add fake implementation for compressed-tensors
"dtype": torch.float16, # ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"quantization": "compressed-tensors" # "dtype": torch.float16,
}), # "quantization": "compressed-tensors"
# }),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", { ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
"dtype": torch.float16, "dtype": torch.float16,
"quantization": "fp8" "quantization": "fp8"
}), }),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { # TODO: add fake implementation for compressed-tensors
"quantization": "compressed-tensors" # ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
}), # "quantization": "compressed-tensors"
# }),
("meta-llama/Meta-Llama-3-8B", {}), ("meta-llama/Meta-Llama-3-8B", {}),
] ]
@ -73,7 +75,7 @@ def check_full_graph_support(model,
# much memory. # much memory.
quantization = model_kwargs.get("quantization") quantization = model_kwargs.get("quantization")
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B") if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
and optimization_level >= CompilationLevel.INDUCTOR): and optimization_level >= CompilationLevel.PIECEWISE):
return return
prompts = [ prompts = [

View File

@ -1,13 +1,16 @@
import copy import copy
import dataclasses
import operator import operator
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
import torch.fx as fx import torch.fx as fx
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
from .compile_context import get_compile_context from .config import CompilationConfig
from .counter import compilation_counter
from .levels import CompilationLevel from .levels import CompilationLevel
logger = init_logger(__name__) logger = init_logger(__name__)
@ -157,113 +160,326 @@ def fix_functionalization(graph: fx.Graph):
# print(graph.python_code(root_module="self", verbose=True).src, file=f) # print(graph.python_code(root_module="self", verbose=True).src, file=f)
def wrap_inductor(graph, example_inputs, additional_inductor_config): def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
do_logging=False,
runtime_shape: Optional[int] = None,
use_inductor: bool = True):
if not use_inductor:
return graph
compilation_counter.num_inductor_compilations += 1
if do_logging:
if runtime_shape is None:
logger.info("Compiling a graph for general shape")
else:
logger.info("Compiling a graph for shape %s", runtime_shape)
from torch._inductor import config from torch._inductor import config
current_config = config.shallow_copy_dict() current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
if additional_inductor_config is not None: if additional_inductor_config is not None:
current_config.update(additional_inductor_config) current_config.update(additional_inductor_config)
if current_config['post_grad_custom_post_pass'] is not None:
logger.warning( # inductor can inplace modify the graph, so we need to copy it
"post_grad_custom_post_pass is already set in the config. " # see https://github.com/pytorch/pytorch/issues/138980
"Overwriting it with the fix_functionalization") graph = copy.deepcopy(graph)
current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config) return compile_fx(graph, example_inputs, config_patches=current_config)
def vllm_backend( @dataclasses.dataclass
class SplitItem:
submod_name: str
is_splitting_graph: bool
graph: fx.GraphModule
def split_graph(graph: fx.GraphModule,
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
# split graph by ops
subgraph_id = 0
node_to_subgraph_id = {}
split_op_graphs = []
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue
if node.op == 'call_function' and str(node.target) in ops:
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
subgraph_id += 1
else:
node_to_subgraph_id[node] = subgraph_id
# `keep_original_order` is important!
# otherwise pytorch might reorder the nodes and
# the semantics of the graph will change when we
# have mutations in the graph
split_gm = torch.fx.passes.split_module.split_module(
graph, graph,
example_inputs, None,
additional_inductor_config: Optional[Dict] = None) -> Callable: lambda node: node_to_subgraph_id[node],
keep_original_order=True)
context = get_compile_context() outputs = []
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
# flags for all the seen shapes, whether we need to specialize # sort the names to make sure the order is deterministic
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} names = [name for (name, module) in split_gm.named_modules()]
names.sort()
# if we need to specialize, the compiled graph for that shape for name in names:
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} if "." in name or name == "":
# recursive child module or the root module
continue
module = getattr(split_gm, name)
graph_id = int(name.replace("submod_", ""))
outputs.append(SplitItem(name, graph_id in split_op_graphs, module))
return split_gm, outputs
class VllmBackend:
"""The compilation backend for `torch.compile` with VLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
where we customize the compilation.
The major work of this backend is to split the graph into
piecewise graphs, and pass them to the piecewise backend.
"""
compilation_configs: CompilationConfig
graph_pool: Any
_called: bool = False
# the graph we compiled
graph: fx.GraphModule
# the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule
piecewise_graphs: List[SplitItem]
returned_callable: Callable
def __init__(self, ):
# every instance of VllmBackend has its own graph pool
self.graph_pool = torch.cuda.graph_pool_handle()
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
compilation_counter.num_graphs_seen += 1
# we control the compilation process, each instance can only be
# called once
assert not self._called, "VllmBackend can only be called once"
self.graph = graph
# config is read now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs = CompilationConfig.select_and_init_config()
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops)
returned_callable: Callable # type: ignore
if len(self.piecewise_graphs) == 0:
compilation_counter.num_piecewise_graphs_seen += 1
compilation_counter.num_piecewise_capturable_graphs_seen += 1
returned_callable = PiecewiseBackend(graph,
self.compilation_configs,
self.graph_pool,
is_first_graph=True)
else:
from torch._dynamo.utils import lazy_format_graph_code
logger.debug(
"%s", lazy_format_graph_code("stiching module", self.split_gm))
is_first_graph = True
for item in self.piecewise_graphs:
compilation_counter.num_piecewise_graphs_seen += 1
compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa
if not item.is_splitting_graph:
# cannot setattr to a module, so we need to set
# the attribute in the __dict__
self.split_gm.__dict__[
item.submod_name] = PiecewiseBackend(
item.graph, self.compilation_configs,
self.graph_pool, is_first_graph)
is_first_graph = False
returned_callable = self.split_gm
self.returned_callable = returned_callable
# trigger the first compilation
# code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa
# to turn the inputs into fake tensors
import torch._guards
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(example_inputs)
fake_args = []
for arg in example_inputs:
if isinstance(arg, torch.Tensor) and not isinstance(
arg, torch._subclasses.FakeTensor):
fake_args.append(
torch._dynamo.utils.to_fake_tensor(arg, fake_mode))
else:
fake_args.append(arg)
self.returned_callable(*fake_args)
self._called = True
return self.returned_callable
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in capture_sizes
compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
class PiecewiseBackend:
def __init__(self,
graph: fx.GraphModule,
compilation_configs: CompilationConfig,
graph_pool: Any,
is_first_graph: bool = False):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_configs.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.compilation_configs = compilation_configs
self.graph_pool = graph_pool
self.is_first_graph = is_first_graph
self.compile_sizes: Set[int] = set(
self.compilation_configs.compile_sizes)
self.capture_sizes: Set[int] = set(
self.compilation_configs.capture_sizes
) if self.compilation_configs.use_cudagraph else set()
self.compile_finished = False
self.first_run_finished = False
self.compiled_graph_for_general_shape: Callable = None # type: ignore
self.sym_shape_indices: List[int] = []
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
for shape in self.compile_sizes.union(self.capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.capture_sizes,
)
def __call__(self, *args) -> Any:
if not self.compile_finished:
self.compile_finished = True
# this is the first compilation, we will compile a graph with # this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic # dynamic shape, as the caller will mark first dimension as dynamic
logger.info("Compiling a graph for general shapes")
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
additional_inductor_config)
# TODO: Dynamo does not pass all dynamic shapes. self.sym_shape_indices = [
# Need to investigate why. It works now because all the dynamic i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
# shapes have the same value, and either of them can be used.
sym_shape_indices = [
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
] ]
first_run = True self.compiled_graph_for_general_shape = wrap_inductor(
self.graph,
args,
self.compilation_configs.inductor_compile_config,
runtime_shape=None,
do_logging=self.is_first_graph,
use_inductor=self.compilation_configs.use_inductor)
# this is the function we return to Dynamo to run finally return self.graph(*args)
def compiled_graph_wrapper(*args):
runtime_shapes: Tuple[int, if not self.first_run_finished:
...] = tuple(args[i] for i in sym_shape_indices) self.first_run_finished = True
return self.compiled_graph_for_general_shape(*args)
nonlocal first_run runtime_shape = args[self.sym_shape_indices[0]]
nonlocal runtime_shapes_to_compile_flags if runtime_shape not in self.concrete_size_entries:
nonlocal runtime_shapes_to_compiled_graph # we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)
if first_run: entry = self.concrete_size_entries[runtime_shape]
# the first compilation is for profiling, we directly run it
first_run = False
return graph_for_symbolic_shape(*args)
if runtime_shapes not in runtime_shapes_to_compile_flags: if entry.runnable is None:
# we haven't seen this shape before entry.runnable = self.compiled_graph_for_general_shape
# query if we need to specialize for this shape
# we only specialize for the first dimension.
# TODO: investigate if any model needs to specialize
# beyond the first dimension
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
0] in sizes_to_specialize
if not runtime_shapes_to_compile_flags[runtime_shapes]: if entry.need_to_compile and not entry.compiled:
# we don't need to specialize for this shape entry.compiled = True
return graph_for_symbolic_shape(*args) # args are real arguments
entry.runnable = wrap_inductor(
self.graph,
args,
self.compilation_configs.inductor_compile_config,
runtime_shape=runtime_shape,
do_logging=self.is_first_graph,
use_inductor=self.compilation_configs.use_inductor)
if runtime_shapes not in runtime_shapes_to_compiled_graph: if not entry.use_cudagraph:
# we need to specialize for this shape, and we haven't compiled return entry.runnable(*args)
# compile the graph for this shape
logger.info("Compiling a graph for shapes %s", runtime_shapes)
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
graph, args, additional_inductor_config)
return runtime_shapes_to_compiled_graph[runtime_shapes](*args) if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_configs.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)
return compiled_graph_wrapper if self.is_first_graph:
logger.info("Capturing a cudagraph for shape %s",
runtime_shape)
cudagraph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
entry.output = weak_ref_tensors(entry.runnable(*args))
compilation_counter.num_cudagraph_caputured += 1
entry.cudagraph = cudagraph
return entry.output
entry.cudagraph.replay()
return entry.output
def select_default_backend(level: int) -> Union[str, Callable]: def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend_str = "eager" backend_str = "eager"
return backend_str return backend_str
assert level in [ assert level == CompilationLevel.PIECEWISE
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"
from vllm.compilation.backends import vllm_backend return VllmBackend()
from vllm.plugins import get_inductor_additional_configs
additional_configs = get_inductor_additional_configs()
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
if "max_autotune" in additional_configs and not additional_configs[
"max_autotune"]:
logger.warning(
"max_autotune is disabled, but is overridden by level %s",
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
additional_configs['max_autotune'] = True
from functools import partial
backend = partial(vllm_backend,
additional_inductor_config=additional_configs)
return backend

154
vllm/compilation/config.py Normal file
View File

@ -0,0 +1,154 @@
import copy
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
import vllm.envs as envs
from vllm.logger import init_logger
from .compile_context import get_compile_context
logger = init_logger(__name__)
class CompilationConfig(BaseModel):
"""
Configuration for compilation.
It has two parts:
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None: capture sizes are inferred from compilation context.
- List[int]: capture sizes are specified.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
- inductor: a graph compiled by inductor for a general shape can be used
for different sizes. Inductor can also compile for specific sizes,
where it can have more information to optimize the graph with fully
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
"""
use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
def model_post_init(self, __context: Any) -> None:
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
f"pass {k} should be a function or a qualified name")
self.inductor_passes[k] = v
continue
# resolve function from qualified name
names = v.split(".")
module = ".".join(names[:-1])
func_name = names[-1]
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
from vllm.compilation.backends import fix_functionalization
from vllm.utils import combine_fx_passes
if "post_grad_custom_post_pass" in self.inductor_compile_config:
self.inductor_compile_config[
"post_grad_custom_post_pass"] = combine_fx_passes(
fix_functionalization,
self.inductor_compile_config["post_grad_custom_post_pass"],
)
else:
self.inductor_compile_config[
"post_grad_custom_post_pass"] = fix_functionalization
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
self.capture_sizes = self.cudagraph_capture_sizes
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
assert self.inductor_compile_sizes is None, (
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None")
self.compile_sizes = [
x for x in self.capture_sizes
if x <= self.inductor_specialize_for_cudagraph_no_more_than
]
else:
assert self.inductor_compile_sizes is not None, (
"inductor_compile_sizes should not be None when "
"inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes
@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())
config.init_during_runtime()
return config

View File

@ -0,0 +1,30 @@
import copy
import dataclasses
from contextlib import contextmanager
@dataclasses.dataclass
class CompilationCounter:
num_graphs_seen: int = 0
# including the splitting ops
num_piecewise_graphs_seen: int = 0
# not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0
num_inductor_compilations: int = 0
num_cudagraph_caputured: int = 0
def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)
@contextmanager
def expect(self, **kwargs):
old = self.clone()
yield
for k, v in kwargs.items():
assert getattr(self, k) - getattr(old, k) == v, (
f"{k} not as expected, before it is {getattr(old, k)}"
f", after it is {getattr(self, k)}, "
f"expected diff is {v}")
compilation_counter = CompilationCounter()

View File

@ -121,7 +121,10 @@ def _support_torch_compile(cls: type,
# take care of method resolution order # take care of method resolution order
# make sure super().__init__ is called on the base class # make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher # other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
# support decorating multiple times
cls.__bases__ = cls.__bases__ + (
TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore old_init = cls.__init__ # type: ignore
@ -160,6 +163,11 @@ def _support_torch_compile(cls: type,
# compiled function and let torch.compile handle the dispatching, # compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation. # with the overhead of guard evaluation and recompilation.
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
# it seems Dynamo reuse the compilation across instances,
# while we need to make sure the compiled code is not reused.
# we need to control all the compilation of the model.
torch._dynamo.eval_frame.remove_from_cache(
self.original_code_object)
return self.compiled_callable(*args, **kwargs) return self.compiled_callable(*args, **kwargs)
# usually, capturing the model once is enough, and then we can # usually, capturing the model once is enough, and then we can

View File

@ -5,5 +5,4 @@ class CompilationLevel:
NO_COMPILATION = 0 NO_COMPILATION = 0
DYNAMO_AS_IS = 1 DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2 DYNAMO_ONCE = 2
INDUCTOR = 3 PIECEWISE = 3
INDUCTOR_MAX_AUTOTUNE = 4

View File

@ -209,6 +209,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL": "VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# Path to the config file for torch compile
"VLLM_TORCH_COMPILE_CONFIG":
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
# Fine-grained control over which custom ops to enable/disable. # Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all. # Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'), # Also specify a list of custom op names to enable (prefixed with a '+'),

View File

@ -100,7 +100,7 @@ class CustomOp(nn.Module):
return (CustomOp.default_on() or enabled) and not disabled return (CustomOp.default_on() or enabled) and not disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod @staticmethod
@lru_cache() @lru_cache()
@ -108,7 +108,7 @@ class CustomOp(nn.Module):
count_none = envs.VLLM_CUSTOM_OPS.count("none") count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all") count_all = envs.VLLM_CUSTOM_OPS.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \ return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \
not count_none > 0 or count_all > 0 not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name). # Dictionary of all custom ops (classes, indexed by registered name).

View File

@ -11,7 +11,7 @@ from .interface import Platform, PlatformEnum
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\ assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor." "TPU does not support Inductor."
set_torch_compile_backend("openxla") set_torch_compile_backend("openxla")

View File

@ -1,7 +1,8 @@
import logging import logging
from typing import Callable, Dict, Optional, Union from typing import Callable, Optional, Union
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,13 +45,13 @@ def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend return _torch_compile_backend
_inductor_additional_configs: Dict = {} _compilation_config: Optional[CompilationConfig] = None
def set_inductor_additional_configs(configs: Dict): def set_compilation_config(config: Optional[CompilationConfig]):
global _inductor_additional_configs global _compilation_config
_inductor_additional_configs = configs _compilation_config = config
def get_inductor_additional_configs() -> Dict: def get_compilation_config() -> Optional[CompilationConfig]:
return _inductor_additional_configs return _compilation_config

View File

@ -1479,6 +1479,15 @@ class LazyDict(Mapping, Generic[T]):
return len(self._factory) return len(self._factory)
def combine_fx_passes(passes: List[Callable]) -> Callable:
def combined_fx(graph) -> None:
for fx in passes:
fx(graph)
return combined_fx
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
""" """
Create a weak reference to a tensor. Create a weak reference to a tensor.
@ -1486,3 +1495,19 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
but will not keep the original tensor alive. but will not keep the original tensor alive.
""" """
return torch.ops._C.weak_ref_tensor(tensor) return torch.ops._C.weak_ref_tensor(tensor)
def weak_ref_tensors(
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")