[v1] fix compilation cache (#11598)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-30 12:24:12 +08:00 committed by GitHub
parent 0aa38d16f5
commit 3682e33f9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 14 deletions

View File

@ -7,7 +7,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, List, Optional, Tuple
import torch
from torch import nn
@ -54,6 +54,16 @@ class LlamaConfig:
tractable_init: bool = False
random_seed: int = 0
def compute_hash(self) -> str:
factors: List[Any] = []
for k, v in self.__dict__.items():
if k == "random_seed":
continue
factors.append((k, v))
factors.sort()
import hashlib
return hashlib.md5(str(factors).encode()).hexdigest()
def __post_init__(self):
assert self.mlp_size >= self.hidden_size
@ -263,7 +273,8 @@ def run_model(llama_config,
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
vllm_config = VllmConfig(compilation_config=compilation_config)
vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,

View File

@ -619,8 +619,10 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
self.capture_sizes)
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
@ -628,12 +630,17 @@ class PiecewiseBackend:
use_cudagraph=shape in self.capture_sizes,
)
def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
# no specific sizes to compile
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.vllm_config)
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
@ -662,10 +669,7 @@ class PiecewiseBackend:
# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
self.check_for_ending_compilation()
if not entry.use_cudagraph:
return entry.runnable(*args)

View File

@ -9,8 +9,8 @@ from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
Union)
Final, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, Type, Union)
import torch
from pydantic import BaseModel, Field, PrivateAttr
@ -75,6 +75,12 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
class SupportsHash(Protocol):
def compute_hash(self) -> str:
...
class ModelConfig:
"""Configuration for the model.
@ -2969,6 +2975,10 @@ class VllmConfig:
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging.
additional_config: SupportsHash = field(default=None,
init=True) # type: ignore
instance_id: str = ""
def compute_hash(self) -> str:
@ -3000,33 +3010,62 @@ class VllmConfig:
vllm_factors.append(__version__)
if self.model_config:
vllm_factors.append(self.model_config.compute_hash())
else:
vllm_factors.append("None")
if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash())
else:
vllm_factors.append("None")
if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash())
else:
vllm_factors.append("None")
if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash())
else:
vllm_factors.append("None")
if self.device_config:
vllm_factors.append(self.device_config.compute_hash())
else:
vllm_factors.append("None")
if self.load_config:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
else:
vllm_factors.append("None")
if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash())
else:
vllm_factors.append("None")
if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash())
else:
vllm_factors.append("None")
if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
vllm_factors.append(self.additional_config.compute_hash())
else:
vllm_factors.append("None")
factors.append(vllm_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]

View File

@ -48,6 +48,7 @@ class Worker:
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method