[torch.compile] avoid Dynamo guard evaluation overhead (#7898)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
3cdfe1f38b
commit
ce6bf3a2cf
@ -12,4 +12,4 @@ remove_docker_container
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||
|
@ -173,6 +173,7 @@ steps:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
|
||||
- label: Vision Language Models Test # 42min
|
||||
|
59
tests/compile/test_wrapper.py
Normal file
59
tests/compile/test_wrapper.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
|
||||
|
||||
|
||||
class MyMod(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
if cache is not None:
|
||||
return x + cache
|
||||
return x * 2
|
||||
|
||||
|
||||
class MyWrapper(TorchCompileWrapperWithCustomDispacther):
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||
super().__init__(compiled_callable)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
# this is the function to be compiled
|
||||
return self.model(x, cache)
|
||||
|
||||
def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
# let torch.compile compile twice
|
||||
if len(self.compiled_codes) == 2:
|
||||
dispatch_id = 0 if cache is None else 1
|
||||
with self.dispatch_to_code(dispatch_id):
|
||||
return self.forward(x, cache)
|
||||
else:
|
||||
return self.compiled_callable(x, cache)
|
||||
|
||||
|
||||
def test_torch_compile_wrapper():
|
||||
mod = MyMod()
|
||||
wrappers = []
|
||||
for i in range(3):
|
||||
torch._dynamo.reset()
|
||||
wrapper = MyWrapper(mod)
|
||||
wrappers.append(wrapper)
|
||||
x = torch.tensor([1])
|
||||
wrapper(x, None) # profile run, compile
|
||||
# create a cache tensor
|
||||
cache = torch.tensor([2])
|
||||
wrapper(x, cache) # warm up with cache, recompile
|
||||
|
||||
# for new input, dispatch to the compiled code directly
|
||||
new_x = torch.tensor([3])
|
||||
assert wrapper(new_x,
|
||||
None).item() == 6 # dispatch to the first compiled code
|
||||
assert wrapper(
|
||||
new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||
|
||||
for wrapper in wrappers:
|
||||
# make sure they have independent compiled codes
|
||||
assert len(wrapper.compiled_codes) == 2
|
0
tests/tpu/__init__.py
Normal file
0
tests/tpu/__init__.py
Normal file
9
tests/tpu/test_custom_dispatcher.py
Normal file
9
tests/tpu/test_custom_dispatcher.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
|
||||
def test_custom_dispatcher():
|
||||
compare_two_settings("google/gemma-2b",
|
||||
arg1=["--enforce-eager"],
|
||||
arg2=["--enforce-eager"],
|
||||
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
|
||||
env2={})
|
0
vllm/compilation/__init__.py
Normal file
0
vllm/compilation/__init__.py
Normal file
81
vllm/compilation/wrapper.py
Normal file
81
vllm/compilation/wrapper.py
Normal file
@ -0,0 +1,81 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispacther:
|
||||
"""
|
||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||
Subclasses should:
|
||||
1. Implement the forward method
|
||||
2. Implement the dispatch logic in the __call__ method
|
||||
It can use `self.compiled_codes` to access the compiled bytecode,
|
||||
and `with self.dispatch_to_code(index):` to dispatch to
|
||||
the compiled code.
|
||||
3. Implement the `__init__` method to determine how to call
|
||||
`torch.compile` over the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_callable: Callable):
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: List[CodeType] = []
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = \
|
||||
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
NOTE: this function can have additional arguments beyond the forward
|
||||
method, for directly dispatching to the compiled code.
|
||||
"""
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object:
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while True:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
"""Context manager to dispatch to the compiled code.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa
|
||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||
yield
|
||||
self.__class__.forward.__code__ = self.original_code_object
|
@ -196,6 +196,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Internal flag to enable Dynamo graph capture
|
||||
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
|
||||
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
|
||||
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
|
||||
lambda:
|
||||
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
|
@ -10,6 +10,7 @@ import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
@ -144,11 +145,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
)
|
||||
model = model.eval()
|
||||
xm.wait_device_ops()
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
self.model = ModelWrapper(model)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
@ -235,8 +232,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(p, 0)
|
||||
# Dummy run.
|
||||
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
||||
num_samples, kv_caches)
|
||||
self.model(token_ids,
|
||||
position_ids,
|
||||
attn_metadata,
|
||||
input_lens,
|
||||
t,
|
||||
p,
|
||||
num_samples,
|
||||
kv_caches,
|
||||
is_prompt=is_prompt)
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
@ -530,7 +534,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
if getattr(arg, "context_lens", None) is not None:
|
||||
arg.context_lens = arg.context_lens.to(self.device)
|
||||
new_args.append(arg)
|
||||
return self.model(*new_args)
|
||||
return self.model(*new_args, is_prompt=is_prompt)
|
||||
|
||||
num_prefills = model_input.attn_metadata.num_prefills
|
||||
is_prompt = num_prefills > 0
|
||||
@ -601,11 +605,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
return [SamplerOutput(sampler_outputs)]
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
compiled_callable = torch.compile(self.forward,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
super().__init__(compiled_callable)
|
||||
|
||||
def __call__(self, *args, is_prompt: bool, **kwargs):
|
||||
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
|
||||
# not fully compiled yet, or not using the custom dispatcher,
|
||||
# let PyTorch handle it
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
# the 3 compiled codes are:
|
||||
# 0: for profiling
|
||||
# 1: for prompt
|
||||
# 2: for decode
|
||||
# dispatch to the compiled code directly, skip PyTorch
|
||||
if is_prompt:
|
||||
with self.dispatch_to_code(1):
|
||||
return self.forward(*args, **kwargs)
|
||||
else:
|
||||
with self.dispatch_to_code(2):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user