[torch.compile] generic decorators (#9258)
This commit is contained in:
parent
a78c6ba7c8
commit
e00c094f15
@ -1,20 +1,54 @@
|
||||
from typing import List, Optional, Union
|
||||
import inspect
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import supports_dynamo
|
||||
|
||||
|
||||
def support_compile_llama_style(cls: type):
|
||||
def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
Depending on the value of arguments:
|
||||
|
||||
- if it is a single integer, the corresponding dimension of the argument
|
||||
will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: type):
|
||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||
# to avoid too much indentation for `_support_torch_compile``
|
||||
sig = inspect.signature(cls.forward)
|
||||
for k in dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}")
|
||||
return _support_torch_compile(cls, dynamic_arg_dims)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _support_torch_compile(cls: type,
|
||||
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
If a module's **forward signature** is compatible with llama, this
|
||||
decorator can be used to enable the compilation of the forward method.
|
||||
"""
|
||||
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
@ -37,48 +71,42 @@ def support_compile_llama_style(cls: type):
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
if len(self.compiled_codes) < 1:
|
||||
if input_ids is not None:
|
||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||
torch._dynamo.mark_dynamic(positions, 0)
|
||||
if inputs_embeds is not None:
|
||||
torch._dynamo.mark_dynamic(inputs_embeds, 0)
|
||||
if intermediate_tensors is not None:
|
||||
for tensors in intermediate_tensors.tensors.values():
|
||||
torch._dynamo.mark_dynamic(tensors, 0)
|
||||
sig = inspect.signature(self.__class__.forward)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}.")
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
# with the overhead of guard evaluation and recompilation.
|
||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||
return self.compiled_callable(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
# the Dynamo guard mechanism.
|
||||
with self.dispatch_to_code(0):
|
||||
model_output = self.forward(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
cls.__call__ = __call__
|
||||
|
@ -21,7 +21,7 @@ from torch import nn
|
||||
from transformers import Gemma2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_compile_llama_style
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -239,7 +239,13 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_compile_llama_style
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": 0,
|
||||
"inputs_embeds": 0,
|
||||
"intermediate_tensors": 0,
|
||||
})
|
||||
class Gemma2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
@ -28,7 +28,7 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_compile_llama_style
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
@ -266,7 +266,13 @@ class LlamaDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_compile_llama_style
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": 0,
|
||||
"inputs_embeds": 0,
|
||||
"intermediate_tensors": 0,
|
||||
})
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
Loading…
x
Reference in New Issue
Block a user