[torch.compile] allow candidate compile sizes (#10984)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-08 03:05:21 -08:00 committed by GitHub
parent 7be15d9356
commit fd57d2b534
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 35 deletions

View File

@ -50,12 +50,12 @@ def test_compilation_config():
args = parser.parse_args(["-O=3"]) args = parser.parse_args(["-O=3"])
assert args.compilation_config.level == 3 assert args.compilation_config.level == 3
# set to json # set to string form of a dict
args = parser.parse_args(["--compilation-config", '{"level": 3}']) args = parser.parse_args(["--compilation-config", "{'level': 3}"])
assert args.compilation_config.level == 3 assert args.compilation_config.level == 3
# set to json # set to string form of a dict
args = parser.parse_args(['--compilation-config={"level": 3}']) args = parser.parse_args(["--compilation-config={'level': 3}"])
assert args.compilation_config.level == 3 assert args.compilation_config.level == 3

View File

@ -1,3 +1,4 @@
import ast
import copy import copy
import enum import enum
import hashlib import hashlib
@ -2191,14 +2192,10 @@ class CompilationConfig(BaseModel):
- use_inductor: whether to use inductor compilation. - use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager. - False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape - True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified is compiled. In addition, compile for cudagraph sizes that are
in inductor_compile_sizes, using configurations in candidate_compile_sizes, using configurations
in inductor_compile_config. in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor. - candidate_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor. - inductor_compile_config: additional configurations for inductor.
- None: use default configurations. - None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary - inductor_passes: additional passes for inductor. It is a dictionary
@ -2227,8 +2224,7 @@ class CompilationConfig(BaseModel):
]) ])
use_inductor: bool = True use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None candidate_compile_sizes: Optional[List[int]] = Field(default=None)
inductor_compile_sizes: Optional[List[int]] = Field(default=None)
inductor_compile_config: Dict = Field(default_factory=dict) inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict)
@ -2294,7 +2290,9 @@ class CompilationConfig(BaseModel):
"""Parse the CLI value for the compilation config.""" """Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]: if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value)) return cls(level=int(cli_value))
return CompilationConfig.model_validate_json(cli_value) # do not use `eval`, it is dangerous and can execute arbitrary code
dict_value = ast.literal_eval(cli_value)
return CompilationConfig.model_validate(dict_value)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
@ -2355,18 +2353,20 @@ class CompilationConfig(BaseModel):
logger.info(("cudagraph sizes specified by model runner" logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"), " %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes) sizes_to_specialize, self.cudagraph_capture_sizes)
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
assert self.inductor_compile_sizes is None, ( if self.candidate_compile_sizes is None:
"inductor_compile_sizes should be None when " self.candidate_compile_sizes = []
"inductor_specialize_for_cudagraph_no_more_than is not None")
self.compile_sizes = [ self.compile_sizes = [
x for x in self.capture_sizes x for x in self.candidate_compile_sizes if x in self.capture_sizes
if x <= self.inductor_specialize_for_cudagraph_no_more_than
] ]
else: ignored_sizes = [
if self.inductor_compile_sizes is None: x for x in self.candidate_compile_sizes
self.inductor_compile_sizes = [] if x not in self.capture_sizes
self.compile_sizes = self.inductor_compile_sizes ]
if ignored_sizes:
logger.warning(("candidate_compile_sizes %s are ignored "
"because they are not cudagraph capture sizes."),
ignored_sizes)
# sort to make sure cudagraph capture sizes are in descending order # sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True) self.capture_sizes.sort(reverse=True)

View File

@ -209,12 +209,9 @@ class EngineArgs:
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
# CompilationConfig object # CompilationConfig object
if isinstance(self.compilation_config, (int)): if isinstance(self.compilation_config, (int, dict)):
self.compilation_config = CompilationConfig.from_cli( self.compilation_config = CompilationConfig.from_cli(
str(self.compilation_config)) str(self.compilation_config))
elif isinstance(self.compilation_config, (dict)):
self.compilation_config = CompilationConfig.from_cli(
json.dumps(self.compilation_config))
# Setup plugins # Setup plugins
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins

View File

@ -1,5 +1,4 @@
import itertools import itertools
import json
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
@ -186,12 +185,9 @@ class LLM:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
if compilation_config is not None: if compilation_config is not None:
if isinstance(compilation_config, (int)): if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli( compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config)) str(compilation_config))
elif isinstance(compilation_config, (dict)):
compilation_config_instance = CompilationConfig.from_cli(
json.dumps(compilation_config))
else: else:
compilation_config_instance = compilation_config compilation_config_instance = compilation_config
else: else: