[torch.compile] avoid Dynamo guard evaluation overhead (#7898)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
3cdfe1f38b
commit
ce6bf3a2cf
@ -12,4 +12,4 @@ remove_docker_container
|
|||||||
# For HF_TOKEN.
|
# For HF_TOKEN.
|
||||||
source /etc/environment
|
source /etc/environment
|
||||||
# Run a simple end-to-end example.
|
# Run a simple end-to-end example.
|
||||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||||
|
@ -173,6 +173,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s ./compile/test_full_graph.py
|
- pytest -v -s ./compile/test_full_graph.py
|
||||||
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
|
|
||||||
|
|
||||||
- label: Vision Language Models Test # 42min
|
- label: Vision Language Models Test # 42min
|
||||||
|
59
tests/compile/test_wrapper.py
Normal file
59
tests/compile/test_wrapper.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
|
||||||
|
|
||||||
|
|
||||||
|
class MyMod(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
|
if cache is not None:
|
||||||
|
return x + cache
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
|
||||||
|
class MyWrapper(TorchCompileWrapperWithCustomDispacther):
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||||
|
super().__init__(compiled_callable)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
|
# this is the function to be compiled
|
||||||
|
return self.model(x, cache)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
|
# let torch.compile compile twice
|
||||||
|
if len(self.compiled_codes) == 2:
|
||||||
|
dispatch_id = 0 if cache is None else 1
|
||||||
|
with self.dispatch_to_code(dispatch_id):
|
||||||
|
return self.forward(x, cache)
|
||||||
|
else:
|
||||||
|
return self.compiled_callable(x, cache)
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_compile_wrapper():
|
||||||
|
mod = MyMod()
|
||||||
|
wrappers = []
|
||||||
|
for i in range(3):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
wrapper = MyWrapper(mod)
|
||||||
|
wrappers.append(wrapper)
|
||||||
|
x = torch.tensor([1])
|
||||||
|
wrapper(x, None) # profile run, compile
|
||||||
|
# create a cache tensor
|
||||||
|
cache = torch.tensor([2])
|
||||||
|
wrapper(x, cache) # warm up with cache, recompile
|
||||||
|
|
||||||
|
# for new input, dispatch to the compiled code directly
|
||||||
|
new_x = torch.tensor([3])
|
||||||
|
assert wrapper(new_x,
|
||||||
|
None).item() == 6 # dispatch to the first compiled code
|
||||||
|
assert wrapper(
|
||||||
|
new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||||
|
|
||||||
|
for wrapper in wrappers:
|
||||||
|
# make sure they have independent compiled codes
|
||||||
|
assert len(wrapper.compiled_codes) == 2
|
0
tests/tpu/__init__.py
Normal file
0
tests/tpu/__init__.py
Normal file
9
tests/tpu/test_custom_dispatcher.py
Normal file
9
tests/tpu/test_custom_dispatcher.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from ..utils import compare_two_settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_dispatcher():
|
||||||
|
compare_two_settings("google/gemma-2b",
|
||||||
|
arg1=["--enforce-eager"],
|
||||||
|
arg2=["--enforce-eager"],
|
||||||
|
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
|
||||||
|
env2={})
|
0
vllm/compilation/__init__.py
Normal file
0
vllm/compilation/__init__.py
Normal file
81
vllm/compilation/wrapper.py
Normal file
81
vllm/compilation/wrapper.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from abc import abstractmethod
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from types import CodeType
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCompileWrapperWithCustomDispacther:
|
||||||
|
"""
|
||||||
|
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||||
|
Subclasses should:
|
||||||
|
1. Implement the forward method
|
||||||
|
2. Implement the dispatch logic in the __call__ method
|
||||||
|
It can use `self.compiled_codes` to access the compiled bytecode,
|
||||||
|
and `with self.dispatch_to_code(index):` to dispatch to
|
||||||
|
the compiled code.
|
||||||
|
3. Implement the `__init__` method to determine how to call
|
||||||
|
`torch.compile` over the forward method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, compiled_callable: Callable):
|
||||||
|
self.compiled_callable = compiled_callable
|
||||||
|
self.original_code_object = self.__class__.forward.__code__
|
||||||
|
self.compiled_codes: List[CodeType] = []
|
||||||
|
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||||
|
|
||||||
|
# read the env var to determine whether to use the custom dispatcher
|
||||||
|
# subclasses can use this to switch between the custom dispatcher
|
||||||
|
# and the default Dynamo guard mechanism.
|
||||||
|
self.use_custom_dispatcher: bool = \
|
||||||
|
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||||
|
NOTE: this function can have additional arguments beyond the forward
|
||||||
|
method, for directly dispatching to the compiled code.
|
||||||
|
"""
|
||||||
|
return self.compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
|
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||||
|
"""Hook to save the compiled bytecode for direct execution."""
|
||||||
|
if old_code is not self.original_code_object:
|
||||||
|
return
|
||||||
|
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||||
|
frame = sys._getframe()
|
||||||
|
while True:
|
||||||
|
frame = frame.f_back
|
||||||
|
code_name = frame.f_code.co_name
|
||||||
|
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||||
|
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||||
|
break
|
||||||
|
frame = frame.f_locals["frame"]
|
||||||
|
assert frame.f_code == old_code
|
||||||
|
|
||||||
|
if frame.f_locals["self"] is not self:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.compiled_codes.append(new_code)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def dispatch_to_code(self, index: int):
|
||||||
|
"""Context manager to dispatch to the compiled code.
|
||||||
|
Why does this work? Because Dynamo guarantees that the compiled
|
||||||
|
bytecode has exactly the same arguments, cell variables, and free
|
||||||
|
variables as the original code. Therefore we can directly switch
|
||||||
|
the code object in the function and call it.
|
||||||
|
|
||||||
|
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||||
|
""" # noqa
|
||||||
|
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||||
|
yield
|
||||||
|
self.__class__.forward.__code__ = self.original_code_object
|
@ -196,6 +196,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# Internal flag to enable Dynamo graph capture
|
# Internal flag to enable Dynamo graph capture
|
||||||
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
|
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
|
||||||
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
|
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
|
||||||
|
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
|
||||||
|
lambda:
|
||||||
|
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# local rank of the process in the distributed setting, used to determine
|
# local rank of the process in the distributed setting, used to determine
|
||||||
# the GPU device id
|
# the GPU device id
|
||||||
|
@ -10,6 +10,7 @@ import torch_xla.core.xla_model as xm
|
|||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -144,11 +145,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
)
|
)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
model = ModelWrapper(model)
|
self.model = ModelWrapper(model)
|
||||||
self.model = torch.compile(model,
|
|
||||||
backend="openxla",
|
|
||||||
fullgraph=True,
|
|
||||||
dynamic=False)
|
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
@ -235,8 +232,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
torch._dynamo.mark_dynamic(t, 0)
|
torch._dynamo.mark_dynamic(t, 0)
|
||||||
torch._dynamo.mark_dynamic(p, 0)
|
torch._dynamo.mark_dynamic(p, 0)
|
||||||
# Dummy run.
|
# Dummy run.
|
||||||
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
self.model(token_ids,
|
||||||
num_samples, kv_caches)
|
position_ids,
|
||||||
|
attn_metadata,
|
||||||
|
input_lens,
|
||||||
|
t,
|
||||||
|
p,
|
||||||
|
num_samples,
|
||||||
|
kv_caches,
|
||||||
|
is_prompt=is_prompt)
|
||||||
|
|
||||||
def warmup_model(
|
def warmup_model(
|
||||||
self,
|
self,
|
||||||
@ -530,7 +534,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
if getattr(arg, "context_lens", None) is not None:
|
if getattr(arg, "context_lens", None) is not None:
|
||||||
arg.context_lens = arg.context_lens.to(self.device)
|
arg.context_lens = arg.context_lens.to(self.device)
|
||||||
new_args.append(arg)
|
new_args.append(arg)
|
||||||
return self.model(*new_args)
|
return self.model(*new_args, is_prompt=is_prompt)
|
||||||
|
|
||||||
num_prefills = model_input.attn_metadata.num_prefills
|
num_prefills = model_input.attn_metadata.num_prefills
|
||||||
is_prompt = num_prefills > 0
|
is_prompt = num_prefills > 0
|
||||||
@ -601,11 +605,32 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
return [SamplerOutput(sampler_outputs)]
|
return [SamplerOutput(sampler_outputs)]
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(nn.Module):
|
class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
|
||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
compiled_callable = torch.compile(self.forward,
|
||||||
|
backend="openxla",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False)
|
||||||
|
super().__init__(compiled_callable)
|
||||||
|
|
||||||
|
def __call__(self, *args, is_prompt: bool, **kwargs):
|
||||||
|
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
|
||||||
|
# not fully compiled yet, or not using the custom dispatcher,
|
||||||
|
# let PyTorch handle it
|
||||||
|
return self.compiled_callable(*args, **kwargs)
|
||||||
|
# the 3 compiled codes are:
|
||||||
|
# 0: for profiling
|
||||||
|
# 1: for prompt
|
||||||
|
# 2: for decode
|
||||||
|
# dispatch to the compiled code directly, skip PyTorch
|
||||||
|
if is_prompt:
|
||||||
|
with self.dispatch_to_code(1):
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
with self.dispatch_to_code(2):
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user