[v1] fix compilation cache (#11598)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-30 12:24:12 +08:00 committed by GitHub
parent 0aa38d16f5
commit 3682e33f9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 14 deletions

View File

@ -7,7 +7,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed. initialized randomly with a fixed seed.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -54,6 +54,16 @@ class LlamaConfig:
tractable_init: bool = False tractable_init: bool = False
random_seed: int = 0 random_seed: int = 0
def compute_hash(self) -> str:
factors: List[Any] = []
for k, v in self.__dict__.items():
if k == "random_seed":
continue
factors.append((k, v))
factors.sort()
import hashlib
return hashlib.md5(str(factors).encode()).hexdigest()
def __post_init__(self): def __post_init__(self):
assert self.mlp_size >= self.hidden_size assert self.mlp_size >= self.hidden_size
@ -263,7 +273,8 @@ def run_model(llama_config,
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ) level=CompilationLevel.NO_COMPILATION, )
vllm_config = VllmConfig(compilation_config=compilation_config) vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = LlamaModel(config=llama_config,
vllm_config=vllm_config, vllm_config=vllm_config,

View File

@ -619,8 +619,10 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either # the entries for different shapes that we need to either
# compile or capture cudagraph # compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
self.capture_sizes) # to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.capture_sizes): for shape in self.compile_sizes.union(self.capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry( self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape, runtime_shape=shape,
@ -628,12 +630,17 @@ class PiecewiseBackend:
use_cudagraph=shape in self.capture_sizes, use_cudagraph=shape in self.capture_sizes,
) )
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any: def __call__(self, *args) -> Any:
if not self.first_run_finished: if not self.first_run_finished:
self.first_run_finished = True self.first_run_finished = True
# no specific sizes to compile self.check_for_ending_compilation()
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.vllm_config)
return self.compiled_graph_for_general_shape(*args) return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]] runtime_shape = args[self.sym_shape_indices[0]]
@ -662,10 +669,7 @@ class PiecewiseBackend:
# finished compilations for all required shapes # finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
if not entry.use_cudagraph: if not entry.use_cudagraph:
return entry.runnable(*args) return entry.runnable(*args)

View File

@ -9,8 +9,8 @@ from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Final, List, Literal, Mapping, Optional, Protocol, Set,
Union) Tuple, Type, Union)
import torch import torch
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@ -75,6 +75,12 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]] PretrainedConfig]]
class SupportsHash(Protocol):
def compute_hash(self) -> str:
...
class ModelConfig: class ModelConfig:
"""Configuration for the model. """Configuration for the model.
@ -2969,6 +2975,10 @@ class VllmConfig:
init=True) # type: ignore init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None, kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore init=True) # type: ignore
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging.
additional_config: SupportsHash = field(default=None,
init=True) # type: ignore
instance_id: str = "" instance_id: str = ""
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -3000,33 +3010,62 @@ class VllmConfig:
vllm_factors.append(__version__) vllm_factors.append(__version__)
if self.model_config: if self.model_config:
vllm_factors.append(self.model_config.compute_hash()) vllm_factors.append(self.model_config.compute_hash())
else:
vllm_factors.append("None")
if self.cache_config: if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash()) vllm_factors.append(self.cache_config.compute_hash())
else:
vllm_factors.append("None")
if self.parallel_config: if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash()) vllm_factors.append(self.parallel_config.compute_hash())
else:
vllm_factors.append("None")
if self.scheduler_config: if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash()) vllm_factors.append(self.scheduler_config.compute_hash())
else:
vllm_factors.append("None")
if self.device_config: if self.device_config:
vllm_factors.append(self.device_config.compute_hash()) vllm_factors.append(self.device_config.compute_hash())
else:
vllm_factors.append("None")
if self.load_config: if self.load_config:
vllm_factors.append(self.load_config.compute_hash()) vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config: if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash()) vllm_factors.append(self.lora_config.compute_hash())
else:
vllm_factors.append("None")
if self.speculative_config: if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash()) vllm_factors.append(self.speculative_config.compute_hash())
else:
vllm_factors.append("None")
if self.decoding_config: if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash()) vllm_factors.append(self.decoding_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config: if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash()) vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config: if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash()) vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config: if self.quant_config:
pass # should be captured by model_config.quantization pass # should be captured by model_config.quantization
if self.compilation_config: if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash()) vllm_factors.append(self.compilation_config.compute_hash())
else:
vllm_factors.append("None")
if self.kv_transfer_config: if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash()) vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
vllm_factors.append(self.additional_config.compute_hash())
else:
vllm_factors.append("None")
factors.append(vllm_factors) factors.append(vllm_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]

View File

@ -48,6 +48,7 @@ class Worker:
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method