[torch.compile] store inductor compiled Python file (#12182)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
630eb5b5ce
commit
e66faf4809
@ -25,23 +25,30 @@ from .pass_manager import PostGradPassManager
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class InductorArtifact:
|
||||||
|
hash_str: str = ""
|
||||||
|
file_path: str = ""
|
||||||
|
|
||||||
|
|
||||||
class InductorHashCache:
|
class InductorHashCache:
|
||||||
"""
|
"""
|
||||||
Disk format: a Python list of tuples, each tuple is
|
Disk format: a Python list of tuples, each tuple is
|
||||||
(runtime_shape, graph_index, hash_str)
|
(runtime_shape, graph_index, hash_str, file_path)
|
||||||
We use list of tuple for readability.
|
We use list of tuple for readability.
|
||||||
|
|
||||||
In-memory format: a defaultdict of dict, where the key is
|
In-memory format: a defaultdict of dict, where the key is
|
||||||
runtime_shape, and the value is a dict of graph_index to hash_str.
|
runtime_shape, and the value is a dict of graph_index to hash_str.
|
||||||
|
|
||||||
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
|
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
|
||||||
we don't use json here because json doesn't support int as key.
|
we don't use json here because json doesn't support int as key.
|
||||||
|
|
||||||
TODO: better off-the-shelf solution to serialize the data?
|
TODO: better off-the-shelf solution to serialize the data?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_dir: str, disabled: bool = False):
|
def __init__(self, cache_dir: str, disabled: bool = False):
|
||||||
self.cache: defaultdict = defaultdict(dict)
|
self.cache: Dict[Optional[int],
|
||||||
|
Dict[int, InductorArtifact]] = defaultdict(dict)
|
||||||
self.disabled = disabled
|
self.disabled = disabled
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.cache_file_path = os.path.join(cache_dir,
|
self.cache_file_path = os.path.join(cache_dir,
|
||||||
@ -66,14 +73,25 @@ class InductorHashCache:
|
|||||||
# because it is a safe way to parse Python literals.
|
# because it is a safe way to parse Python literals.
|
||||||
# do not use eval(), it is unsafe.
|
# do not use eval(), it is unsafe.
|
||||||
list_data = ast.literal_eval(data)
|
list_data = ast.literal_eval(data)
|
||||||
for runtime_shape, graph_index, hash_str in list_data:
|
for item in list_data:
|
||||||
self.cache[runtime_shape][graph_index] = hash_str
|
runtime_shape = item[0]
|
||||||
|
graph_index = item[1]
|
||||||
|
hash_str = item[2]
|
||||||
|
# for compatibility of old version,
|
||||||
|
# where we don't have file_path.
|
||||||
|
# NOTE: after running the new code, the file_path
|
||||||
|
# will be updated.
|
||||||
|
file_path = "" if len(item) == 3 else item[3]
|
||||||
|
self.cache[runtime_shape][graph_index] = InductorArtifact(
|
||||||
|
hash_str=hash_str, file_path=file_path)
|
||||||
|
|
||||||
def serialize(self) -> str:
|
def serialize(self) -> str:
|
||||||
data = []
|
data = []
|
||||||
for runtime_shape, graph_index_to_hash_str in self.cache.items():
|
for runtime_shape, value in self.cache.items():
|
||||||
for graph_index, hash_str in graph_index_to_hash_str.items():
|
for graph_index, inductor_artifact in value.items():
|
||||||
data.append((runtime_shape, graph_index, hash_str))
|
data.append(
|
||||||
|
(runtime_shape, graph_index, inductor_artifact.hash_str,
|
||||||
|
inductor_artifact.file_path))
|
||||||
printer = pprint.PrettyPrinter(indent=4)
|
printer = pprint.PrettyPrinter(indent=4)
|
||||||
return printer.pformat(data)
|
return printer.pformat(data)
|
||||||
|
|
||||||
@ -90,13 +108,14 @@ class InductorHashCache:
|
|||||||
return runtime_shape in self.cache and graph_index in self.cache[
|
return runtime_shape in self.cache and graph_index in self.cache[
|
||||||
runtime_shape]
|
runtime_shape]
|
||||||
|
|
||||||
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
|
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
raise KeyError("cannot read from disabled cache")
|
raise KeyError("cannot read from disabled cache")
|
||||||
runtime_shape, graph_index = key
|
runtime_shape, graph_index = key
|
||||||
return self.cache[runtime_shape][graph_index]
|
return self.cache[runtime_shape][graph_index]
|
||||||
|
|
||||||
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
|
def __setitem__(self, key: Tuple[Optional[int], int],
|
||||||
|
value: InductorArtifact):
|
||||||
# setitem for disabled cache is fine, because we
|
# setitem for disabled cache is fine, because we
|
||||||
# don't actually write to the disk
|
# don't actually write to the disk
|
||||||
runtime_shape, graph_index = key
|
runtime_shape, graph_index = key
|
||||||
@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
if (runtime_shape, graph_index) in cache_data:
|
if (runtime_shape, graph_index) in cache_data:
|
||||||
# we compiled this graph before
|
# we compiled this graph before
|
||||||
# so we can directly lookup the compiled graph via hash
|
# so we can directly lookup the compiled graph via hash
|
||||||
hash_str = cache_data[(runtime_shape, graph_index)]
|
inductor_artifact = cache_data[(runtime_shape, graph_index)]
|
||||||
|
hash_str = inductor_artifact.hash_str
|
||||||
if graph_index == 0:
|
if graph_index == 0:
|
||||||
# adds some info logging for the first graph
|
# adds some info logging for the first graph
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
"Inductor cache lookup failed. Please remove"
|
"Inductor cache lookup failed. Please remove"
|
||||||
f"the cache file {cache_data.cache_file_path} and try again." # noqa
|
f"the cache file {cache_data.cache_file_path} and try again." # noqa
|
||||||
)
|
)
|
||||||
|
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||||
|
|
||||||
# Inductor calling convention (function signature):
|
# Inductor calling convention (function signature):
|
||||||
# f(list) -> tuple
|
# f(list) -> tuple
|
||||||
@ -224,19 +245,20 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
# the assumption is that we don't have nested Inductor compilation.
|
# the assumption is that we don't have nested Inductor compilation.
|
||||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||||
# it to get the hash of the compiled graph directly.
|
# it to get the hash of the compiled graph directly.
|
||||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
|
||||||
|
inductor_artifact = InductorArtifact()
|
||||||
|
from torch._inductor.codecache import (FxGraphCache,
|
||||||
|
compiled_fx_graph_hash)
|
||||||
|
original_load = FxGraphCache.load
|
||||||
|
|
||||||
|
def hijack_load(*args, **kwargs):
|
||||||
|
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||||
|
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||||
|
return inductor_compiled_graph
|
||||||
|
|
||||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||||
# store the hash in the cache
|
inductor_artifact.hash_str = out[0]
|
||||||
nonlocal cache_data
|
|
||||||
cache_data[(runtime_shape, graph_index)] = out[0]
|
|
||||||
if graph_index == 0:
|
|
||||||
# adds some info logging for the first graph
|
|
||||||
logger.info("Cache the graph of shape %s for later use",
|
|
||||||
str(runtime_shape))
|
|
||||||
logger.debug("store the %s-th graph for shape %s via hash %s",
|
|
||||||
graph_index, str(runtime_shape), out[0])
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _check_can_cache(*args, **kwargs):
|
def _check_can_cache(*args, **kwargs):
|
||||||
@ -255,6 +277,11 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
if not cache_data.disabled:
|
if not cache_data.disabled:
|
||||||
# compilation cache is enabled, patch several functions
|
# compilation cache is enabled, patch several functions
|
||||||
|
|
||||||
|
# hijack to get the compiled graph itself
|
||||||
|
stack.enter_context(
|
||||||
|
patch("torch._inductor.codecache.FxGraphCache.load",
|
||||||
|
hijack_load))
|
||||||
|
|
||||||
# for hijacking the hash of the compiled graph
|
# for hijacking the hash of the compiled graph
|
||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
||||||
@ -275,7 +302,16 @@ def wrap_inductor(graph: fx.GraphModule,
|
|||||||
compiled_graph = compile_fx(graph,
|
compiled_graph = compile_fx(graph,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
config_patches=current_config)
|
config_patches=current_config)
|
||||||
|
# store the inductor_artifact in the cache
|
||||||
|
cache_data[(runtime_shape, graph_index)] = inductor_artifact
|
||||||
|
if graph_index == 0:
|
||||||
|
# adds some info logging for the first graph
|
||||||
|
logger.info("Cache the graph of shape %s for later use",
|
||||||
|
str(runtime_shape))
|
||||||
|
logger.debug(
|
||||||
|
"store the %s-th graph for shape %s via hash %s from file %s",
|
||||||
|
graph_index, str(runtime_shape), inductor_artifact.hash_str,
|
||||||
|
inductor_artifact.file_path)
|
||||||
# after compiling the last graph, record the end time
|
# after compiling the last graph, record the end time
|
||||||
if graph_index == num_graphs - 1:
|
if graph_index == num_graphs - 1:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
@ -2862,17 +2862,8 @@ class CompilationConfig(BaseModel):
|
|||||||
"vllm.unified_attention_with_output",
|
"vllm.unified_attention_with_output",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# v0 can use full graph compilation without splitting,
|
# v0 uses full graph compilation
|
||||||
# splitting is optional.
|
self.splitting_ops = []
|
||||||
# right now we still need it. kv cache shape
|
|
||||||
# will be included in the graph if we don't split
|
|
||||||
# the graph.
|
|
||||||
# TODO: hide kv cache in static forward context
|
|
||||||
# so that inductor does not see it.
|
|
||||||
self.splitting_ops = [
|
|
||||||
"vllm.unified_attention",
|
|
||||||
"vllm.unified_attention_with_output",
|
|
||||||
]
|
|
||||||
|
|
||||||
for k, v in self.inductor_passes.items():
|
for k, v in self.inductor_passes.items():
|
||||||
if not isinstance(v, str):
|
if not isinstance(v, str):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user