[torch.compile] rework compile control with piecewise cudagraph (#9715)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
7b0365efef
commit
ff5ed6e1bc
@ -229,6 +229,9 @@ steps:
|
|||||||
- tests/compile
|
- tests/compile
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
|
# these tests need to be separated, cannot combine
|
||||||
|
- pytest -v -s compile/piecewise/test_simple.py
|
||||||
|
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||||
|
|
||||||
- label: "PyTorch Fullgraph Test" # 18min
|
- label: "PyTorch Fullgraph Test" # 18min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
0
tests/compile/piecewise/__init__.py
Normal file
0
tests/compile/piecewise/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"use_cudagraph": true,
|
||||||
|
"non_cudagraph_ops": ["silly.attention"]
|
||||||
|
}
|
96
tests/compile/piecewise/test_simple.py
Normal file
96
tests/compile/piecewise/test_simple.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Test the piecewise compilation with a simple model so that we
|
||||||
|
can exactly calculate the expected output and side effects.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||||
|
|
||||||
|
global_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("silly::attention", mutates_args=["out"])
|
||||||
|
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
global global_counter
|
||||||
|
global_counter += 1
|
||||||
|
print(f"{global_counter=}")
|
||||||
|
out.copy_(q)
|
||||||
|
out[0] += 1
|
||||||
|
|
||||||
|
|
||||||
|
@silly_attention.register_fake
|
||||||
|
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class SillyModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Overall effect:
|
||||||
|
x += 1
|
||||||
|
x[0] += 2
|
||||||
|
global_counter += 2
|
||||||
|
"""
|
||||||
|
x = x + 1
|
||||||
|
x = x + 2
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
torch.ops.silly.attention(x, x, x, out)
|
||||||
|
x = out
|
||||||
|
x = x - 2
|
||||||
|
x = x - 1
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
torch.ops.silly.attention(x, x, x, out)
|
||||||
|
x = out
|
||||||
|
x = x + 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_piecewise_compile():
|
||||||
|
|
||||||
|
model = SillyModel()
|
||||||
|
|
||||||
|
directory = os.path.dirname(__file__)
|
||||||
|
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||||
|
|
||||||
|
input_buffer = torch.randn(100).cuda()
|
||||||
|
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=1, # one graph for the model
|
||||||
|
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||||
|
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||||
|
num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||||
|
num_cudagraph_caputured=
|
||||||
|
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
):
|
||||||
|
|
||||||
|
with set_compile_context([1, 2]):
|
||||||
|
model(input_buffer)
|
||||||
|
|
||||||
|
model(input_buffer[:2])
|
||||||
|
model(input_buffer[:1])
|
||||||
|
|
||||||
|
input_buffer[:2].zero_()
|
||||||
|
global global_counter
|
||||||
|
global_counter = 0
|
||||||
|
output = model(input_buffer[:2])
|
||||||
|
assert global_counter == 2
|
||||||
|
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||||
|
|
||||||
|
# clean up to avoid side effects for other tests
|
||||||
|
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]
|
334
tests/compile/piecewise/test_toy_llama.py
Normal file
334
tests/compile/piecewise/test_toy_llama.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
"""
|
||||||
|
Test the piecewise compilation with a simple model, comparing the output
|
||||||
|
with and without the piecewise compilation.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
|
from vllm.compilation.config import CompilationConfig
|
||||||
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
|
from vllm.plugins import set_compilation_config
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("silly::attention", mutates_args=["out"])
|
||||||
|
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
out.copy_(q)
|
||||||
|
out += k
|
||||||
|
out += v
|
||||||
|
|
||||||
|
|
||||||
|
@silly_attention.register_fake
|
||||||
|
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
out: torch.Tensor) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LlamaConfig:
|
||||||
|
hidden_size: int = 128
|
||||||
|
mlp_size: int = 256
|
||||||
|
vocab_size: int = 128
|
||||||
|
num_layers: int = 2
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_projection = nn.Linear(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
out_features=config.mlp_size * 2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_projection = nn.Linear(
|
||||||
|
in_features=config.mlp_size,
|
||||||
|
out_features=config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gate_up_projection.weight.data.fill_(0.0)
|
||||||
|
self.down_projection.weight.data.fill_(0.0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.gate_up_projection(x)
|
||||||
|
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
||||||
|
x[:, x.size(1) // 2:])
|
||||||
|
x = self.down_projection(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.qkv_projection = nn.Linear(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
out_features=config.hidden_size * 3,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_projection = nn.Linear(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
out_features=config.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qkv_projection.weight.data.fill_(0.0)
|
||||||
|
self.output_projection.weight.data.fill_(0.0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_projection(hidden_states)
|
||||||
|
hidden_size = qkv.size(-1) // 3
|
||||||
|
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
|
||||||
|
|
||||||
|
q = q + positions.unsqueeze(1)
|
||||||
|
k = k + positions.unsqueeze(1)
|
||||||
|
|
||||||
|
attn_output = torch.empty_like(q)
|
||||||
|
torch.ops.silly.attention(q, k, v, attn_output)
|
||||||
|
|
||||||
|
output = self.output_projection(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.self_attention = LlamaAttention(config)
|
||||||
|
self.mlp = LlamaMLP(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = hidden_states / 2
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = hidden_states / 2
|
||||||
|
|
||||||
|
hidden_states = self.self_attention(positions=positions,
|
||||||
|
hidden_states=hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = hidden_states / 2
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_tokens = nn.Embedding(
|
||||||
|
num_embeddings=config.vocab_size,
|
||||||
|
embedding_dim=config.hidden_size,
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
||||||
|
|
||||||
|
self.embedding_tokens.weight.data.fill_(0.0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embedding_tokens(input_ids)
|
||||||
|
residual = None
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def run_model(llama_config,
|
||||||
|
use_compile: bool,
|
||||||
|
split_attn: bool = False) -> torch.Tensor:
|
||||||
|
|
||||||
|
if use_compile:
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
|
||||||
|
CompilationLevel.PIECEWISE)
|
||||||
|
|
||||||
|
if split_attn:
|
||||||
|
set_compilation_config(
|
||||||
|
CompilationConfig(
|
||||||
|
use_cudagraph=True,
|
||||||
|
non_cudagraph_ops=["silly.attention"],
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
|
||||||
|
CompilationLevel.NO_COMPILATION)
|
||||||
|
set_compilation_config(None)
|
||||||
|
|
||||||
|
cls = LlamaModel
|
||||||
|
if use_compile:
|
||||||
|
cls = support_torch_compile(LlamaModel)
|
||||||
|
model = cls(llama_config).eval().cuda()
|
||||||
|
|
||||||
|
B = 16 # max batch size
|
||||||
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||||
|
positions = torch.arange(B).cuda()
|
||||||
|
|
||||||
|
with set_compile_context([1, 2]):
|
||||||
|
model(input_ids, positions)
|
||||||
|
model(input_ids[:2], positions[:2])
|
||||||
|
model(input_ids[:1], positions[:1])
|
||||||
|
|
||||||
|
input_ids[:2].zero_()
|
||||||
|
output = model(input_ids[:2], positions[:2])
|
||||||
|
|
||||||
|
# manual cleanup
|
||||||
|
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
|
||||||
|
set_compilation_config(None)
|
||||||
|
|
||||||
|
return output.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def test_toy_llama():
|
||||||
|
# compare output with and without piecewise compilation
|
||||||
|
|
||||||
|
llama_config = LlamaConfig(hidden_size=128,
|
||||||
|
mlp_size=256,
|
||||||
|
vocab_size=128,
|
||||||
|
num_layers=2)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=0,
|
||||||
|
num_piecewise_graphs_seen=0,
|
||||||
|
num_piecewise_capturable_graphs_seen=0,
|
||||||
|
num_inductor_compilations=0,
|
||||||
|
num_cudagraph_caputured=0,
|
||||||
|
):
|
||||||
|
outputs.append(run_model(llama_config, use_compile=False))
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=1, # one graph for the model
|
||||||
|
num_piecewise_graphs_seen=1,
|
||||||
|
num_piecewise_capturable_graphs_seen=1,
|
||||||
|
num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||||
|
num_cudagraph_caputured=
|
||||||
|
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
):
|
||||||
|
outputs.append(run_model(llama_config, use_compile=True))
|
||||||
|
|
||||||
|
with compilation_counter.expect(
|
||||||
|
num_graphs_seen=1, # one graph for the model
|
||||||
|
num_piecewise_graphs_seen=2 * llama_config.num_layers +
|
||||||
|
1, # 2 * num_layers + 1
|
||||||
|
num_piecewise_capturable_graphs_seen=1 +
|
||||||
|
llama_config.num_layers, # 1 + num_layers
|
||||||
|
num_inductor_compilations=1 +
|
||||||
|
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
||||||
|
num_cudagraph_caputured=2 *
|
||||||
|
(1 + llama_config.num_layers
|
||||||
|
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
):
|
||||||
|
outputs.append(
|
||||||
|
run_model(llama_config, use_compile=True, split_attn=True))
|
||||||
|
|
||||||
|
for i in range(1, len(outputs)):
|
||||||
|
assert torch.allclose(outputs[0], outputs[i])
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def benchmark():
|
||||||
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||||
|
from triton.testing import do_bench
|
||||||
|
cls = support_torch_compile(LlamaModel)
|
||||||
|
|
||||||
|
# similar to llama 3.1-8B
|
||||||
|
llama_config = LlamaConfig(hidden_size=4096,
|
||||||
|
mlp_size=14336,
|
||||||
|
vocab_size=128 * 1024,
|
||||||
|
num_layers=32)
|
||||||
|
|
||||||
|
# a tiny model to measure the overhead
|
||||||
|
# of piecewise cudagraph
|
||||||
|
llama_config = LlamaConfig(hidden_size=40,
|
||||||
|
mlp_size=80,
|
||||||
|
vocab_size=128,
|
||||||
|
num_layers=2)
|
||||||
|
|
||||||
|
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
||||||
|
|
||||||
|
eager_time = {}
|
||||||
|
full_cudagraph_time = {}
|
||||||
|
piecewise_cudagraph_time = {}
|
||||||
|
|
||||||
|
pool = torch.cuda.graph_pool_handle()
|
||||||
|
|
||||||
|
for piecewise in [False, True]:
|
||||||
|
if piecewise:
|
||||||
|
set_compilation_config(
|
||||||
|
CompilationConfig(
|
||||||
|
use_cudagraph=True,
|
||||||
|
non_cudagraph_ops=["silly.attention"],
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
set_compilation_config(None)
|
||||||
|
|
||||||
|
model = cls(llama_config).eval().cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
|
B = 256 # max batch size
|
||||||
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||||
|
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
||||||
|
|
||||||
|
graphs = {}
|
||||||
|
|
||||||
|
with set_compile_context(cudagraph_sizes):
|
||||||
|
model(input_ids, positions)
|
||||||
|
for b in cudagraph_sizes[::-1]:
|
||||||
|
if not piecewise:
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph, pool=pool):
|
||||||
|
output = model(input_ids[:b], positions[:b])
|
||||||
|
graphs[b] = (graph, output)
|
||||||
|
else:
|
||||||
|
output = model(input_ids[:b], positions[:b])
|
||||||
|
graphs[b] = (model, output)
|
||||||
|
for b in cudagraph_sizes:
|
||||||
|
if piecewise:
|
||||||
|
# noqa is for `Function definition does not bind loop variable`
|
||||||
|
# it will be problematic if we save the created lambda function
|
||||||
|
# and use it later, because it will look up the name `b` in the
|
||||||
|
# enclosing scope, and the value of `b` will always be 256.
|
||||||
|
# it is fine here, because we only use the lambda function once.
|
||||||
|
runtime = do_bench(lambda: graphs[b][0] # noqa
|
||||||
|
(input_ids[:b], positions[:b])) # noqa
|
||||||
|
piecewise_cudagraph_time[b] = runtime
|
||||||
|
else:
|
||||||
|
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
||||||
|
eager_runtime = do_bench(
|
||||||
|
lambda: model(input_ids[:b], positions[:b])) # noqa
|
||||||
|
full_cudagraph_time[b] = runtime
|
||||||
|
eager_time[b] = eager_runtime
|
||||||
|
|
||||||
|
# print in tabular format
|
||||||
|
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
||||||
|
for b in cudagraph_sizes:
|
||||||
|
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
||||||
|
f"\t{piecewise_cudagraph_time[b]:.3f}"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark()
|
@ -9,7 +9,7 @@ from .utils import TEST_MODELS, check_full_graph_support
|
|||||||
@pytest.mark.parametrize("model_info", TEST_MODELS)
|
@pytest.mark.parametrize("model_info", TEST_MODELS)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"optimization_level",
|
"optimization_level",
|
||||||
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
|
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE])
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
def test_full_graph(model_info, optimization_level):
|
def test_full_graph(model_info, optimization_level):
|
||||||
model = model_info[0]
|
model = model_info[0]
|
||||||
|
@ -9,17 +9,19 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
TEST_MODELS = [
|
TEST_MODELS = [
|
||||||
("facebook/opt-125m", {}),
|
("facebook/opt-125m", {}),
|
||||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
# TODO: add fake implementation for compressed-tensors
|
||||||
"dtype": torch.float16,
|
# ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||||
"quantization": "compressed-tensors"
|
# "dtype": torch.float16,
|
||||||
}),
|
# "quantization": "compressed-tensors"
|
||||||
|
# }),
|
||||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
|
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
|
||||||
"dtype": torch.float16,
|
"dtype": torch.float16,
|
||||||
"quantization": "fp8"
|
"quantization": "fp8"
|
||||||
}),
|
}),
|
||||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
|
# TODO: add fake implementation for compressed-tensors
|
||||||
"quantization": "compressed-tensors"
|
# ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
|
||||||
}),
|
# "quantization": "compressed-tensors"
|
||||||
|
# }),
|
||||||
("meta-llama/Meta-Llama-3-8B", {}),
|
("meta-llama/Meta-Llama-3-8B", {}),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -73,7 +75,7 @@ def check_full_graph_support(model,
|
|||||||
# much memory.
|
# much memory.
|
||||||
quantization = model_kwargs.get("quantization")
|
quantization = model_kwargs.get("quantization")
|
||||||
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
|
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
|
||||||
and optimization_level >= CompilationLevel.INDUCTOR):
|
and optimization_level >= CompilationLevel.PIECEWISE):
|
||||||
return
|
return
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import dataclasses
|
||||||
import operator
|
import operator
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import weak_ref_tensors
|
||||||
|
|
||||||
from .compile_context import get_compile_context
|
from .config import CompilationConfig
|
||||||
|
from .counter import compilation_counter
|
||||||
from .levels import CompilationLevel
|
from .levels import CompilationLevel
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -157,113 +160,326 @@ def fix_functionalization(graph: fx.Graph):
|
|||||||
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||||
|
|
||||||
|
|
||||||
def wrap_inductor(graph, example_inputs, additional_inductor_config):
|
def wrap_inductor(graph,
|
||||||
|
example_inputs,
|
||||||
|
additional_inductor_config,
|
||||||
|
do_logging=False,
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
use_inductor: bool = True):
|
||||||
|
if not use_inductor:
|
||||||
|
return graph
|
||||||
|
|
||||||
|
compilation_counter.num_inductor_compilations += 1
|
||||||
|
|
||||||
|
if do_logging:
|
||||||
|
if runtime_shape is None:
|
||||||
|
logger.info("Compiling a graph for general shape")
|
||||||
|
else:
|
||||||
|
logger.info("Compiling a graph for shape %s", runtime_shape)
|
||||||
|
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
current_config = config.shallow_copy_dict()
|
current_config = config.shallow_copy_dict()
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
|
||||||
if additional_inductor_config is not None:
|
if additional_inductor_config is not None:
|
||||||
current_config.update(additional_inductor_config)
|
current_config.update(additional_inductor_config)
|
||||||
if current_config['post_grad_custom_post_pass'] is not None:
|
|
||||||
logger.warning(
|
# inductor can inplace modify the graph, so we need to copy it
|
||||||
"post_grad_custom_post_pass is already set in the config. "
|
# see https://github.com/pytorch/pytorch/issues/138980
|
||||||
"Overwriting it with the fix_functionalization")
|
graph = copy.deepcopy(graph)
|
||||||
current_config['post_grad_custom_post_pass'] = fix_functionalization
|
|
||||||
return compile_fx(graph, example_inputs, config_patches=current_config)
|
return compile_fx(graph, example_inputs, config_patches=current_config)
|
||||||
|
|
||||||
|
|
||||||
def vllm_backend(
|
@dataclasses.dataclass
|
||||||
|
class SplitItem:
|
||||||
|
submod_name: str
|
||||||
|
is_splitting_graph: bool
|
||||||
|
graph: fx.GraphModule
|
||||||
|
|
||||||
|
|
||||||
|
def split_graph(graph: fx.GraphModule,
|
||||||
|
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
|
||||||
|
# split graph by ops
|
||||||
|
subgraph_id = 0
|
||||||
|
node_to_subgraph_id = {}
|
||||||
|
split_op_graphs = []
|
||||||
|
for node in graph.graph.nodes:
|
||||||
|
if node.op in ("output", "placeholder"):
|
||||||
|
continue
|
||||||
|
if node.op == 'call_function' and str(node.target) in ops:
|
||||||
|
subgraph_id += 1
|
||||||
|
node_to_subgraph_id[node] = subgraph_id
|
||||||
|
split_op_graphs.append(subgraph_id)
|
||||||
|
subgraph_id += 1
|
||||||
|
else:
|
||||||
|
node_to_subgraph_id[node] = subgraph_id
|
||||||
|
|
||||||
|
# `keep_original_order` is important!
|
||||||
|
# otherwise pytorch might reorder the nodes and
|
||||||
|
# the semantics of the graph will change when we
|
||||||
|
# have mutations in the graph
|
||||||
|
split_gm = torch.fx.passes.split_module.split_module(
|
||||||
graph,
|
graph,
|
||||||
example_inputs,
|
None,
|
||||||
additional_inductor_config: Optional[Dict] = None) -> Callable:
|
lambda node: node_to_subgraph_id[node],
|
||||||
|
keep_original_order=True)
|
||||||
|
|
||||||
context = get_compile_context()
|
outputs = []
|
||||||
context = copy.deepcopy(context) if context is not None else []
|
|
||||||
sizes_to_specialize: List[int] = context
|
|
||||||
|
|
||||||
# flags for all the seen shapes, whether we need to specialize
|
# sort the names to make sure the order is deterministic
|
||||||
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
|
names = [name for (name, module) in split_gm.named_modules()]
|
||||||
|
names.sort()
|
||||||
|
|
||||||
# if we need to specialize, the compiled graph for that shape
|
for name in names:
|
||||||
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
|
if "." in name or name == "":
|
||||||
|
# recursive child module or the root module
|
||||||
|
continue
|
||||||
|
|
||||||
|
module = getattr(split_gm, name)
|
||||||
|
|
||||||
|
graph_id = int(name.replace("submod_", ""))
|
||||||
|
outputs.append(SplitItem(name, graph_id in split_op_graphs, module))
|
||||||
|
|
||||||
|
return split_gm, outputs
|
||||||
|
|
||||||
|
|
||||||
|
class VllmBackend:
|
||||||
|
"""The compilation backend for `torch.compile` with VLLM.
|
||||||
|
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||||
|
where we customize the compilation.
|
||||||
|
|
||||||
|
The major work of this backend is to split the graph into
|
||||||
|
piecewise graphs, and pass them to the piecewise backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
compilation_configs: CompilationConfig
|
||||||
|
graph_pool: Any
|
||||||
|
_called: bool = False
|
||||||
|
# the graph we compiled
|
||||||
|
graph: fx.GraphModule
|
||||||
|
# the stiching graph module for all the piecewise graphs
|
||||||
|
split_gm: fx.GraphModule
|
||||||
|
piecewise_graphs: List[SplitItem]
|
||||||
|
returned_callable: Callable
|
||||||
|
|
||||||
|
def __init__(self, ):
|
||||||
|
# every instance of VllmBackend has its own graph pool
|
||||||
|
self.graph_pool = torch.cuda.graph_pool_handle()
|
||||||
|
|
||||||
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
|
# do anything here
|
||||||
|
|
||||||
|
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||||
|
|
||||||
|
compilation_counter.num_graphs_seen += 1
|
||||||
|
|
||||||
|
# we control the compilation process, each instance can only be
|
||||||
|
# called once
|
||||||
|
assert not self._called, "VllmBackend can only be called once"
|
||||||
|
|
||||||
|
self.graph = graph
|
||||||
|
# config is read now, because only here can
|
||||||
|
# we get the sizes to capture for cudagraph
|
||||||
|
# from compilation context
|
||||||
|
self.compilation_configs = CompilationConfig.select_and_init_config()
|
||||||
|
|
||||||
|
self.split_gm, self.piecewise_graphs = split_graph(
|
||||||
|
graph, self.compilation_configs.non_cudagraph_ops)
|
||||||
|
|
||||||
|
returned_callable: Callable # type: ignore
|
||||||
|
|
||||||
|
if len(self.piecewise_graphs) == 0:
|
||||||
|
compilation_counter.num_piecewise_graphs_seen += 1
|
||||||
|
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||||
|
returned_callable = PiecewiseBackend(graph,
|
||||||
|
self.compilation_configs,
|
||||||
|
self.graph_pool,
|
||||||
|
is_first_graph=True)
|
||||||
|
else:
|
||||||
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
|
logger.debug(
|
||||||
|
"%s", lazy_format_graph_code("stiching module", self.split_gm))
|
||||||
|
|
||||||
|
is_first_graph = True
|
||||||
|
|
||||||
|
for item in self.piecewise_graphs:
|
||||||
|
compilation_counter.num_piecewise_graphs_seen += 1
|
||||||
|
compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa
|
||||||
|
if not item.is_splitting_graph:
|
||||||
|
# cannot setattr to a module, so we need to set
|
||||||
|
# the attribute in the __dict__
|
||||||
|
self.split_gm.__dict__[
|
||||||
|
item.submod_name] = PiecewiseBackend(
|
||||||
|
item.graph, self.compilation_configs,
|
||||||
|
self.graph_pool, is_first_graph)
|
||||||
|
is_first_graph = False
|
||||||
|
returned_callable = self.split_gm
|
||||||
|
|
||||||
|
self.returned_callable = returned_callable
|
||||||
|
# trigger the first compilation
|
||||||
|
# code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa
|
||||||
|
# to turn the inputs into fake tensors
|
||||||
|
import torch._guards
|
||||||
|
from torch._guards import detect_fake_mode
|
||||||
|
fake_mode = detect_fake_mode(example_inputs)
|
||||||
|
fake_args = []
|
||||||
|
for arg in example_inputs:
|
||||||
|
if isinstance(arg, torch.Tensor) and not isinstance(
|
||||||
|
arg, torch._subclasses.FakeTensor):
|
||||||
|
fake_args.append(
|
||||||
|
torch._dynamo.utils.to_fake_tensor(arg, fake_mode))
|
||||||
|
else:
|
||||||
|
fake_args.append(arg)
|
||||||
|
self.returned_callable(*fake_args)
|
||||||
|
|
||||||
|
self._called = True
|
||||||
|
|
||||||
|
return self.returned_callable
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ConcreteSizeEntry:
|
||||||
|
runtime_shape: int
|
||||||
|
need_to_compile: bool # the size is in compile_sizes
|
||||||
|
use_cudagraph: bool # the size is in capture_sizes
|
||||||
|
|
||||||
|
compiled: bool = False
|
||||||
|
runnable: Callable = None # type: ignore
|
||||||
|
num_finished_warmup: int = 0
|
||||||
|
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
output: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PiecewiseBackend:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
compilation_configs: CompilationConfig,
|
||||||
|
graph_pool: Any,
|
||||||
|
is_first_graph: bool = False):
|
||||||
|
"""
|
||||||
|
The backend for piecewise compilation.
|
||||||
|
It mainly handles the compilation and cudagraph capturing.
|
||||||
|
|
||||||
|
We will compile `self.graph` once for the general shape,
|
||||||
|
and then compile for different shapes specified in
|
||||||
|
`compilation_configs.compile_sizes`.
|
||||||
|
|
||||||
|
Independently, we will capture cudagraph for different shapes.
|
||||||
|
|
||||||
|
If a shape needs both compilation and cudagraph, we will
|
||||||
|
compile it first, and then capture cudagraph.
|
||||||
|
"""
|
||||||
|
self.graph = graph
|
||||||
|
self.compilation_configs = compilation_configs
|
||||||
|
self.graph_pool = graph_pool
|
||||||
|
self.is_first_graph = is_first_graph
|
||||||
|
|
||||||
|
self.compile_sizes: Set[int] = set(
|
||||||
|
self.compilation_configs.compile_sizes)
|
||||||
|
self.capture_sizes: Set[int] = set(
|
||||||
|
self.compilation_configs.capture_sizes
|
||||||
|
) if self.compilation_configs.use_cudagraph else set()
|
||||||
|
|
||||||
|
self.compile_finished = False
|
||||||
|
self.first_run_finished = False
|
||||||
|
|
||||||
|
self.compiled_graph_for_general_shape: Callable = None # type: ignore
|
||||||
|
|
||||||
|
self.sym_shape_indices: List[int] = []
|
||||||
|
|
||||||
|
# the entries for different shapes that we need to either
|
||||||
|
# compile or capture cudagraph
|
||||||
|
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||||
|
for shape in self.compile_sizes.union(self.capture_sizes):
|
||||||
|
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||||
|
runtime_shape=shape,
|
||||||
|
need_to_compile=shape in self.compile_sizes,
|
||||||
|
use_cudagraph=shape in self.capture_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, *args) -> Any:
|
||||||
|
|
||||||
|
if not self.compile_finished:
|
||||||
|
self.compile_finished = True
|
||||||
|
|
||||||
# this is the first compilation, we will compile a graph with
|
# this is the first compilation, we will compile a graph with
|
||||||
# dynamic shape, as the caller will mark first dimension as dynamic
|
# dynamic shape, as the caller will mark first dimension as dynamic
|
||||||
logger.info("Compiling a graph for general shapes")
|
|
||||||
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
|
|
||||||
additional_inductor_config)
|
|
||||||
|
|
||||||
# TODO: Dynamo does not pass all dynamic shapes.
|
self.sym_shape_indices = [
|
||||||
# Need to investigate why. It works now because all the dynamic
|
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||||
# shapes have the same value, and either of them can be used.
|
|
||||||
sym_shape_indices = [
|
|
||||||
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
first_run = True
|
self.compiled_graph_for_general_shape = wrap_inductor(
|
||||||
|
self.graph,
|
||||||
|
args,
|
||||||
|
self.compilation_configs.inductor_compile_config,
|
||||||
|
runtime_shape=None,
|
||||||
|
do_logging=self.is_first_graph,
|
||||||
|
use_inductor=self.compilation_configs.use_inductor)
|
||||||
|
|
||||||
# this is the function we return to Dynamo to run finally
|
return self.graph(*args)
|
||||||
def compiled_graph_wrapper(*args):
|
|
||||||
|
|
||||||
runtime_shapes: Tuple[int,
|
if not self.first_run_finished:
|
||||||
...] = tuple(args[i] for i in sym_shape_indices)
|
self.first_run_finished = True
|
||||||
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
|
||||||
nonlocal first_run
|
runtime_shape = args[self.sym_shape_indices[0]]
|
||||||
nonlocal runtime_shapes_to_compile_flags
|
if runtime_shape not in self.concrete_size_entries:
|
||||||
nonlocal runtime_shapes_to_compiled_graph
|
# we don't need to do anything for this shape
|
||||||
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
|
||||||
if first_run:
|
entry = self.concrete_size_entries[runtime_shape]
|
||||||
# the first compilation is for profiling, we directly run it
|
|
||||||
first_run = False
|
|
||||||
return graph_for_symbolic_shape(*args)
|
|
||||||
|
|
||||||
if runtime_shapes not in runtime_shapes_to_compile_flags:
|
if entry.runnable is None:
|
||||||
# we haven't seen this shape before
|
entry.runnable = self.compiled_graph_for_general_shape
|
||||||
# query if we need to specialize for this shape
|
|
||||||
# we only specialize for the first dimension.
|
|
||||||
# TODO: investigate if any model needs to specialize
|
|
||||||
# beyond the first dimension
|
|
||||||
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
|
|
||||||
0] in sizes_to_specialize
|
|
||||||
|
|
||||||
if not runtime_shapes_to_compile_flags[runtime_shapes]:
|
if entry.need_to_compile and not entry.compiled:
|
||||||
# we don't need to specialize for this shape
|
entry.compiled = True
|
||||||
return graph_for_symbolic_shape(*args)
|
# args are real arguments
|
||||||
|
entry.runnable = wrap_inductor(
|
||||||
|
self.graph,
|
||||||
|
args,
|
||||||
|
self.compilation_configs.inductor_compile_config,
|
||||||
|
runtime_shape=runtime_shape,
|
||||||
|
do_logging=self.is_first_graph,
|
||||||
|
use_inductor=self.compilation_configs.use_inductor)
|
||||||
|
|
||||||
if runtime_shapes not in runtime_shapes_to_compiled_graph:
|
if not entry.use_cudagraph:
|
||||||
# we need to specialize for this shape, and we haven't compiled
|
return entry.runnable(*args)
|
||||||
# compile the graph for this shape
|
|
||||||
logger.info("Compiling a graph for shapes %s", runtime_shapes)
|
|
||||||
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
|
|
||||||
graph, args, additional_inductor_config)
|
|
||||||
|
|
||||||
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
|
if entry.cudagraph is None:
|
||||||
|
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa
|
||||||
|
entry.num_finished_warmup += 1
|
||||||
|
if self.is_first_graph:
|
||||||
|
logger.debug(
|
||||||
|
"Warming up %s/%s for shape %s",
|
||||||
|
entry.num_finished_warmup,
|
||||||
|
self.compilation_configs.cudagraph_num_of_warmups,
|
||||||
|
runtime_shape)
|
||||||
|
return entry.runnable(*args)
|
||||||
|
|
||||||
return compiled_graph_wrapper
|
if self.is_first_graph:
|
||||||
|
logger.info("Capturing a cudagraph for shape %s",
|
||||||
|
runtime_shape)
|
||||||
|
|
||||||
|
cudagraph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||||
|
entry.output = weak_ref_tensors(entry.runnable(*args))
|
||||||
|
|
||||||
|
compilation_counter.num_cudagraph_caputured += 1
|
||||||
|
|
||||||
|
entry.cudagraph = cudagraph
|
||||||
|
return entry.output
|
||||||
|
|
||||||
|
entry.cudagraph.replay()
|
||||||
|
return entry.output
|
||||||
|
|
||||||
|
|
||||||
def select_default_backend(level: int) -> Union[str, Callable]:
|
def select_default_backend(level: int) -> Union[str, Callable]:
|
||||||
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||||
backend_str = "eager"
|
backend_str = "eager"
|
||||||
return backend_str
|
return backend_str
|
||||||
assert level in [
|
assert level == CompilationLevel.PIECEWISE
|
||||||
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
|
|
||||||
], f"Invalid level {level}"
|
|
||||||
|
|
||||||
from vllm.compilation.backends import vllm_backend
|
return VllmBackend()
|
||||||
from vllm.plugins import get_inductor_additional_configs
|
|
||||||
additional_configs = get_inductor_additional_configs()
|
|
||||||
|
|
||||||
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
|
|
||||||
if "max_autotune" in additional_configs and not additional_configs[
|
|
||||||
"max_autotune"]:
|
|
||||||
logger.warning(
|
|
||||||
"max_autotune is disabled, but is overridden by level %s",
|
|
||||||
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
|
|
||||||
additional_configs['max_autotune'] = True
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
backend = partial(vllm_backend,
|
|
||||||
additional_inductor_config=additional_configs)
|
|
||||||
|
|
||||||
return backend
|
|
||||||
|
154
vllm/compilation/config.py
Normal file
154
vllm/compilation/config.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from .compile_context import get_compile_context
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CompilationConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for compilation.
|
||||||
|
It has two parts:
|
||||||
|
- CudaGraph capture:
|
||||||
|
- use_cudagraph: whether to use cudagraph inside compilation.
|
||||||
|
- False: cudagraph inside compilation is not used.
|
||||||
|
- True: cudagraph inside compilation is used. It requires
|
||||||
|
that all input buffers have fixed addresses.
|
||||||
|
Note that this is orthogonal to the cudagraph capture out
|
||||||
|
side of compilation.
|
||||||
|
TODO: move outside cudagraph logic into compilation.
|
||||||
|
torch.compile will handle cudagraph capture logic in the future.
|
||||||
|
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
||||||
|
- None: capture sizes are inferred from compilation context.
|
||||||
|
- List[int]: capture sizes are specified.
|
||||||
|
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
|
||||||
|
It means the first several runs will be treated as warmup runs.
|
||||||
|
Only after that, the execution will be recorded, and the recorded
|
||||||
|
cudagraph will be used for subsequent runs.
|
||||||
|
- Inductor compilation:
|
||||||
|
- use_inductor: whether to use inductor compilation.
|
||||||
|
- False: inductor compilation is not used. graph runs in eager.
|
||||||
|
- True: inductor compilation is used. one graph for symbolic shape
|
||||||
|
is compiled. In addition, compile for different sizes specified
|
||||||
|
in inductor_compile_sizes, using configurations
|
||||||
|
in inductor_compile_config.
|
||||||
|
- inductor_compile_sizes: sizes to compile for inductor.
|
||||||
|
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
|
||||||
|
to specialize inductor for cudagraph sizes no more than the
|
||||||
|
specified size. It is useful when we want to specialize inductor
|
||||||
|
with a subset of cudagraph sizes.
|
||||||
|
- inductor_compile_config: additional configurations for inductor.
|
||||||
|
- None: use default configurations.
|
||||||
|
- inductor_passes: additional passes for inductor. It is a dictionary
|
||||||
|
from pass name to pass function qualified name. We use function
|
||||||
|
name because the config uses json format. If we pass the config
|
||||||
|
from Python, functions can also be passed directly via Python object
|
||||||
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||||
|
|
||||||
|
Why we have different sizes for cudagraph and inductor:
|
||||||
|
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||||
|
for the same size. We need to capture all the sizes we want to use.
|
||||||
|
- inductor: a graph compiled by inductor for a general shape can be used
|
||||||
|
for different sizes. Inductor can also compile for specific sizes,
|
||||||
|
where it can have more information to optimize the graph with fully
|
||||||
|
static shapes. However, we find the general shape compilation is
|
||||||
|
sufficient for most cases. It might be beneficial to compile for
|
||||||
|
certain small batchsizes, where inductor is good at optimizing.
|
||||||
|
"""
|
||||||
|
use_inductor: bool = True
|
||||||
|
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
|
||||||
|
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
|
||||||
|
inductor_compile_config: Dict = Field(default_factory=dict)
|
||||||
|
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
use_cudagraph: bool = False
|
||||||
|
non_cudagraph_ops: List[str] = Field(default_factory=list)
|
||||||
|
cudagraph_num_of_warmups: int = 0
|
||||||
|
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# not configurable, computed after init
|
||||||
|
compile_sizes: List[int] = PrivateAttr
|
||||||
|
capture_sizes: List[int] = PrivateAttr
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
for k, v in self.inductor_passes.items():
|
||||||
|
if not isinstance(v, str):
|
||||||
|
assert callable(v), (
|
||||||
|
f"pass {k} should be a function or a qualified name")
|
||||||
|
self.inductor_passes[k] = v
|
||||||
|
continue
|
||||||
|
|
||||||
|
# resolve function from qualified name
|
||||||
|
names = v.split(".")
|
||||||
|
module = ".".join(names[:-1])
|
||||||
|
func_name = names[-1]
|
||||||
|
func = __import__(module).__dict__[func_name]
|
||||||
|
self.inductor_compile_config[k] = func
|
||||||
|
|
||||||
|
from vllm.compilation.backends import fix_functionalization
|
||||||
|
from vllm.utils import combine_fx_passes
|
||||||
|
if "post_grad_custom_post_pass" in self.inductor_compile_config:
|
||||||
|
self.inductor_compile_config[
|
||||||
|
"post_grad_custom_post_pass"] = combine_fx_passes(
|
||||||
|
fix_functionalization,
|
||||||
|
self.inductor_compile_config["post_grad_custom_post_pass"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.inductor_compile_config[
|
||||||
|
"post_grad_custom_post_pass"] = fix_functionalization
|
||||||
|
|
||||||
|
def init_during_runtime(self):
|
||||||
|
"""To complete the initialization of config,
|
||||||
|
we need to know the compile context, which is only available
|
||||||
|
during the first run of the model.
|
||||||
|
"""
|
||||||
|
context = get_compile_context()
|
||||||
|
context = copy.deepcopy(context) if context is not None else []
|
||||||
|
sizes_to_specialize: List[int] = context
|
||||||
|
if self.cudagraph_capture_sizes is None:
|
||||||
|
self.capture_sizes = sizes_to_specialize
|
||||||
|
else:
|
||||||
|
self.capture_sizes = self.cudagraph_capture_sizes
|
||||||
|
logger.info(("cudagraph sizes specified by model runner"
|
||||||
|
" %s is overridden by config %s"),
|
||||||
|
sizes_to_specialize, self.cudagraph_capture_sizes)
|
||||||
|
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
|
||||||
|
assert self.inductor_compile_sizes is None, (
|
||||||
|
"inductor_compile_sizes should be None when "
|
||||||
|
"inductor_specialize_for_cudagraph_no_more_than is not None")
|
||||||
|
self.compile_sizes = [
|
||||||
|
x for x in self.capture_sizes
|
||||||
|
if x <= self.inductor_specialize_for_cudagraph_no_more_than
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert self.inductor_compile_sizes is not None, (
|
||||||
|
"inductor_compile_sizes should not be None when "
|
||||||
|
"inductor_specialize_for_cudagraph_no_more_than is None")
|
||||||
|
self.compile_sizes = self.inductor_compile_sizes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def select_and_init_config() -> "CompilationConfig":
|
||||||
|
"""The order of selecting config is:
|
||||||
|
1. Use the config specified in environment variable.
|
||||||
|
2. Use the config specified in plugins.
|
||||||
|
3. Use the default config.
|
||||||
|
"""
|
||||||
|
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
|
||||||
|
if config_path is not None:
|
||||||
|
with open(config_path) as json_file:
|
||||||
|
config = CompilationConfig.model_validate_json(
|
||||||
|
json_file.read())
|
||||||
|
else:
|
||||||
|
from vllm.plugins import get_compilation_config
|
||||||
|
predefined_config = get_compilation_config()
|
||||||
|
config = predefined_config if predefined_config is not None else (
|
||||||
|
CompilationConfig())
|
||||||
|
|
||||||
|
config.init_during_runtime()
|
||||||
|
return config
|
30
vllm/compilation/counter.py
Normal file
30
vllm/compilation/counter.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CompilationCounter:
|
||||||
|
num_graphs_seen: int = 0
|
||||||
|
# including the splitting ops
|
||||||
|
num_piecewise_graphs_seen: int = 0
|
||||||
|
# not including the splitting ops
|
||||||
|
num_piecewise_capturable_graphs_seen: int = 0
|
||||||
|
num_inductor_compilations: int = 0
|
||||||
|
num_cudagraph_caputured: int = 0
|
||||||
|
|
||||||
|
def clone(self) -> "CompilationCounter":
|
||||||
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def expect(self, **kwargs):
|
||||||
|
old = self.clone()
|
||||||
|
yield
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
assert getattr(self, k) - getattr(old, k) == v, (
|
||||||
|
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||||
|
f", after it is {getattr(self, k)}, "
|
||||||
|
f"expected diff is {v}")
|
||||||
|
|
||||||
|
|
||||||
|
compilation_counter = CompilationCounter()
|
@ -121,7 +121,10 @@ def _support_torch_compile(cls: type,
|
|||||||
# take care of method resolution order
|
# take care of method resolution order
|
||||||
# make sure super().__init__ is called on the base class
|
# make sure super().__init__ is called on the base class
|
||||||
# other than TorchCompileWrapperWithCustomDispatcher
|
# other than TorchCompileWrapperWithCustomDispatcher
|
||||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
|
||||||
|
# support decorating multiple times
|
||||||
|
cls.__bases__ = cls.__bases__ + (
|
||||||
|
TorchCompileWrapperWithCustomDispatcher, )
|
||||||
|
|
||||||
old_init = cls.__init__ # type: ignore
|
old_init = cls.__init__ # type: ignore
|
||||||
|
|
||||||
@ -160,6 +163,11 @@ def _support_torch_compile(cls: type,
|
|||||||
# compiled function and let torch.compile handle the dispatching,
|
# compiled function and let torch.compile handle the dispatching,
|
||||||
# with the overhead of guard evaluation and recompilation.
|
# with the overhead of guard evaluation and recompilation.
|
||||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||||
|
# it seems Dynamo reuse the compilation across instances,
|
||||||
|
# while we need to make sure the compiled code is not reused.
|
||||||
|
# we need to control all the compilation of the model.
|
||||||
|
torch._dynamo.eval_frame.remove_from_cache(
|
||||||
|
self.original_code_object)
|
||||||
return self.compiled_callable(*args, **kwargs)
|
return self.compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
# usually, capturing the model once is enough, and then we can
|
||||||
|
@ -5,5 +5,4 @@ class CompilationLevel:
|
|||||||
NO_COMPILATION = 0
|
NO_COMPILATION = 0
|
||||||
DYNAMO_AS_IS = 1
|
DYNAMO_AS_IS = 1
|
||||||
DYNAMO_ONCE = 2
|
DYNAMO_ONCE = 2
|
||||||
INDUCTOR = 3
|
PIECEWISE = 3
|
||||||
INDUCTOR_MAX_AUTOTUNE = 4
|
|
||||||
|
@ -209,6 +209,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
||||||
"VLLM_TORCH_COMPILE_LEVEL":
|
"VLLM_TORCH_COMPILE_LEVEL":
|
||||||
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
|
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
|
||||||
|
|
||||||
|
# Path to the config file for torch compile
|
||||||
|
"VLLM_TORCH_COMPILE_CONFIG":
|
||||||
|
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
|
||||||
|
|
||||||
# Fine-grained control over which custom ops to enable/disable.
|
# Fine-grained control over which custom ops to enable/disable.
|
||||||
# Use 'all' to enable all, 'none' to disable all.
|
# Use 'all' to enable all, 'none' to disable all.
|
||||||
# Also specify a list of custom op names to enable (prefixed with a '+'),
|
# Also specify a list of custom op names to enable (prefixed with a '+'),
|
||||||
|
@ -100,7 +100,7 @@ class CustomOp(nn.Module):
|
|||||||
|
|
||||||
return (CustomOp.default_on() or enabled) and not disabled
|
return (CustomOp.default_on() or enabled) and not disabled
|
||||||
|
|
||||||
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
|
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
|
||||||
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
|
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@ -108,7 +108,7 @@ class CustomOp(nn.Module):
|
|||||||
count_none = envs.VLLM_CUSTOM_OPS.count("none")
|
count_none = envs.VLLM_CUSTOM_OPS.count("none")
|
||||||
count_all = envs.VLLM_CUSTOM_OPS.count("all")
|
count_all = envs.VLLM_CUSTOM_OPS.count("all")
|
||||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||||
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
|
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \
|
||||||
not count_none > 0 or count_all > 0
|
not count_none > 0 or count_all > 0
|
||||||
|
|
||||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||||
|
@ -11,7 +11,7 @@ from .interface import Platform, PlatformEnum
|
|||||||
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
||||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
|
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
|
||||||
|
|
||||||
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\
|
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
|
||||||
"TPU does not support Inductor."
|
"TPU does not support Inductor."
|
||||||
|
|
||||||
set_torch_compile_backend("openxla")
|
set_torch_compile_backend("openxla")
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Callable, Dict, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.config import CompilationConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -44,13 +45,13 @@ def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
|
|||||||
return _torch_compile_backend
|
return _torch_compile_backend
|
||||||
|
|
||||||
|
|
||||||
_inductor_additional_configs: Dict = {}
|
_compilation_config: Optional[CompilationConfig] = None
|
||||||
|
|
||||||
|
|
||||||
def set_inductor_additional_configs(configs: Dict):
|
def set_compilation_config(config: Optional[CompilationConfig]):
|
||||||
global _inductor_additional_configs
|
global _compilation_config
|
||||||
_inductor_additional_configs = configs
|
_compilation_config = config
|
||||||
|
|
||||||
|
|
||||||
def get_inductor_additional_configs() -> Dict:
|
def get_compilation_config() -> Optional[CompilationConfig]:
|
||||||
return _inductor_additional_configs
|
return _compilation_config
|
||||||
|
@ -1479,6 +1479,15 @@ class LazyDict(Mapping, Generic[T]):
|
|||||||
return len(self._factory)
|
return len(self._factory)
|
||||||
|
|
||||||
|
|
||||||
|
def combine_fx_passes(passes: List[Callable]) -> Callable:
|
||||||
|
|
||||||
|
def combined_fx(graph) -> None:
|
||||||
|
for fx in passes:
|
||||||
|
fx(graph)
|
||||||
|
|
||||||
|
return combined_fx
|
||||||
|
|
||||||
|
|
||||||
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Create a weak reference to a tensor.
|
Create a weak reference to a tensor.
|
||||||
@ -1486,3 +1495,19 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
but will not keep the original tensor alive.
|
but will not keep the original tensor alive.
|
||||||
"""
|
"""
|
||||||
return torch.ops._C.weak_ref_tensor(tensor)
|
return torch.ops._C.weak_ref_tensor(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_tensors(
|
||||||
|
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
|
||||||
|
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Convenience function to create weak references to tensors,
|
||||||
|
for single tensor, list of tensors or tuple of tensors.
|
||||||
|
"""
|
||||||
|
if isinstance(tensors, torch.Tensor):
|
||||||
|
return weak_ref_tensor(tensors)
|
||||||
|
if isinstance(tensors, list):
|
||||||
|
return [weak_ref_tensor(t) for t in tensors]
|
||||||
|
if isinstance(tensors, tuple):
|
||||||
|
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||||
|
raise ValueError("Invalid type for tensors")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user