[v1] fix compilation cache (#11598)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
0aa38d16f5
commit
3682e33f9f
@ -7,7 +7,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
|
|||||||
initialized randomly with a fixed seed.
|
initialized randomly with a fixed seed.
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -54,6 +54,16 @@ class LlamaConfig:
|
|||||||
tractable_init: bool = False
|
tractable_init: bool = False
|
||||||
random_seed: int = 0
|
random_seed: int = 0
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
factors: List[Any] = []
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if k == "random_seed":
|
||||||
|
continue
|
||||||
|
factors.append((k, v))
|
||||||
|
factors.sort()
|
||||||
|
import hashlib
|
||||||
|
return hashlib.md5(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.mlp_size >= self.hidden_size
|
assert self.mlp_size >= self.hidden_size
|
||||||
|
|
||||||
@ -263,7 +273,8 @@ def run_model(llama_config,
|
|||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.NO_COMPILATION, )
|
level=CompilationLevel.NO_COMPILATION, )
|
||||||
|
|
||||||
vllm_config = VllmConfig(compilation_config=compilation_config)
|
vllm_config = VllmConfig(compilation_config=compilation_config,
|
||||||
|
additional_config=llama_config)
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
model = LlamaModel(config=llama_config,
|
model = LlamaModel(config=llama_config,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
|
@ -619,8 +619,10 @@ class PiecewiseBackend:
|
|||||||
# the entries for different shapes that we need to either
|
# the entries for different shapes that we need to either
|
||||||
# compile or capture cudagraph
|
# compile or capture cudagraph
|
||||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||||
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
|
|
||||||
self.capture_sizes)
|
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||||
|
# and updates during the compilation process, so we need to copy it
|
||||||
|
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
|
||||||
for shape in self.compile_sizes.union(self.capture_sizes):
|
for shape in self.compile_sizes.union(self.capture_sizes):
|
||||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||||
runtime_shape=shape,
|
runtime_shape=shape,
|
||||||
@ -628,12 +630,17 @@ class PiecewiseBackend:
|
|||||||
use_cudagraph=shape in self.capture_sizes,
|
use_cudagraph=shape in self.capture_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_for_ending_compilation(self):
|
||||||
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
|
# no specific sizes to compile
|
||||||
|
# save the hash of the inductor graph for the next run
|
||||||
|
self.compilation_config.inductor_hash_cache.save_to_file()
|
||||||
|
end_monitoring_torch_compile(self.vllm_config)
|
||||||
|
|
||||||
def __call__(self, *args) -> Any:
|
def __call__(self, *args) -> Any:
|
||||||
if not self.first_run_finished:
|
if not self.first_run_finished:
|
||||||
self.first_run_finished = True
|
self.first_run_finished = True
|
||||||
# no specific sizes to compile
|
self.check_for_ending_compilation()
|
||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
|
||||||
end_monitoring_torch_compile(self.vllm_config)
|
|
||||||
return self.compiled_graph_for_general_shape(*args)
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
|
||||||
runtime_shape = args[self.sym_shape_indices[0]]
|
runtime_shape = args[self.sym_shape_indices[0]]
|
||||||
@ -662,10 +669,7 @@ class PiecewiseBackend:
|
|||||||
|
|
||||||
# finished compilations for all required shapes
|
# finished compilations for all required shapes
|
||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
|
self.check_for_ending_compilation()
|
||||||
# save the hash of the inductor graph for the next run
|
|
||||||
self.compilation_config.inductor_hash_cache.save_to_file()
|
|
||||||
end_monitoring_torch_compile(self.vllm_config)
|
|
||||||
|
|
||||||
if not entry.use_cudagraph:
|
if not entry.use_cudagraph:
|
||||||
return entry.runnable(*args)
|
return entry.runnable(*args)
|
||||||
|
@ -9,8 +9,8 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
|
||||||
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
|
Final, List, Literal, Mapping, Optional, Protocol, Set,
|
||||||
Union)
|
Tuple, Type, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@ -75,6 +75,12 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
|
|||||||
PretrainedConfig]]
|
PretrainedConfig]]
|
||||||
|
|
||||||
|
|
||||||
|
class SupportsHash(Protocol):
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
"""Configuration for the model.
|
"""Configuration for the model.
|
||||||
|
|
||||||
@ -2969,6 +2975,10 @@ class VllmConfig:
|
|||||||
init=True) # type: ignore
|
init=True) # type: ignore
|
||||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||||
init=True) # type: ignore
|
init=True) # type: ignore
|
||||||
|
# some opaque config, only used to provide additional information
|
||||||
|
# for the hash computation, mainly used for testing and debugging.
|
||||||
|
additional_config: SupportsHash = field(default=None,
|
||||||
|
init=True) # type: ignore
|
||||||
instance_id: str = ""
|
instance_id: str = ""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
@ -3000,33 +3010,62 @@ class VllmConfig:
|
|||||||
vllm_factors.append(__version__)
|
vllm_factors.append(__version__)
|
||||||
if self.model_config:
|
if self.model_config:
|
||||||
vllm_factors.append(self.model_config.compute_hash())
|
vllm_factors.append(self.model_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.cache_config:
|
if self.cache_config:
|
||||||
vllm_factors.append(self.cache_config.compute_hash())
|
vllm_factors.append(self.cache_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.parallel_config:
|
if self.parallel_config:
|
||||||
vllm_factors.append(self.parallel_config.compute_hash())
|
vllm_factors.append(self.parallel_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.scheduler_config:
|
if self.scheduler_config:
|
||||||
vllm_factors.append(self.scheduler_config.compute_hash())
|
vllm_factors.append(self.scheduler_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.device_config:
|
if self.device_config:
|
||||||
vllm_factors.append(self.device_config.compute_hash())
|
vllm_factors.append(self.device_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.load_config:
|
if self.load_config:
|
||||||
vllm_factors.append(self.load_config.compute_hash())
|
vllm_factors.append(self.load_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
vllm_factors.append(self.lora_config.compute_hash())
|
vllm_factors.append(self.lora_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
vllm_factors.append(self.speculative_config.compute_hash())
|
vllm_factors.append(self.speculative_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.decoding_config:
|
if self.decoding_config:
|
||||||
vllm_factors.append(self.decoding_config.compute_hash())
|
vllm_factors.append(self.decoding_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.observability_config:
|
if self.observability_config:
|
||||||
vllm_factors.append(self.observability_config.compute_hash())
|
vllm_factors.append(self.observability_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.prompt_adapter_config:
|
if self.prompt_adapter_config:
|
||||||
vllm_factors.append(self.prompt_adapter_config.compute_hash())
|
vllm_factors.append(self.prompt_adapter_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.quant_config:
|
if self.quant_config:
|
||||||
pass # should be captured by model_config.quantization
|
pass # should be captured by model_config.quantization
|
||||||
if self.compilation_config:
|
if self.compilation_config:
|
||||||
vllm_factors.append(self.compilation_config.compute_hash())
|
vllm_factors.append(self.compilation_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.kv_transfer_config:
|
if self.kv_transfer_config:
|
||||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
|
if self.additional_config:
|
||||||
|
vllm_factors.append(self.additional_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
factors.append(vllm_factors)
|
factors.append(vllm_factors)
|
||||||
|
|
||||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||||
|
@ -48,6 +48,7 @@ class Worker:
|
|||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||||
self.observability_config = vllm_config.observability_config
|
self.observability_config = vllm_config.observability_config
|
||||||
|
|
||||||
|
self.parallel_config.rank = rank
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
|
Loading…
x
Reference in New Issue
Block a user