[plugin][torch.compile] allow to add custom compile backend (#8445)
This commit is contained in:
parent
ecd7a1d5b6
commit
0a4806f0a9
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
@ -29,3 +30,15 @@ def load_general_plugins():
|
||||
except Exception:
|
||||
logger.exception("Failed to load general plugin: %s",
|
||||
plugin.name)
|
||||
|
||||
|
||||
_torch_compile_backend: Optional[Union[Callable, str]] = None
|
||||
|
||||
|
||||
def set_torch_compile_backend(backend: Union[Callable, str]):
|
||||
global _torch_compile_backend
|
||||
_torch_compile_backend = backend
|
||||
|
||||
|
||||
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
|
||||
return _torch_compile_backend
|
||||
|
@ -1064,10 +1064,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"This may lead to less accurate results!")
|
||||
|
||||
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
|
||||
from vllm.plugins import get_torch_compile_backend
|
||||
backend = get_torch_compile_backend() or "eager"
|
||||
self.model = torch.compile(
|
||||
self.model,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend="eager")
|
||||
backend=backend)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user