[torch.compile] generic decorators (#9258)
This commit is contained in:
parent
a78c6ba7c8
commit
e00c094f15
@ -1,20 +1,54 @@
|
|||||||
from typing import List, Optional, Union
|
import inspect
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import supports_dynamo
|
from vllm.utils import supports_dynamo
|
||||||
|
|
||||||
|
|
||||||
def support_compile_llama_style(cls: type):
|
def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
||||||
|
"""
|
||||||
|
A decorator to add support for compiling the forward method of a class.
|
||||||
|
|
||||||
|
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||||
|
dimensions of the argument. The dynamic dimensions can be either a single
|
||||||
|
integer or a list of integers.
|
||||||
|
|
||||||
|
Depending on the value of arguments:
|
||||||
|
|
||||||
|
- if it is a single integer, the corresponding dimension of the argument
|
||||||
|
will be marked as dynamic.
|
||||||
|
- if it is `None`, ignored.
|
||||||
|
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||||
|
tensors will be marked as dynamic.
|
||||||
|
- otherwise, it will raise an error.
|
||||||
|
|
||||||
|
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||||
|
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||||
|
computation graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cls_decorator_helper(cls: type):
|
||||||
|
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||||
|
# to avoid too much indentation for `_support_torch_compile``
|
||||||
|
sig = inspect.signature(cls.forward)
|
||||||
|
for k in dynamic_arg_dims:
|
||||||
|
if k not in sig.parameters:
|
||||||
|
raise ValueError(
|
||||||
|
f"Argument {k} not found in the forward method of {cls}")
|
||||||
|
return _support_torch_compile(cls, dynamic_arg_dims)
|
||||||
|
|
||||||
|
return cls_decorator_helper
|
||||||
|
|
||||||
|
|
||||||
|
def _support_torch_compile(cls: type,
|
||||||
|
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
If a module's **forward signature** is compatible with llama, this
|
|
||||||
decorator can be used to enable the compilation of the forward method.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||||
@ -37,48 +71,42 @@ def support_compile_llama_style(cls: type):
|
|||||||
|
|
||||||
cls.__init__ = __init__
|
cls.__init__ = __init__
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, *args, **kwargs):
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.Tensor],
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
# torch.compiler.is_compiling() means we are inside the compilation
|
# torch.compiler.is_compiling() means we are inside the compilation
|
||||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||||
# need to compile the model inside.
|
# need to compile the model inside.
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
return self.forward(input_ids, positions, kv_caches, attn_metadata,
|
return self.forward(*args, **kwargs)
|
||||||
intermediate_tensors, inputs_embeds)
|
|
||||||
|
|
||||||
# the first compilation needs to have dynamic shapes marked
|
# the first compilation needs to have dynamic shapes marked
|
||||||
if len(self.compiled_codes) < 1:
|
if len(self.compiled_codes) < 1:
|
||||||
if input_ids is not None:
|
sig = inspect.signature(self.__class__.forward)
|
||||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
bound_args = sig.bind(self, *args, **kwargs)
|
||||||
torch._dynamo.mark_dynamic(positions, 0)
|
bound_args.apply_defaults()
|
||||||
if inputs_embeds is not None:
|
for k, dims in dynamic_arg_dims.items():
|
||||||
torch._dynamo.mark_dynamic(inputs_embeds, 0)
|
arg = bound_args.arguments.get(k)
|
||||||
if intermediate_tensors is not None:
|
if arg is not None:
|
||||||
for tensors in intermediate_tensors.tensors.values():
|
if isinstance(arg, torch.Tensor):
|
||||||
torch._dynamo.mark_dynamic(tensors, 0)
|
torch._dynamo.mark_dynamic(arg, dims)
|
||||||
|
elif isinstance(arg, IntermediateTensors):
|
||||||
|
for tensor in arg.tensors.values():
|
||||||
|
torch._dynamo.mark_dynamic(tensor, dims)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported dynamic dimensions"
|
||||||
|
f" {dims} for argument {k} with type {type(arg)}.")
|
||||||
|
|
||||||
# if we don't use custom dispatcher, we can directly call the
|
# if we don't use custom dispatcher, we can directly call the
|
||||||
# compiled function and let torch.compile handle the dispatching,
|
# compiled function and let torch.compile handle the dispatching,
|
||||||
# with the overhead of guard evaluation and recompilation.
|
# with the overhead of guard evaluation and recompilation.
|
||||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||||
return self.compiled_callable(input_ids, positions, kv_caches,
|
return self.compiled_callable(*args, **kwargs)
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
# usually, capturing the model once is enough, and then we can
|
||||||
# dispatch to the compiled code directly, without going through
|
# dispatch to the compiled code directly, without going through
|
||||||
# the Dynamo guard mechanism.
|
# the Dynamo guard mechanism.
|
||||||
with self.dispatch_to_code(0):
|
with self.dispatch_to_code(0):
|
||||||
model_output = self.forward(input_ids, positions, kv_caches,
|
model_output = self.forward(*args, **kwargs)
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
cls.__call__ = __call__
|
cls.__call__ = __call__
|
||||||
|
@ -21,7 +21,7 @@ from torch import nn
|
|||||||
from transformers import Gemma2Config
|
from transformers import Gemma2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_compile_llama_style
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -239,7 +239,13 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@support_compile_llama_style
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims={
|
||||||
|
"input_ids": 0,
|
||||||
|
"positions": 0,
|
||||||
|
"inputs_embeds": 0,
|
||||||
|
"intermediate_tensors": 0,
|
||||||
|
})
|
||||||
class Gemma2Model(nn.Module):
|
class Gemma2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -28,7 +28,7 @@ from torch import nn
|
|||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_compile_llama_style
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@ -266,7 +266,13 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@support_compile_llama_style
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims={
|
||||||
|
"input_ids": 0,
|
||||||
|
"positions": 0,
|
||||||
|
"inputs_embeds": 0,
|
||||||
|
"intermediate_tensors": 0,
|
||||||
|
})
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user