Adding "torch compile" annotations to moe models (#9758)
This commit is contained in:
parent
5f8d8075f9
commit
aa0addb397
@ -5,6 +5,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -360,6 +361,7 @@ class ArcticDecoderLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ArcticModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
@ -28,6 +28,7 @@ from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -245,6 +246,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MixtralModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
@ -17,6 +17,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -239,6 +240,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class OlmoeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
@ -28,6 +28,7 @@ from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -429,6 +430,7 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class PhiMoEModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
Loading…
x
Reference in New Issue
Block a user