449 lines
16 KiB
Python
449 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
Test the piecewise compilation with a simple model, comparing the output
|
|
with and without the piecewise compilation.
|
|
|
|
This is a tractable model, the weights and computation are specially designed
|
|
if the config `tractable_init` is set to True. Otherwise, the weights are
|
|
initialized randomly with a fixed seed.
|
|
"""
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.library import Library
|
|
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
|
set_current_vllm_config)
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
# create a library to hold the custom op
|
|
silly_lib = Library("silly", "FRAGMENT") # noqa
|
|
|
|
|
|
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
out: torch.Tensor) -> None:
|
|
out.copy_(q)
|
|
out += k
|
|
out += v
|
|
|
|
|
|
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
out: torch.Tensor) -> None:
|
|
return
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="attention",
|
|
op_func=silly_attention,
|
|
mutates_args=["out"],
|
|
fake_impl=silly_attention_fake,
|
|
target_lib=silly_lib,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class LlamaConfig:
|
|
hidden_size: int = 128
|
|
mlp_size: int = 256
|
|
vocab_size: int = 128
|
|
num_layers: int = 2
|
|
init_value: float = 1.0
|
|
tractable_init: bool = False
|
|
random_seed: int = 0
|
|
|
|
def compute_hash(self) -> str:
|
|
factors: list[Any] = []
|
|
for k, v in self.__dict__.items():
|
|
if k == "random_seed":
|
|
continue
|
|
factors.append((k, v))
|
|
factors.sort()
|
|
import hashlib
|
|
return hashlib.md5(str(factors).encode()).hexdigest()
|
|
|
|
def __post_init__(self):
|
|
assert self.mlp_size >= self.hidden_size
|
|
|
|
|
|
class LlamaMLP(nn.Module):
|
|
|
|
def __init__(self, config: LlamaConfig) -> None:
|
|
super().__init__()
|
|
self.gate_up_projection = nn.Linear(
|
|
in_features=config.hidden_size,
|
|
out_features=config.mlp_size * 2,
|
|
bias=False,
|
|
)
|
|
self.down_projection = nn.Linear(
|
|
in_features=config.mlp_size,
|
|
out_features=config.hidden_size,
|
|
bias=False,
|
|
)
|
|
|
|
if config.tractable_init:
|
|
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
|
|
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
|
|
nn.init.eye_(self.down_projection.weight.data)
|
|
else:
|
|
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
|
|
generator=torch.Generator().manual_seed(
|
|
config.random_seed),
|
|
gain=0.001)
|
|
nn.init.xavier_normal_(self.down_projection.weight.data,
|
|
generator=torch.Generator().manual_seed(
|
|
config.random_seed),
|
|
gain=0.001)
|
|
|
|
def forward(self, x):
|
|
# for tractable_init and positive input, this is
|
|
# essentially an elementwise-square
|
|
x = self.gate_up_projection(x)
|
|
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
|
x[:, x.size(1) // 2:])
|
|
x = self.down_projection(x)
|
|
return x
|
|
|
|
|
|
class LlamaAttention(nn.Module):
|
|
|
|
def __init__(self, config: LlamaConfig) -> None:
|
|
super().__init__()
|
|
self.qkv_projection = nn.Linear(
|
|
in_features=config.hidden_size,
|
|
out_features=config.hidden_size * 3,
|
|
bias=False,
|
|
)
|
|
|
|
self.output_projection = nn.Linear(
|
|
in_features=config.hidden_size,
|
|
out_features=config.hidden_size,
|
|
bias=False,
|
|
)
|
|
|
|
if config.tractable_init:
|
|
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
|
|
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
|
|
config.hidden_size])
|
|
nn.init.eye_(self.qkv_projection.weight.data[2 *
|
|
config.hidden_size:])
|
|
nn.init.eye_(self.output_projection.weight.data)
|
|
else:
|
|
nn.init.xavier_normal_(self.qkv_projection.weight.data,
|
|
generator=torch.Generator().manual_seed(
|
|
config.random_seed),
|
|
gain=0.001)
|
|
nn.init.xavier_normal_(self.output_projection.weight.data,
|
|
generator=torch.Generator().manual_seed(
|
|
config.random_seed),
|
|
gain=0.001)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# for tractable_init, this is:
|
|
# output = (hidden_states * 3 + positions * 2)
|
|
qkv = self.qkv_projection(hidden_states)
|
|
hidden_size = qkv.size(-1) // 3
|
|
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
|
|
|
|
q = q + positions.unsqueeze(1)
|
|
k = k + positions.unsqueeze(1)
|
|
|
|
attn_output = torch.empty_like(q)
|
|
torch.ops.silly.attention(q, k, v, attn_output)
|
|
|
|
output = self.output_projection(attn_output)
|
|
return output
|
|
|
|
|
|
class LlamaDecoderLayer(nn.Module):
|
|
|
|
def __init__(self, config: LlamaConfig) -> None:
|
|
super().__init__()
|
|
self.self_attention = LlamaAttention(config)
|
|
self.mlp = LlamaMLP(config)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
For tractable computation:
|
|
- if residual is None, the outputs are:
|
|
- residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
|
|
- hidden_states = (residual + 1) ** 2
|
|
- if residual is not None, the outputs are:
|
|
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
|
|
- hidden_states = (residual + 1) ** 2
|
|
""" # noqa
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = hidden_states + 1
|
|
else:
|
|
hidden_states = hidden_states + residual
|
|
residual = hidden_states
|
|
hidden_states = hidden_states + 1
|
|
|
|
hidden_states = self.self_attention(positions=positions,
|
|
hidden_states=hidden_states)
|
|
|
|
hidden_states = hidden_states + residual
|
|
residual = hidden_states
|
|
hidden_states = hidden_states + 1
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
return hidden_states, residual
|
|
|
|
|
|
@support_torch_compile
|
|
class LlamaModel(nn.Module):
|
|
|
|
def __init__(self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
config: LlamaConfig,
|
|
prefix: str = '',
|
|
**kwargs) -> None:
|
|
super().__init__()
|
|
self.embedding_tokens = nn.Embedding(
|
|
num_embeddings=config.vocab_size,
|
|
embedding_dim=config.hidden_size,
|
|
)
|
|
self.layers = nn.ModuleList(
|
|
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
|
|
|
# this is the initial value of the hidden states
|
|
self.embedding_tokens.weight.data.fill_(config.init_value)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor],
|
|
positions: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.embedding_tokens(input_ids)
|
|
residual = None
|
|
for layer in self.layers:
|
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
def tractable_computation(input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
config: LlamaConfig,
|
|
init_value: float = 1.0) -> torch.Tensor:
|
|
hidden_states = torch.ones(input_ids.size(0),
|
|
config.hidden_size,
|
|
device=input_ids.device,
|
|
dtype=input_ids.dtype) * init_value
|
|
|
|
# first layer
|
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
|
hidden_states = (residual + 1)**2
|
|
|
|
# following layers
|
|
for _ in range(config.num_layers - 1):
|
|
hidden_states = hidden_states + residual
|
|
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
|
hidden_states = (residual + 1)**2
|
|
|
|
return hidden_states
|
|
|
|
|
|
@torch.inference_mode
|
|
def run_model(llama_config,
|
|
use_compile: bool,
|
|
split_attn: bool = False) -> torch.Tensor:
|
|
|
|
if use_compile:
|
|
compilation_config = CompilationConfig(
|
|
level=CompilationLevel.PIECEWISE,
|
|
use_cudagraph=True,
|
|
cudagraph_capture_sizes=[1, 2],
|
|
)
|
|
if split_attn:
|
|
compilation_config.splitting_ops = ["silly.attention"]
|
|
else:
|
|
compilation_config = CompilationConfig(
|
|
level=CompilationLevel.NO_COMPILATION, )
|
|
|
|
vllm_config = VllmConfig(compilation_config=compilation_config,
|
|
additional_config=llama_config)
|
|
with set_current_vllm_config(vllm_config):
|
|
model = LlamaModel(config=llama_config,
|
|
vllm_config=vllm_config,
|
|
prefix="").eval().cuda()
|
|
|
|
B = 16 # max batch size
|
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
|
positions = torch.arange(B).cuda()
|
|
|
|
model(input_ids, positions)
|
|
model(input_ids[:2], positions[:2])
|
|
model(input_ids[:1], positions[:1])
|
|
|
|
input_ids[:2].zero_()
|
|
output = model(input_ids[:2], positions[:2])
|
|
|
|
output = output.cpu()
|
|
|
|
if llama_config.tractable_init:
|
|
expected_output = tractable_computation(input_ids[:2], positions[:2],
|
|
llama_config).cpu()
|
|
|
|
assert torch.allclose(output, expected_output)
|
|
else:
|
|
return output.cpu()
|
|
|
|
|
|
def test_toy_llama():
|
|
# compare output with and without piecewise compilation
|
|
|
|
llama_config = LlamaConfig(hidden_size=128,
|
|
mlp_size=256,
|
|
vocab_size=128,
|
|
num_layers=12)
|
|
|
|
tractable_config = LlamaConfig(hidden_size=128,
|
|
mlp_size=256,
|
|
vocab_size=128,
|
|
num_layers=2,
|
|
tractable_init=True)
|
|
|
|
outputs = []
|
|
with compilation_counter.expect(
|
|
num_graphs_seen=0,
|
|
num_piecewise_graphs_seen=0,
|
|
num_piecewise_capturable_graphs_seen=0,
|
|
num_backend_compilations=0,
|
|
num_cudagraph_caputured=0,
|
|
):
|
|
outputs.append(run_model(llama_config, use_compile=False))
|
|
run_model(tractable_config, use_compile=False)
|
|
|
|
with compilation_counter.expect(
|
|
num_graphs_seen=1, # one graph for the model
|
|
num_piecewise_graphs_seen=1,
|
|
num_piecewise_capturable_graphs_seen=1,
|
|
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
|
num_cudagraph_caputured=
|
|
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
):
|
|
outputs.append(run_model(llama_config, use_compile=True))
|
|
run_model(tractable_config, use_compile=True)
|
|
|
|
with compilation_counter.expect(
|
|
num_graphs_seen=1, # one graph for the model
|
|
num_piecewise_graphs_seen=2 * llama_config.num_layers +
|
|
1, # 2 * num_layers + 1
|
|
num_piecewise_capturable_graphs_seen=1 +
|
|
llama_config.num_layers, # 1 + num_layers
|
|
num_backend_compilations=1 +
|
|
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
|
|
num_cudagraph_caputured=2 *
|
|
(1 + llama_config.num_layers
|
|
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
|
):
|
|
outputs.append(
|
|
run_model(llama_config, use_compile=True, split_attn=True))
|
|
run_model(tractable_config, use_compile=True, split_attn=True)
|
|
|
|
for i in range(1, len(outputs)):
|
|
assert torch.allclose(outputs[0], outputs[i])
|
|
|
|
|
|
@torch.inference_mode
|
|
def benchmark():
|
|
from triton.testing import do_bench
|
|
|
|
# similar to llama 3.1-8B
|
|
llama_config = LlamaConfig(hidden_size=4096,
|
|
mlp_size=14336,
|
|
vocab_size=128 * 1024,
|
|
num_layers=32)
|
|
|
|
# a tiny model to measure the overhead
|
|
# of piecewise cudagraph
|
|
llama_config = LlamaConfig(hidden_size=40,
|
|
mlp_size=80,
|
|
vocab_size=128,
|
|
num_layers=2)
|
|
|
|
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
|
|
|
|
eager_time = {}
|
|
full_cudagraph_time = {}
|
|
piecewise_cudagraph_time = {}
|
|
|
|
pool = torch.cuda.graph_pool_handle()
|
|
|
|
for piecewise in [False, True]:
|
|
if piecewise:
|
|
compilation_config = CompilationConfig(
|
|
level=CompilationLevel.PIECEWISE,
|
|
use_cudagraph=True,
|
|
splitting_ops=["silly.attention"],
|
|
cudagraph_capture_sizes=cudagraph_sizes,
|
|
)
|
|
else:
|
|
compilation_config = CompilationConfig(
|
|
level=CompilationLevel.PIECEWISE,
|
|
cudagraph_capture_sizes=cudagraph_sizes,
|
|
)
|
|
|
|
vllm_config = VllmConfig(compilation_config=compilation_config)
|
|
with set_current_vllm_config(vllm_config):
|
|
model = LlamaModel(config=llama_config,
|
|
vllm_config=vllm_config,
|
|
prefix="").eval().cuda().to(torch.bfloat16)
|
|
|
|
B = 256 # max batch size
|
|
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
|
positions = torch.arange(B).cuda().to(torch.bfloat16)
|
|
|
|
graphs = {}
|
|
|
|
model(input_ids, positions)
|
|
for b in cudagraph_sizes[::-1]:
|
|
if not piecewise:
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, pool=pool):
|
|
output = model(input_ids[:b], positions[:b])
|
|
graphs[b] = (graph, output)
|
|
else:
|
|
output = model(input_ids[:b], positions[:b])
|
|
graphs[b] = (model, output)
|
|
for b in cudagraph_sizes:
|
|
if piecewise:
|
|
# noqa is for `Function definition does not bind loop variable`
|
|
# it will be problematic if we save the created lambda function
|
|
# and use it later, because it will look up the name `b` in the
|
|
# enclosing scope, and the value of `b` will always be 256.
|
|
# it is fine here, because we only use the lambda function once.
|
|
runtime = do_bench(lambda: graphs[b][0] # noqa
|
|
(input_ids[:b], positions[:b])) # noqa
|
|
piecewise_cudagraph_time[b] = runtime
|
|
else:
|
|
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
|
|
eager_runtime = do_bench(
|
|
lambda: model(input_ids[:b], positions[:b])) # noqa
|
|
full_cudagraph_time[b] = runtime
|
|
eager_time[b] = eager_runtime
|
|
|
|
# print in tabular format
|
|
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
|
|
for b in cudagraph_sizes:
|
|
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
|
|
f"\t{piecewise_cudagraph_time[b]:.3f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
benchmark()
|