[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-16 18:02:14 -08:00 committed by GitHub
parent 661a34fd4f
commit 4fd9375028
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 359 additions and 283 deletions

View File

@ -11,8 +11,8 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
global_counter = 0
@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
model = SillyModel(vllm_config=VllmConfig(), prefix='')
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
inputs = torch.randn(100).cuda()

View File

@ -15,12 +15,10 @@ from torch import nn
from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
@ -272,9 +270,11 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda()
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
@ -395,9 +395,11 @@ def benchmark():
else:
set_compilation_config(None)
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda().to(torch.bfloat16)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional
import pytest
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.utils import cuda_device_count_stateless
from ..utils import compare_all_settings

View File

@ -1,6 +1,6 @@
import pytest
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support

View File

@ -3,10 +3,10 @@ import torch
from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)

View File

@ -3,6 +3,7 @@ from typing import Optional
import torch
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel
class MyMod(torch.nn.Module):
@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)
super().__init__(compiled_callable,
compilation_level=CompilationLevel.DYNAMO_ONCE)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled

View File

@ -4,7 +4,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.platforms import current_platform
TEST_MODELS = [

View File

@ -3,11 +3,13 @@ from typing import List
import pytest
from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config
# Registered subclass for test
@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on
# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
ops_enabled = [bool(x) for x in ops_enabled]
assert CustomOp.default_on() == default_on
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
ops_enabled = [bool(x) for x in ops_enabled]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()
with pytest.raises(AssertionError):
RMSNorm(1024).enabled()
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()

View File

@ -5,7 +5,7 @@ import tempfile
import depyf
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# all the control

View File

@ -1,6 +1,6 @@
import os
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from ..utils import compare_two_settings

View File

@ -10,13 +10,12 @@ import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors
from .config import CompilationConfig
from .counter import compilation_counter
from .fusion import FusionPass
from .levels import CompilationLevel
from .reshapes import RedundantReshapesPass
logger = init_logger(__name__)
@ -392,7 +391,10 @@ class VllmBackend:
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
def __init__(
self,
compilation_configs: CompilationConfig,
):
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()
@ -401,11 +403,13 @@ class VllmBackend:
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes
self.post_grad_passes = []
self.sym_tensor_indices = []
self.input_buffers = []
self.compilation_configs = compilation_configs
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
@ -437,10 +441,10 @@ class VllmBackend:
assert not self._called, "VllmBackend can only be called once"
self.graph = graph
# config is read now, because only here can
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs = CompilationConfig.select_and_init_config()
self.compilation_configs.init_during_runtime()
self.add_passes_to_config()
self.split_gm, self.piecewise_graphs = split_graph(
@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]:
return backend_str
assert level == CompilationLevel.PIECEWISE
return VllmBackend()
from vllm.plugins import get_current_vllm_config
compilation_config = get_current_vllm_config().compilation_config
return VllmBackend(compilation_config)

View File

@ -1,159 +0,0 @@
import copy
from pathlib import Path
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
import vllm.envs as envs
from vllm.logger import init_logger
from .compile_context import get_compile_context
logger = init_logger(__name__)
class CompilationConfig(BaseModel):
"""
Configuration for compilation.
It has two parts:
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None: capture sizes are inferred from compilation context.
- List[int]: capture sizes are specified.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- Custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
- inductor: a graph compiled by inductor for a general shape can be used
for different sizes. Inductor can also compile for specific sizes,
where it can have more information to optimize the graph with fully
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
"""
use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
def model_post_init(self, __context: Any) -> None:
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
f"pass {k} should be a function or a qualified name")
self.inductor_compile_config[k] = v
continue
# resolve function from qualified name
names = v.split(".")
module = ".".join(names[:-1])
func_name = names[-1]
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
self.capture_sizes = self.cudagraph_capture_sizes
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
assert self.inductor_compile_sizes is None, (
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None")
self.compile_sizes = [
x for x in self.capture_sizes
if x <= self.inductor_specialize_for_cudagraph_no_more_than
]
else:
assert self.inductor_compile_sizes is not None, (
"inductor_compile_sizes should not be None when "
"inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes
@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())
config.init_during_runtime()
return config

View File

@ -3,10 +3,8 @@ from typing import Dict, List, Optional, Union
import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
@ -126,12 +124,14 @@ def _support_torch_compile(cls: type,
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
return
TorchCompileWrapperWithCustomDispatcher.__init__(self)
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
cls.__init__ = __init__ # type: ignore

View File

@ -6,8 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)
from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import CompilationConfig
from vllm.logger import init_logger
logger = init_logger(__name__)

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
import torch
from vllm.compilation.config import CompilationConfig
from vllm.config import CompilationConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (

View File

@ -1,8 +0,0 @@
# constants for the levels of the compilation process
class CompilationLevel:
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
PIECEWISE = 3

View File

@ -8,8 +8,7 @@ from typing import Callable, List, Optional
import torch
import vllm.envs as envs
from .levels import CompilationLevel
from vllm.config import CompilationLevel
class TorchCompileWrapperWithCustomDispatcher:
@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method.
"""
def __init__(self, compiled_callable: Optional[Callable] = None):
def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
if compiled_callable is None:
# default compilation settings
@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher:
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
backend = select_default_backend(compilation_level)
compiled_callable = torch.compile(
self.forward,
@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
compilation_level >= CompilationLevel.DYNAMO_ONCE
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.

View File

@ -3,10 +3,12 @@ import enum
import json
import warnings
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
Literal, Mapping, Optional, Set, Tuple, Type, Union)
import torch
from pydantic import BaseModel, Field, PrivateAttr
from transformers import PretrainedConfig
import vllm.envs as envs
@ -2052,6 +2054,185 @@ class ObservabilityConfig:
f"installed. Original error:\n{otel_import_error_traceback}")
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
PIECEWISE = 3
class CompilationConfig(BaseModel):
"""
Configuration for compilation.
It has three parts:
- Top-level Compilation control:
- level: the level of compilation.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
or disable (prefixed with a '-').
Examples:
- 'all,-op1' to enable all except op1
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None: capture sizes are inferred from compilation context.
- List[int]: capture sizes are specified.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
- inductor: a graph compiled by inductor for a general shape can be used
for different sizes. Inductor can also compile for specific sizes,
where it can have more information to optimize the graph with fully
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
""" # noqa
level: int = 0
custom_ops: List[str] = Field(default_factory=list)
use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
def model_post_init(self, __context: Any) -> None:
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
f"pass {k} should be a function or a qualified name")
self.inductor_compile_config[k] = v
continue
# resolve function from qualified name
names = v.split(".")
module = ".".join(names[:-1])
func_name = names[-1]
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from vllm.compilation.compile_context import get_compile_context
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
self.capture_sizes = self.cudagraph_capture_sizes
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
assert self.inductor_compile_sizes is None, (
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None")
self.compile_sizes = [
x for x in self.capture_sizes
if x <= self.inductor_specialize_for_cudagraph_no_more_than
]
else:
assert self.inductor_compile_sizes is not None, (
"inductor_compile_sizes should not be None when "
"inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes
@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())
return config
@dataclass
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
@ -2073,6 +2254,8 @@ class VllmConfig:
observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] = None
quant_config: Optional[QuantizationConfig] = None
compilation_config: CompilationConfig = field(default=None,
init=True) # type: ignore
@staticmethod
def _get_quantization_config(
@ -2133,6 +2316,12 @@ class VllmConfig:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)
if self.compilation_config is None:
self.compilation_config = CompilationConfig.select_and_init_config(
)
current_platform.check_and_update_config(self)
def __str__(self):
return ("model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "

View File

@ -69,7 +69,6 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
@ -217,18 +216,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TORCH_COMPILE_CONFIG":
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS":
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":

View File

@ -1,12 +1,10 @@
from functools import lru_cache
from typing import Dict, Type
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import print_warning_once
logger = init_logger(__name__)
@ -87,6 +85,8 @@ class CustomOp(nn.Module):
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
print_warning_once(
f"Custom op {cls.__name__} was not registered, "
@ -94,22 +94,25 @@ class CustomOp(nn.Module):
f"It will be enabled/disabled based on the global settings.")
return CustomOp.default_on()
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
enabled = f"+{cls.name}" in custom_ops
disabled = f"-{cls.name}" in custom_ops
assert not (enabled
and disabled), f"Cannot enable and disable {cls.name}"
return (CustomOp.default_on() or enabled) and not disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod
@lru_cache
def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \
"""
On by default if level < CompilationLevel.PIECEWISE
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
count_none = custom_ops.count("none")
count_all = custom_ops.count("all")
return compilation_config.level < CompilationLevel.PIECEWISE and \
not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).

View File

@ -42,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
from vllm.utils import is_pin_memory_available
@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
return model_class(vllm_config=vllm_config, prefix=prefix)
with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
return model_class(**kwargs)
with set_current_vllm_config(vllm_config):
return model_class(**kwargs)
class BaseModelLoader(ABC):

View File

@ -1,10 +1,15 @@
import enum
import random
from typing import NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
@ -129,6 +134,19 @@ class Platform:
np.random.seed(seed)
torch.manual_seed(seed)
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"""
Check and update the configuration for the current platform.
It can raise an exception if the configuration is not compatible with
the current platform, or it can update the configuration to make it
compatible with the current platform.
The config is passed by reference, so it can be modified in place.
"""
pass
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -1,18 +1,16 @@
import os
from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
set_torch_compile_backend("openxla")
@ -31,3 +29,12 @@ class TpuPlatform(Platform):
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel
compilation_config = vllm_config.compilation_config
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."

View File

@ -1,11 +1,11 @@
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union
import vllm.envs as envs
if TYPE_CHECKING:
from vllm.compilation.config import CompilationConfig
from vllm.config import VllmConfig
from vllm.config import CompilationConfig, VllmConfig
else:
CompilationConfig = None
VllmConfig = None
@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config
_current_vllm_config: Optional[VllmConfig] = None
@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
try:
_current_vllm_config = vllm_config
yield
finally:
_current_vllm_config = old_vllm_config
def get_current_vllm_config() -> VllmConfig:
assert _current_vllm_config is not None, "Current VLLM config is not set."
return _current_vllm_config

View File

@ -1,4 +1,3 @@
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
@ -8,11 +7,8 @@ import torch
import torch.distributed
import torch.nn as nn
from vllm import envs
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
@ -99,7 +95,7 @@ class GPUModelRunner:
pin_memory=self.pin_memory,
)
self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
@ -517,9 +513,9 @@ class GPUModelRunner:
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config(
CompilationConfig(
custom_ops=["none"],
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
use_inductor=True,

View File

@ -19,8 +19,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
@ -1142,8 +1141,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
and supports_dynamo():
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile(

View File

@ -140,7 +140,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
self.model = ModelWrapper(model)
self.model = ModelWrapper(model, self.vllm_config)
def _dummy_run(
self,
@ -669,13 +669,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model: nn.Module):
def __init__(self, model: nn.Module, vllm_config: VllmConfig):
self.model = model
compiled_callable = torch.compile(self.forward,
backend="openxla",
fullgraph=True,
dynamic=False)
super().__init__(compiled_callable)
super().__init__(
compiled_callable,
compilation_level=vllm_config.compilation_config.level)
def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: