Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor 2025-02-25 01:13:52 +00:00 committed by GitHub
parent f61528d46d
commit cdc1fa12eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
104 changed files with 436 additions and 1654 deletions

View File

@ -74,8 +74,6 @@ def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
... ...
``` ```

View File

@ -16,8 +16,6 @@ Further update the model as follows:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor, + pixel_values: torch.Tensor,
) -> SamplerOutput: ) -> SamplerOutput:
``` ```

View File

@ -644,11 +644,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward( return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)
def _run_decoder_self_attention_test( def _run_decoder_self_attention_test(
@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
& attn_metadata & attn_metadata
''' '''
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config): with set_forward_context(attn_metadata, vllm_config):
@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
kv_cache, attn_metadata)
def _run_encoder_decoder_cross_attention_test( def _run_encoder_decoder_cross_attention_test(
@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
assert decoder_test_params.packed_qkvo.packed_qkv is not None assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None: if cross_test_params is None:
key = None key = None
value = None value = None
@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, key, value, kv_cache, return attn.forward(reshaped_query, key, value)
attn_metadata)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View File

@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
@ -153,15 +153,10 @@ class Attention(nn.Module):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
# directly, use `self.kv_cache` and
# `get_forward_context().attn_metadata` instead.
if self.calculate_kv_scales: if self.calculate_kv_scales:
ctx_attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if ctx_attn_metadata.enable_kv_scales_calculation: if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value) self.calc_kv_scales(key, value)
if self.use_output: if self.use_output:
output = torch.empty_like(query) output = torch.empty_like(query)
@ -177,14 +172,14 @@ class Attention(nn.Module):
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self, self.impl.forward(self,
query, query,
key, key,
value, value,
self_kv_cache, self_kv_cache,
ctx_attn_metadata, attn_metadata,
output=output) output=output)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
@ -193,10 +188,10 @@ class Attention(nn.Module):
else: else:
if self.use_direct_call: if self.use_direct_call:
forward_context = get_forward_context() forward_context = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value, return self.impl.forward(self, query, key, value,
self_kv_cache, ctx_attn_metadata) self_kv_cache, attn_metadata)
else: else:
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name) query, key, value, self.layer_name)

View File

@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -130,14 +131,14 @@ class MambaMixer(CustomOp):
) if use_rms_norm else None ) if use_rms_norm else None
def forward_native(self, hidden_states: torch.Tensor, def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor): conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass pass
def forward_cuda(self, hidden_states: torch.Tensor, def forward_cuda(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams): mamba_cache_params: MambaCacheParams):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2) hidden_states, gate = projected_states.chunk(2, dim=-2)

View File

@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
@ -376,17 +377,16 @@ class MambaMixer2(CustomOp):
eps=rms_norm_eps) eps=rms_norm_eps)
def forward_native(self, hidden_states: torch.Tensor, def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor): conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass pass
def forward_cuda( def forward_cuda(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None, sequence_idx: Optional[torch.Tensor] = None,
): ):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
seq_len, _ = hidden_states.shape seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size groups_time_state_size = self.n_groups * self.ssm_state_size

View File

@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
return cls return cls
# Lazy import # Lazy import
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions, kv_caches, hidden_states = super().forward(input_ids, positions,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.score(hidden_states) logits, _ = self.score(hidden_states)

View File

@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -283,13 +283,11 @@ class ArcticAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual_input = hidden_states residual_input = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual_input + hidden_states hidden_states = residual_input + hidden_states
@ -400,8 +394,6 @@ class ArcticModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -413,11 +405,8 @@ class ArcticModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor from transformers.models.aria.processing_aria import AriaProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids, input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -20,13 +20,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.""" """Inference-only BaiChuan model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,17 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only Bamba model.""" """Inference-only Bamba model."""
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import BambaConfig from transformers import BambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None, sequence_idx: Optional[torch.Tensor] = None,
@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, hidden_states = self.mamba(hidden_states, mamba_cache_params,
mamba_cache_params, sequence_idx) sequence_idx)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual) hidden_states, residual)
@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
**kwargs, **kwargs,
): ):
@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
hidden_states = self.self_attention( hidden_states = self.self_attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
@ -312,8 +305,6 @@ class BambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
@ -323,6 +314,7 @@ class BambaModel(nn.Module):
# proper continuous batching computation including # proper continuous batching computation including
# chunked prefill # chunked prefill
seq_idx = None seq_idx = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate( for i, (srt, end) in enumerate(
@ -348,9 +340,7 @@ class BambaModel(nn.Module):
num_attn = 0 num_attn = 0
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
kv_cache = None
if isinstance(layer, BambaAttentionDecoderLayer): if isinstance(layer, BambaAttentionDecoderLayer):
kv_cache = kv_caches[num_attn]
num_attn += 1 num_attn += 1
layer_mamba_cache_params = None layer_mamba_cache_params = None
@ -361,8 +351,6 @@ class BambaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params, mamba_cache_params=layer_mamba_cache_params,
sequence_idx=seq_idx, sequence_idx=seq_idx,
@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape()) *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, mamba_cache_params,
attn_metadata, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states

View File

@ -19,14 +19,14 @@
# limitations under the License. # limitations under the License.
"""PyTorch BART model.""" """PyTorch BART model."""
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import BartConfig from transformers import BartConfig
from transformers.utils import logging from transformers.utils import logging
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER) attn_type=AttentionType.ENCODER)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER) attn_type=AttentionType.DECODER)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module):
def forward( def forward(
self, self,
decoder_hidden_states: torch.Tensor, decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module):
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1) dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module):
self.final_layer_norm = nn.LayerNorm(self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
hidden_states hidden_states
torch.Tensor of *encoder* input embeddings. torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Encoder layer output torch.Tensor Encoder layer output torch.Tensor
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn(hidden_states=hidden_states, hidden_states = self.self_attn(hidden_states=hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module):
def forward( def forward(
self, self,
decoder_hidden_states: torch.Tensor, decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
decoder_hidden_states decoder_hidden_states
torch.Tensor of *decoder* input embeddings. torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states encoder_hidden_states
torch.Tensor of *encoder* input embeddings. torch.Tensor of *encoder* input embeddings.
Returns: Returns:
@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module):
residual = decoder_hidden_states residual = decoder_hidden_states
# Self Attention # Self Attention
hidden_states = self.self_attn(hidden_states=decoder_hidden_states, hidden_states = self.self_attn(hidden_states=decoder_hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module):
hidden_states = self.encoder_attn( hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states, decoder_hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
@ -609,9 +588,8 @@ class BartEncoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor,
kv_caches: List[torch.Tensor], positions: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids
@ -620,10 +598,6 @@ class BartEncoder(nn.Module):
provide it. provide it.
positions positions
Positions of *encoder* input sequence tokens. Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
@ -636,12 +610,8 @@ class BartEncoder(nn.Module):
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
for idx, encoder_layer in enumerate(self.layers): for encoder_layer in self.layers:
hidden_states = encoder_layer( hidden_states = encoder_layer(hidden_states=hidden_states)
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
return hidden_states return hidden_states
@ -693,9 +663,7 @@ class BartDecoder(nn.Module):
def forward(self, decoder_input_ids: torch.Tensor, def forward(self, decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor, decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
decoder_input_ids decoder_input_ids
@ -706,10 +674,6 @@ class BartDecoder(nn.Module):
Positions of *decoder* input sequence tokens. Positions of *decoder* input sequence tokens.
encoder_hidden_states: encoder_hidden_states:
Tensor of encoder output embeddings Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
@ -725,11 +689,9 @@ class BartDecoder(nn.Module):
# decoder layers # decoder layers
for idx, decoder_layer in enumerate(self.layers): for decoder_layer in self.layers:
hidden_states = decoder_layer( hidden_states = decoder_layer(
decoder_hidden_states=hidden_states, decoder_hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
@ -768,8 +730,7 @@ class BartModel(nn.Module):
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], encoder_positions: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids
@ -782,10 +743,6 @@ class BartModel(nn.Module):
Indices of *encoder* input sequence tokens in the vocabulary. Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions: encoder_positions:
Positions of *encoder* input sequence tokens. Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Model output torch.Tensor Model output torch.Tensor
""" """
@ -796,18 +753,14 @@ class BartModel(nn.Module):
# Run encoder attention if a non-zero number of encoder tokens # Run encoder attention if a non-zero number of encoder tokens
# are provided as input # are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions, positions=encoder_positions)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of # decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn) # (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
decoder_input_ids=input_ids, decoder_input_ids=input_ids,
decoder_positions=positions, decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs return decoder_outputs
@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
*, *,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids. torch.Tensor of *encoder* input token ids.
encoder_positions encoder_positions
torch.Tensor of *encoder* position indices torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
return self.model(input_ids, positions, encoder_input_ids, return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata) encoder_positions)
def compute_logits( def compute_logits(
self, self,

View File

@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import BertConfig from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
for i in range(len(self.layer)): for layer in self.layer:
layer = self.layer[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
return hidden_states return hidden_states
@ -152,13 +150,8 @@ class BertLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.output") prefix=f"{prefix}.output")
def forward( def forward(self, hidden_states: torch.Tensor):
self, attn_output = self.attention(hidden_states)
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
):
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
intermediate_output = self.intermediate(attn_output) intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output) output = self.output(intermediate_output, attn_output)
return output return output
@ -191,10 +184,8 @@ class BertAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
self_output = self.self(hidden_states, kv_cache, attn_metadata) self_output = self.self(hidden_states)
return self.output(self_output, hidden_states) return self.output(self_output, hidden_states)
@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(q, k, v, kv_cache, attn_metadata) output = self.attn(q, k, v)
return output return output
@ -343,8 +332,6 @@ class BertModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
@ -352,13 +339,14 @@ class BertModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor") assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings( hidden_states = self.embeddings(
input_ids=input_ids, input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor, seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
return self.encoder(hidden_states, kv_caches, attn_metadata) return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.model(input_ids=input_ids, return self.model(input_ids=input_ids,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors)
attn_metadata=attn_metadata)
def pooler( def pooler(
self, self,
@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.bert(input_ids=input_ids, return self.bert(input_ids=input_ids,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
@ -9,7 +9,6 @@ import torch.nn as nn
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
apply_chunking_to_forward) apply_chunking_to_forward)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -18,13 +18,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.""" """Inference-only BLOOM model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import BloomConfig from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # Unused. del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
attention_output = self.self_attention( attention_output = self.self_attention(
position_ids=position_ids, position_ids=position_ids,
hidden_states=layernorm_output, hidden_states=layernorm_output,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
attention_output = attention_output + residual attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output) layernorm_output = self.post_attention_layernorm(attention_output)
@ -266,8 +260,6 @@ class BloomModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -279,14 +271,8 @@ class BloomModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(position_ids, hidden_states)
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
@ -10,7 +10,7 @@ import torch.nn.functional as F
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
ChameleonVQVAEConfig) ChameleonVQVAEConfig)
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -906,8 +896,6 @@ class ChameleonModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -921,13 +909,10 @@ class ChameleonModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.model(input_ids, hidden_states = self.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -2,13 +2,13 @@
# Adapted from # Adapted from
# https://github.com/THUDM/ChatGLM2-6B # https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -108,19 +108,11 @@ class GLMAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn( context_layer = self.attn(q, k, v)
q,
k,
v,
kv_cache,
attn_metadata,
)
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
@ -215,8 +207,6 @@ class GLMBlock(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# hidden_states: [num_tokens, h] # hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
@ -225,8 +215,6 @@ class GLMBlock(nn.Module):
attention_output = self.self_attention( attention_output = self.self_attention(
hidden_states=layernorm_output, hidden_states=layernorm_output,
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Residual connection. # Residual connection.
@ -289,17 +277,10 @@ class GLMTransformer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(hidden_states=hidden_states,
hidden_states = layer( position_ids=position_ids)
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module):
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states=hidden_states, hidden_states=hidden_states,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
) )
return hidden_states return hidden_states
@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states

View File

@ -21,14 +21,14 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
if self.v1 or self.sliding_window: if self.v1 or self.sliding_window:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention = self.self_attn( hidden_states_attention = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states_mlp = self.mlp(hidden_states) hidden_states_mlp = self.mlp(hidden_states)
# Add everything together # Add everything together
@ -311,8 +305,6 @@ class CohereModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -326,13 +318,10 @@ class CohereModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
hidden_states, _ = self.out_proj(attn_output) hidden_states, _ = self.out_proj(attn_output)
return hidden_states return hidden_states
@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.norm_1(hidden_states) hidden_states = self.norm_1(hidden_states)
x = self.attn( x = self.attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + x hidden_states = residual + x
residual = hidden_states residual = hidden_states
@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm( hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.ffn(hidden_states) hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
else: else:
assert intermediate_tensors assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for block in self.blocks[self.start_layer:self.end_layer]:
block = self.blocks[i] hidden_states = block(position_ids, hidden_states)
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Deepseek model.""" """Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -370,8 +364,6 @@ class DeepseekModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -384,11 +376,8 @@ class DeepseekModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -69,8 +68,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0, spec_step_index: int = 0,
@ -88,8 +85,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
hidden_states, residual = self.mtp_block(positions=positions, hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None) residual=None)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return self.shared_head(hidden_states) return self.shared_head(hidden_states)
@ -122,8 +117,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
@ -131,8 +124,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
input_ids, input_ids,
positions, positions,
kv_caches[spec_step_idx],
attn_metadata,
previous_hidden_states, previous_hidden_states,
inputs_embeds, inputs_embeds,
spec_step_idx, spec_step_idx,
@ -165,16 +156,14 @@ class DeepSeekMTP(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions,
attn_metadata, previous_hidden_states, previous_hidden_states, inputs_embeds,
inputs_embeds, spec_step_idx) spec_step_idx)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
@ -279,8 +279,6 @@ class DeepseekV2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
@ -313,7 +311,7 @@ class DeepseekV2Attention(nn.Module):
v = torch.nn.functional.pad( v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim], v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim) value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output = attn_output.view( attn_output = attn_output.view(
-1, self.num_local_heads, -1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape( self.qk_head_dim)[..., :self.v_head_dim].reshape(
@ -451,8 +449,6 @@ class DeepseekV2MLAAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0] ckq = self.q_a_proj(hidden_states)[0]
@ -462,8 +458,7 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe)
attn_metadata)
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
@ -532,8 +527,6 @@ class DeepseekV2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
@ -546,8 +539,6 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -608,8 +599,6 @@ class DeepseekV2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -624,11 +613,8 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -665,13 +651,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -13,7 +13,6 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
@ -595,8 +594,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object): **kwargs: object):
@ -614,8 +611,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model(input_ids, hidden_states = self.language_model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -121,8 +120,6 @@ class EAGLE(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
@ -140,8 +137,6 @@ class EAGLE(nn.Module):
input_ids=None, input_ids=None,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
) )
return hidden_states return hidden_states

View File

@ -24,12 +24,12 @@
# limitations under the License. # limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights.""" """Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -179,13 +179,11 @@ class ExaoneAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -225,14 +223,10 @@ class ExaoneBlockAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
return self.attention( return self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
@ -288,8 +282,6 @@ class ExaoneDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -301,8 +293,6 @@ class ExaoneDecoderLayer(nn.Module):
hidden_states = self.attn( hidden_states = self.attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -365,8 +355,6 @@ class ExaoneModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -381,13 +369,10 @@ class ExaoneModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
@ -471,14 +456,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformer(input_ids, positions, kv_caches, model_output = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(

View File

@ -20,14 +20,14 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -190,8 +190,6 @@ class FalconAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states) qkv, bias = self.query_key_value(hidden_states)
if bias is not None: if bias is not None:
@ -199,7 +197,7 @@ class FalconAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_rotary: if self.use_rotary:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, bias = self.dense(attn_output) attn_output, bias = self.dense(attn_output)
return attn_output, bias return attn_output, bias
@ -291,8 +289,6 @@ class FalconDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
@ -306,8 +302,6 @@ class FalconDecoderLayer(nn.Module):
attention_output, attention_bias = self.self_attention( attention_output, attention_bias = self.self_attention(
positions=positions, positions=positions,
hidden_states=attention_layernorm_out, hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
if self.reduce_row_parallel_results and attention_bias is not None: if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias attention_output += attention_bias
@ -384,8 +378,6 @@ class FalconModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -396,14 +388,8 @@ class FalconModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
@ -450,14 +436,11 @@ class FalconForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -50,8 +49,7 @@ class Florence2LanguageModel(nn.Module):
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], encoder_positions: torch.Tensor) -> torch.Tensor:
attn_metadata: AttentionMetadata) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids
@ -64,10 +62,6 @@ class Florence2LanguageModel(nn.Module):
Indices of *encoder* input sequence tokens in the vocabulary. Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions: encoder_positions:
Positions of *encoder* input sequence tokens. Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Model output torch.Tensor Model output torch.Tensor
""" """
@ -78,18 +72,14 @@ class Florence2LanguageModel(nn.Module):
# Run encoder attention if a non-zero number of encoder tokens # Run encoder attention if a non-zero number of encoder tokens
# are provided as input # are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions, positions=encoder_positions)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of # decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn) # (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
decoder_input_ids=input_ids, decoder_input_ids=input_ids,
decoder_positions=positions, decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states)
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs return decoder_outputs
@ -122,8 +112,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
@ -136,15 +124,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids. torch.Tensor of *encoder* input token ids.
encoder_positions encoder_positions
torch.Tensor of *encoder* position indices torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
return self.model(input_ids, positions, encoder_input_ids, return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata) encoder_positions)
def compute_logits( def compute_logits(
self, self,
@ -213,8 +197,6 @@ class Florence2ForConditionalGeneration(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
*, *,
encoder_input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
@ -231,15 +213,11 @@ class Florence2ForConditionalGeneration(nn.Module):
torch.Tensor of *encoder* input token ids. torch.Tensor of *encoder* input token ids.
encoder_positions encoder_positions
torch.Tensor of *encoder* position indices torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
return self.language_model(input_ids, positions, encoder_input_ids, return self.language_model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata) encoder_positions)
def compute_logits( def compute_logits(
self, self,

View File

@ -25,7 +25,6 @@ import torch.nn as nn
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
FuyuProcessor) FuyuProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -351,8 +350,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -371,8 +368,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -16,13 +16,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from functools import cache from functools import cache
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GemmaConfig from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -183,13 +183,11 @@ class GemmaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -233,8 +231,6 @@ class GemmaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -247,8 +243,6 @@ class GemmaDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -298,8 +292,6 @@ class GemmaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -313,13 +305,10 @@ class GemmaModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -370,13 +359,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -15,13 +15,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Gemma2Config from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -164,13 +164,11 @@ class Gemma2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -220,8 +218,6 @@ class Gemma2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None: if residual is None:
@ -233,8 +229,6 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
@ -284,8 +278,6 @@ class Gemma2Model(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -300,13 +292,10 @@ class Gemma2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -415,13 +404,10 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -4,7 +4,7 @@
# https://github.com/THUDM/CogAgent # https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights.""" """Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace from argparse import Namespace
from typing import List, Literal, Mapping, Optional, TypedDict, Union from typing import Literal, Mapping, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
@ -15,7 +15,6 @@ from transformers import PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention import AttentionMetadata
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -628,8 +627,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -645,8 +642,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
vision_embeddings) vision_embeddings)
input_ids = None input_ids = None
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states

View File

@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
@ -92,12 +92,10 @@ class GPT2Attention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
@ -164,16 +162,10 @@ class GPT2Block(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(hidden_states=hidden_states)
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
@ -222,8 +214,6 @@ class GPT2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -236,11 +226,8 @@ class GPT2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -279,14 +266,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -19,13 +19,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -101,8 +101,6 @@ class GPTBigCodeAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split( q, k, v = qkv.split(
@ -112,7 +110,7 @@ class GPTBigCodeAttention(nn.Module):
], ],
dim=-1, dim=-1,
) )
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
@ -173,16 +171,10 @@ class GPTBigCodeBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(hidden_states=hidden_states, )
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
@ -234,8 +226,6 @@ class GPTBigCodeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -246,11 +236,8 @@ class GPTBigCodeModel(nn.Module):
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -302,14 +289,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -17,13 +17,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPTJConfig from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -104,13 +104,11 @@ class GPTJAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
@ -167,16 +165,12 @@ class GPTJBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
mlp_output = self.mlp(hidden_states) mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual hidden_states = attn_output + mlp_output + residual
@ -217,8 +211,6 @@ class GPTJModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -229,14 +221,8 @@ class GPTJModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(position_ids, hidden_states)
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
@ -273,14 +259,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -17,13 +17,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GPTNeoXConfig from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -104,13 +104,11 @@ class GPTNeoXAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@ -167,15 +165,11 @@ class GPTNeoXLayer(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states) attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention( attn_output = self.attention(
position_ids=position_ids, position_ids=position_ids,
hidden_states=attn_input, hidden_states=attn_input,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
if self.use_parallel_residual: if self.use_parallel_residual:
@ -230,8 +224,6 @@ class GPTNeoXModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -242,14 +234,8 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(position_ids, hidden_states)
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
@ -285,14 +271,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights.""" """Inference-only IBM Granite model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import GraniteConfig from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -166,13 +166,11 @@ class GraniteAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -233,8 +231,6 @@ class GraniteDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
@ -242,8 +238,6 @@ class GraniteDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
# Fully Connected # Fully Connected
@ -300,8 +294,6 @@ class GraniteModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -318,14 +310,8 @@ class GraniteModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -405,13 +391,10 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GraniteMoe model.""" """Inference-only GraniteMoe model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -173,13 +173,11 @@ class GraniteMoeAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -226,8 +224,6 @@ class GraniteMoeDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
@ -235,8 +231,6 @@ class GraniteMoeDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states * self.residual_multiplier hidden_states = residual + hidden_states * self.residual_multiplier
residual = hidden_states residual = hidden_states
@ -287,8 +281,6 @@ class GraniteMoeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -303,11 +295,8 @@ class GraniteMoeModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -377,13 +366,10 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,15 +1,15 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from array import array from array import array
from typing import List, Optional, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.attention.backends.xformers import XFormersImpl from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
@ -217,13 +217,12 @@ class GritLM(LlamaForCausalLM):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
# Change attention to non-causal for pooling tasks. # Change attention to non-causal for pooling tasks.
if self.runner_type == "pooling": if self.runner_type == "pooling":
attn_metadata = get_forward_context().attn_metadata
assert attn_metadata.prefill_metadata.attn_bias is None assert attn_metadata.prefill_metadata.attn_bias is None
attn_metadata.prefill_metadata.attn_bias = [ attn_metadata.prefill_metadata.attn_bias = [
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
@ -232,8 +231,6 @@ class GritLM(LlamaForCausalLM):
return super().forward( return super().forward(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**kwargs, **kwargs,
) )

View File

@ -25,7 +25,6 @@ from torch import nn
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
Idefics3Processor) Idefics3Processor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
@ -563,8 +562,6 @@ class Idefics3Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -572,8 +569,6 @@ class Idefics3Model(nn.Module):
hidden_states = self.text_model( hidden_states = self.text_model(
input_ids, input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
@ -645,8 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -664,8 +657,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.model.text_model(input_ids, hidden_states = self.model.text_model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload,
overload, runtime_checkable) runtime_checkable)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -11,7 +11,6 @@ from vllm.logger import init_logger
from vllm.utils import supports_kw from vllm.utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -46,8 +45,6 @@ class VllmModel(Protocol[T_co]):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
) -> T_co: ) -> T_co:
... ...
@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
if not callable(model_forward): if not callable(model_forward):
return False return False
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata") vllm_kws = ("input_ids", "positions")
missing_kws = tuple(kw for kw in vllm_kws missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_forward, kw)) if not supports_kw(model_forward, kw))

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states) qkv, _ = self.wqkv(hidden_states)
q, k, v = self.split_qkv(qkv) q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.wo(attn_output) output, _ = self.wo(attn_output)
return output return output
@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states = self.attention( hidden_states = self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.v_head(hidden_states) logits, _ = self.v_head(hidden_states)
return logits return logits

View File

@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
hidden_states = self.attention( hidden_states = self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
visual_token_mask: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None,
@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
visual_token_mask=visual_token_mask, visual_token_mask=visual_token_mask,
) )

View File

@ -17,7 +17,6 @@ import torchvision.transforms as T
from PIL import Image from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
forward_kwargs = { forward_kwargs = {
"input_ids": input_ids, "input_ids": input_ids,
"positions": positions, "positions": positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"intermediate_tensors": intermediate_tensors, "intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds, "inputs_embeds": inputs_embeds,
} }

View File

@ -21,12 +21,12 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(hidden_states=hidden_states, )
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
@ -266,8 +258,6 @@ class JAISModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]: ) -> Union[IntermediateTensors, torch.Tensor]:
@ -285,11 +275,8 @@ class JAISModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]: ) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import JambaConfig from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, hidden_states = self.mamba(hidden_states, mamba_cache_params)
mamba_cache_params)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual) hidden_states, residual)
@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
**kwargs, **kwargs,
): ):
@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
hidden_states = self.self_attention( hidden_states = self.self_attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
@ -320,8 +311,6 @@ class JambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
@ -339,12 +328,9 @@ class JambaModel(nn.Module):
kv_cache_index = 0 kv_cache_index = 0
mamba_cache_index = 0 mamba_cache_index = 0
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
kv_cache = None
layer_mamba_cache_params = None layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer): if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[kv_cache_index]
kv_cache_index += 1 kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer): if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = mamba_cache_index current_state_layer = mamba_cache_index
@ -355,8 +341,6 @@ class JambaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=layer_mamba_cache_params) mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, mamba_cache_params,
attn_metadata, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
return hidden_states return hidden_states

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output

View File

@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -10,7 +10,6 @@ import torch.nn as nn
from transformers import (BatchFeature, LlavaNextVideoConfig, from transformers import (BatchFeature, LlavaNextVideoConfig,
LlavaNextVideoProcessor) LlavaNextVideoProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
**kwargs, **kwargs,
@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, hidden_states = self.mixer(hidden_states, mamba_cache_params)
mamba_cache_params)
return hidden_states, residual return hidden_states, residual
@ -125,7 +122,6 @@ class MambaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
@ -146,7 +142,6 @@ class MambaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx( mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer)) i - self.start_layer))
@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
mamba_cache_params, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model.""" """PyTorch MAMBA2 model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import ( from vllm.model_executor.layers.mamba.mamba_mixer2 import (
@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor], sequence_idx: Optional[torch.Tensor],
@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, hidden_states = self.mixer(hidden_states, mamba_cache_params,
mamba_cache_params, sequence_idx) sequence_idx)
return hidden_states, residual return hidden_states, residual
@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
# proper continuous batching computation including # proper continuous batching computation including
# chunked prefill # chunked prefill
seq_idx = None seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate( for i, (srt, end) in enumerate(
@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx( mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer), i - self.start_layer),
@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
mamba_cache_params, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states

View File

@ -23,13 +23,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
q, k = q.float(), k.float() q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype) q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states * \ hidden_states = residual + hidden_states * \
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -29,7 +29,7 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
q, _ = self.q_a_proj(hidden_states) q, _ = self.q_a_proj(hidden_states)
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
v, [0, self.qk_head_dim - self.v_head_dim], v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim) value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
attn_output = attn_output.view( attn_output = attn_output.view(
-1, self.num_local_heads, -1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape( self.qk_head_dim)[..., :self.v_head_dim].reshape(

View File

@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import ( from transformers.models.whisper.modeling_whisper import (
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
output = self.llm.model( output = self.llm.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings, inputs_embeds=vlm_embeddings,
) )

View File

@ -37,7 +37,6 @@ from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
output = self.llm.model( output = self.llm.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings, inputs_embeds=vlm_embeddings,
) )

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -30,7 +30,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size() tensor_parallel_size = get_tp_group().world_size
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // model_parallel_size self.num_local_heads = self.num_heads // tensor_parallel_size
self.q_size = self.num_local_heads * self.head_dim self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim self.kv_size = self.num_local_heads * self.head_dim
@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.model_parallel_size = get_tensor_model_parallel_world_size() self.pipeline_parallel_rank = get_pp_group().rank_in_group
self.tensor_parallel_size = get_tp_group().world_size
self.num_heads = self.config.num_attention_heads self.num_heads = self.config.num_attention_heads
self.num_local_heads = self.num_heads // self.model_parallel_size self.num_local_heads = self.num_heads // self.tensor_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = \ self.num_local_key_value_heads = \
self.num_key_value_heads // self.model_parallel_size self.num_key_value_heads // self.tensor_parallel_size
self.dropout = config.dropout self.dropout = config.dropout
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads self.head_dim = config.hidden_size // self.num_heads
@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states) qkv_dec, _ = self.qkv_proj(hidden_states)
q, _, _ = qkv_dec.split( q, _, _ = qkv_dec.split(
@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
q = self.q_norm(q) q = self.q_norm(q)
if attention_mask is not None: if attention_mask is not None:
output = self._attention_with_mask(q, k, v, kv_cache, output = self._attention_with_mask(q, k, v, attention_mask,
attention_mask, kv_range_for_decode)
kv_range_for_decode,
attn_metadata)
else: else:
output = self.attn( output = self.attn(
q.view(-1, self.num_local_heads * self.head_dim), k, v, q.view(-1, self.num_local_heads * self.head_dim), k, v)
kv_cache, attn_metadata)
out, _ = self.o_proj(output) out, _ = self.o_proj(output)
return out return out
@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
kv_cache: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]], kv_range_for_decode: List[Tuple[int, int]],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Skip writing kv-cache for the initial profiling run. # Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1: if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32) i = torch.ones(1, dtype=torch.float32)
@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
cross_attention_mask: torch.Tensor, cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
attention_mask=cross_attention_mask, attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh( hidden_states = residual + self.cross_attn_attn_gate.tanh(
@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool, skip_cross_attention: bool,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers): for decoder_layer in self.layers:
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention: if not skip_cross_attention:
hidden_states = decoder_layer( hidden_states = decoder_layer(
@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask= full_text_row_masked_out_mask=
full_text_row_masked_out_mask, full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
) )
elif isinstance(decoder_layer, LlamaDecoderLayer): elif isinstance(decoder_layer, LlamaDecoderLayer):
hidden_states, residual = decoder_layer( hidden_states, residual = decoder_layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
residual=None, residual=None,
) )
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool, skip_cross_attention: bool,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention, skip_cross_attention=skip_cross_attention,
) )
return hidden_states return hidden_states
@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object, **kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefill_tokens > 0 and \ if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0: attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported") raise ValueError("Chunk prefill not supported")
@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_mask=cross_attention_mask, cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode, kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention, skip_cross_attention=skip_cross_attention,
) )

View File

@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.q_norm is not None and self.k_norm is not None: if self.q_norm is not None and self.k_norm is not None:
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention # Self Attention
@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention # Self Attention
@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
# Apply blocks one-by-one. # Apply blocks one-by-one.
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.LongTensor, positions: torch.LongTensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
hidden_states = self.model(input_ids, hidden_states = self.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -2,12 +2,12 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -125,8 +125,6 @@ class MPTAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # unused. del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
@ -136,7 +134,7 @@ class MPTAttention(nn.Module):
if self.qk_ln: if self.qk_ln:
q = self.q_ln(q) q = self.q_ln(q)
k = self.k_ln(k) k = self.k_ln(k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -196,15 +194,11 @@ class MPTBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
x = self.norm_1(hidden_states) x = self.norm_1(hidden_states)
x = self.attn( x = self.attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=x, hidden_states=x,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = hidden_states + x hidden_states = hidden_states + x
x = self.norm_2(hidden_states) x = self.norm_2(hidden_states)
@ -253,8 +247,6 @@ class MPTModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -267,14 +259,8 @@ class MPTModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for block in self.blocks[self.start_layer:self.end_layer]:
block = self.blocks[i] hidden_states = block(position_ids, hidden_states)
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
@ -306,14 +292,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(

View File

@ -27,7 +27,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -204,13 +204,11 @@ class NemotronAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -269,8 +267,6 @@ class NemotronDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -283,8 +279,6 @@ class NemotronDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -343,8 +337,6 @@ class NemotronModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -359,15 +351,8 @@ class NemotronModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -444,13 +429,10 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights.""" """Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import OlmoConfig from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -119,15 +119,13 @@ class OlmoAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -212,14 +210,11 @@ class OlmoDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block. # Attention block.
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(positions, hidden_states, kv_cache, hidden_states = self.self_attn(positions, hidden_states)
attn_metadata)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
# MLP block. # MLP block.
@ -263,8 +258,6 @@ class OlmoModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -281,14 +274,9 @@ class OlmoModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
# Apply blocks one-by-one. # Apply blocks one-by-one.
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
hidden_states = self.layers[i]( hidden_states = layer(positions, hidden_states)
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -332,16 +320,12 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -24,12 +24,12 @@
"""Inference-only OLMo2 model compatible with HuggingFace weights.""" """Inference-only OLMo2 model compatible with HuggingFace weights."""
from functools import partial from functools import partial
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather from vllm.distributed.communication_op import tensor_model_parallel_all_gather
@ -153,14 +153,12 @@ class Olmo2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -239,13 +237,10 @@ class Olmo2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Attention block. # Attention block.
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn(positions, hidden_states, kv_cache, hidden_states = self.self_attn(positions, hidden_states)
attn_metadata)
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
@ -287,8 +282,6 @@ class Olmo2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
""" """
@ -307,14 +300,9 @@ class Olmo2Model(nn.Module):
assert isinstance(hidden_states, torch.Tensor) assert isinstance(hidden_states, torch.Tensor)
# Apply blocks one-by-one. # Apply blocks one-by-one.
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
hidden_states = self.layers[i]( hidden_states = layer(positions, hidden_states)
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -357,15 +345,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
) )
return hidden_states return hidden_states

View File

@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights.""" """Inference-only OLMoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -168,14 +168,12 @@ class OlmoeAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -222,8 +220,6 @@ class OlmoeDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
@ -237,8 +233,6 @@ class OlmoeDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -283,8 +277,6 @@ class OlmoeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -299,13 +291,10 @@ class OlmoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
@ -347,13 +336,10 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights.""" """Inference-only OPT model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import OPTConfig from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -107,12 +107,10 @@ class OPTAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -164,17 +162,13 @@ class OPTDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before: if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states, hidden_states = self.self_attn(hidden_states=hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention # 350m applies layer norm AFTER attention
if not self.do_layer_norm_before: if not self.do_layer_norm_before:
@ -261,8 +255,6 @@ class OPTDecoder(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -277,11 +269,8 @@ class OPTDecoder(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(hidden_states)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -317,15 +306,11 @@ class OPTModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
return self.decoder(input_ids, return self.decoder(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
@ -362,13 +347,10 @@ class OPTForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -5,13 +5,13 @@
# Copyright (c) OrionStar Inc. # Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights.""" """Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -136,13 +136,11 @@ class OrionAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -189,8 +187,6 @@ class OrionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
@ -198,8 +194,6 @@ class OrionDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -247,8 +241,6 @@ class OrionModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -260,14 +252,8 @@ class OrionModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -303,13 +289,10 @@ class OrionForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,13 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
from transformers import PaliGemmaConfig from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
@ -288,8 +287,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
@ -306,8 +303,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -21,13 +21,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights.""" """Inference-only persimmon model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PersimmonConfig from transformers import PersimmonConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -142,8 +142,6 @@ class PersimmonAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# [seq_length, 3 x hidden_size] # [seq_length, 3 x hidden_size]
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
@ -161,7 +159,7 @@ class PersimmonAttention(nn.Module):
k = self._merge_heads(k) k = self._merge_heads(k)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@ -189,8 +187,6 @@ class PersimmonDecoderLayer(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
@ -200,8 +196,6 @@ class PersimmonDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -248,8 +242,6 @@ class PersimmonModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -261,13 +253,8 @@ class PersimmonModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states = self.layers[i]( hidden_states = layer(positions, hidden_states)
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
@ -298,16 +285,12 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
): ):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -36,13 +36,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PhiConfig from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -126,13 +126,11 @@ class PhiAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@ -186,16 +184,12 @@ class PhiLayer(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn( attn_outputs = self.self_attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
feed_forward_hidden_states = self.mlp(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual hidden_states = attn_outputs + feed_forward_hidden_states + residual
@ -234,8 +228,6 @@ class PhiModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -247,14 +239,8 @@ class PhiModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
@ -304,13 +290,10 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
@ -231,8 +231,6 @@ class Phi3SmallSelfAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]: Optional[Tuple[torch.Tensor]]]:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
@ -248,7 +246,7 @@ class Phi3SmallSelfAttention(nn.Module):
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@ -282,8 +280,6 @@ class Phi3SmallDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -291,8 +287,6 @@ class Phi3SmallDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -338,8 +332,6 @@ class Phi3SmallModel(nn.Module):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -354,14 +346,8 @@ class Phi3SmallModel(nn.Module):
else: else:
assert intermediate_tensors assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
@ -438,16 +424,12 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
output_hidden_states = self.model( output_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -23,7 +23,6 @@ import torch.nn as nn
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
ProcessorMixin) ProcessorMixin)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -672,8 +671,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object): **kwargs: object):
@ -691,8 +688,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only PhiMoE model.""" """Inference-only PhiMoE model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -357,13 +357,11 @@ class PhiMoEAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -410,8 +408,6 @@ class PhiMoEDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
@ -422,8 +418,6 @@ class PhiMoEDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
@ -478,8 +472,6 @@ class PhiMoEModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -494,13 +486,10 @@ class PhiMoEModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
@ -571,13 +560,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -16,7 +16,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
from transformers.models.pixtral.modeling_pixtral import ( from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
@ -270,8 +269,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -291,8 +288,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)

View File

@ -15,13 +15,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model.""" """Inference-only IBM/NASA Prithvi Geospatial model."""
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union from typing import Iterable, Mapping, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree, from vllm.model_executor.models.interfaces import (IsAttentionFree,
@ -181,8 +180,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,

View File

@ -6,13 +6,13 @@
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights.""" """Inference-only QWen model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -124,13 +124,11 @@ class QWenAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.c_proj(attn_output) output, _ = self.c_proj(attn_output)
return output return output
@ -168,8 +166,6 @@ class QWenBlock(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -181,8 +177,6 @@ class QWenBlock(nn.Module):
hidden_states = self.attn( hidden_states = self.attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -225,8 +219,6 @@ class QWenModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -241,13 +233,10 @@ class QWenModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.h[self.start_layer:self.end_layer]:
layer = self.h[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -373,12 +362,9 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states

View File

@ -23,13 +23,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -170,13 +170,11 @@ class Qwen2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -233,8 +231,6 @@ class Qwen2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -247,8 +243,6 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -328,8 +322,6 @@ class Qwen2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -343,13 +335,10 @@ class Qwen2Model(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -468,13 +457,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
@ -553,12 +539,9 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.model(input_ids, positions, kv_caches, attn_metadata, return self.model(input_ids, positions, intermediate_tensors)
intermediate_tensors)
def pooler( def pooler(
self, self,

View File

@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
@ -992,8 +991,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -1047,8 +1044,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -22,8 +22,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple, from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict,
TypedDict, Union) Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -33,7 +33,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioProcessor) Qwen2AudioProcessor)
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -380,8 +379,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -400,8 +397,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -23,14 +23,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
@ -232,13 +232,11 @@ class Qwen2MoeAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -296,8 +294,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
@ -310,8 +306,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -358,8 +352,6 @@ class Qwen2MoeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -373,11 +365,8 @@ class Qwen2MoeModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
@ -416,13 +405,10 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -5,12 +5,11 @@
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" """Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.score(hidden_states) logits, _ = self.score(hidden_states)
return logits return logits

View File

@ -24,8 +24,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
Set, Tuple, Type, TypedDict, Union) Tuple, Type, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig) Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )

View File

@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
vision_embeddings) vision_embeddings)
input_ids = None input_ids = None
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states

View File

@ -1,13 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.roberta(input_ids=input_ids, return self.roberta(input_ids=input_ids,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)

View File

@ -23,13 +23,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights.""" """Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -172,13 +172,11 @@ class SolarAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
@ -315,8 +309,6 @@ class SolarModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -357,8 +349,6 @@ class SolarModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output

View File

@ -20,13 +20,13 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import StableLmConfig from transformers import StableLmConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states)
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -19,13 +19,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Starcoder2Config from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states

View File

@ -22,7 +22,7 @@ from torch import nn
from transformers import AutoModel, PreTrainedModel from transformers import AutoModel, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
@ -59,7 +59,6 @@ def vllm_flash_attention_forward(
# Transformers kwargs # Transformers kwargs
scaling: Optional[float] = None, scaling: Optional[float] = None,
# vLLM kwargs # vLLM kwargs
attn_metadata: Optional[AttentionMetadata] = None,
attention_instances: Optional[list[Attention]] = None, attention_instances: Optional[list[Attention]] = None,
**kwargs): **kwargs):
self_attn = attention_instances[module.layer_idx] self_attn = attention_instances[module.layer_idx]
@ -68,12 +67,7 @@ def vllm_flash_attention_forward(
hidden = query.shape[-2] hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward( return self_attn.forward(query, key, value), None
query,
key,
value,
kv_cache=None, # argument not used
attn_metadata=attn_metadata), None
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: list[torch.Tensor], # argument not used
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant):
input_ids[None, ...], input_ids[None, ...],
use_cache=False, use_cache=False,
position_ids=positions[None, ...], position_ids=positions[None, ...],
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances, attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now return_dict=False)[0][0, ...] # we remove batch dimension for now

View File

@ -4,8 +4,8 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math import math
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
Tuple, TypedDict, Union) TypedDict, Union)
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm import envs from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
# TODO(ywang96): remove this block after v0 is deprecated. # TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
attn_metadata = get_forward_context().attn_metadata
merge_multimodal_embeddings_from_map( merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings, inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"]) attn_metadata.multi_modal_placeholder_index_maps["audio"])
@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[torch.Tensor] = None, intermediate_tensors: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> Union[torch.Tensor, IntermediateTensors]: **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
elif inputs_embeds is None: elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
# TODO(ywang96): remove attn_metadata from get_input_embeddings
# after v0 is deprecated
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings, multimodal_embeddings)
attn_metadata)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor) WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
q, _ = self.q_proj(hidden_states) q, _ = self.q_proj(hidden_states)
@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
else: else:
k = v = None k = v = None
attn_output = self.attn( attn_output = self.attn(q, k, v)
q,
k,
v,
kv_cache,
attn_metadata,
)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states, hidden_states = self.self_attn(hidden_states=hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
hidden_states = self.encoder_attn( hidden_states = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
self.embed_positions.weight.copy_( self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)) sinusoids(*self.embed_positions.weight.shape))
def forward( def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
self,
input_features: Union[torch.Tensor, List[torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
hidden_states = [] hidden_states = []
for features in input_features: for features in input_features:
embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv1(features))
@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
hidden_states.append(embeds) hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states) hidden_states = torch.cat(hidden_states)
for idx, encoder_layer in enumerate(self.layers): for encoder_layer in self.layers:
hidden_states = encoder_layer( hidden_states = encoder_layer(hidden_states)
hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
input_ids, input_ids,
positions: torch.Tensor, positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
): ):
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
positions = self.embed_positions(positions) positions = self.embed_positions(positions)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
for idx, decoder_layer in enumerate(self.layers): for decoder_layer in self.layers:
hidden_states = decoder_layer( hidden_states = decoder_layer(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
) )
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
encoder_outputs = self.get_encoder_outputs( encoder_outputs = self.get_encoder_outputs(input_features)
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
encoder_hidden_states=encoder_outputs, encoder_hidden_states=encoder_outputs,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
) )
return decoder_outputs return decoder_outputs
def get_encoder_outputs( def get_encoder_outputs(
self, self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if input_features is None: if input_features is None:
return None return None
return self.encoder( return self.encoder(input_features)
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
input_features=audio_input["input_features"], input_features=audio_input["input_features"],
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
) )
return decoder_outputs return decoder_outputs
def get_multimodal_embeddings( def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
self,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> Optional[NestedTensors]:
# TODO: This method does not obey the interface for SupportsMultiModal. # TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1. # Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs( return self.model.get_encoder_outputs(audio_input["input_features"])
audio_input["input_features"],
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since # TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once # Whisper does not have encoder text tokens. Refactor this once

View File

@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),

View File

@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
kv_caches: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
model = self.model model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = model( hidden_states = model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
def profile_run(self) -> None: def profile_run(self) -> None:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches = [
torch.tensor((), dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them. # TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_profile_with_lora(self.lora_config, with self.maybe_profile_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
# Trigger compilation for general shape. # Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens, hidden_states = self._dummy_run(self.max_num_tokens)
dummy_kv_caches)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices] hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)

View File

@ -13,11 +13,10 @@ import torch.nn as nn
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
@ -623,7 +622,6 @@ class TPUModelRunner:
assert self.model is not None assert self.model is not None
selected_token_ids = self.model(prompt_data.input_tokens, selected_token_ids = self.model(prompt_data.input_tokens,
prompt_data.input_positions, prompt_data.input_positions,
prompt_data.attn_metadata,
self.kv_caches) self.kv_caches)
# In parallel to TPU execution, prepare the next iteration # In parallel to TPU execution, prepare the next iteration
@ -662,7 +660,6 @@ class TPUModelRunner:
assert self.model is not None assert self.model is not None
selected_token_ids = self.model(decode_data.input_tokens, selected_token_ids = self.model(decode_data.input_tokens,
decode_data.input_positions, decode_data.input_positions,
decode_data.attn_metadata,
self.kv_caches) self.kv_caches)
# Transfer sampled tokens from TPU to CPU # Transfer sampled tokens from TPU to CPU
@ -839,7 +836,7 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0): with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None assert self.model is not None
self.model(token_ids, position_ids, attn_metadata, kv_caches) self.model(token_ids, position_ids, kv_caches)
def capture_model(self) -> None: def capture_model(self) -> None:
"""Compile the model.""" """Compile the model."""
@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module):
self, self,
token_ids: torch.Tensor, token_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor: ) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token. """Executes the forward pass of the model and samples the next token.
@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module):
Args: Args:
token_ids: The input token IDs of shape [batch_size, seq_len]. token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size]. input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size]. t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size]. p: The top-p probability of shape [batch_size].
@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module):
memory profiling at initialization. memory profiling at initialization.
""" """
# Skip this in memory profiling at initialization. # Skip this in memory profiling at initialization.
if attn_metadata is not None and kv_caches[0][0].numel() > 0: if kv_caches[0][0].numel() > 0:
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension # index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape # is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it # [num_kv_heads, num_blocks, block_size, head_size]. To make it
@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module):
attn_metadata.slot_mapping = slot_mapping attn_metadata.slot_mapping = slot_mapping
assert self.model is not None assert self.model is not None
hidden_states = self.model( hidden_states = self.model(token_ids, position_ids)
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = hidden_states.flatten(0, 1) hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)

View File

@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner(
model_input.encoder_input_tokens, model_input.encoder_input_tokens,
"encoder_positions": "encoder_positions":
model_input.encoder_input_positions, model_input.encoder_input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
"intermediate_tensors": "intermediate_tensors":

View File

@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**execute_model_kwargs, **execute_model_kwargs,
**multimodal_kwargs, **multimodal_kwargs,

View File

@ -41,16 +41,6 @@ class CPUPoolingModelRunner(
raise ValueError( raise ValueError(
"CPU worker does not support multi-step execution.") "CPU worker does not support multi-step execution.")
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
model_executable = self.model model_executable = self.model
cross_enc_kwargs = {} cross_enc_kwargs = {}
if model_input.token_type_ids is not None: if model_input.token_type_ids is not None:
@ -60,10 +50,6 @@ class CPUPoolingModelRunner(
model_input.input_tokens, model_input.input_tokens,
"positions": "positions":
model_input.input_positions, model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
**cross_enc_kwargs, **cross_enc_kwargs,

View File

@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
positions=model_input.input_positions, positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens, encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions, encoder_positions=model_input.encoder_input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
or encoder_dummy_data.multi_modal_placeholders) or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
finished_requests_ids = [seq.request_id for seq in seqs] finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input( model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids) seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None intermediate_tensors = None
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, None, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return

View File

@ -384,11 +384,12 @@ class HpuModelAdapter:
if 'virtual_engine' in kwargs: if 'virtual_engine' in kwargs:
virtual_engine = kwargs.pop('virtual_engine') virtual_engine = kwargs.pop('virtual_engine')
input_ids = kwargs['input_ids'] input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata( attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'),
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), input_ids.size(0),
input_ids.size(1),
input_ids.device, self.dtype) input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask')) LoraMask.setLoraMask(kwargs.pop('lora_mask'))
with set_forward_context(kwargs['attn_metadata'], self.vllm_config, with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine): virtual_engine):
hidden_states = self.model(*args, **kwargs) hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len, max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs) self.scheduler_config.max_num_seqs)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, self.warmup_scenario(max_batch_size, max_seq_len, True, False, True)
False, True)
return return
def warmup_scenario(self, def warmup_scenario(self,
batch_size, batch_size,
seq_len, seq_len,
is_prompt, is_prompt,
kv_caches,
is_pt_profiler_run=False, is_pt_profiler_run=False,
is_lora_profile_run=False) -> None: is_lora_profile_run=False) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler.start() profiler.start()
for _ in range(times): for _ in range(times):
inputs = self.prepare_model_input(seqs) inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True) self.execute_model(inputs, None, warmup_mode=True)
torch.hpu.synchronize() torch.hpu.synchronize()
if profiler: if profiler:
profiler.step() profiler.step()
@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
f"free_mem:{free_mem}") f"free_mem:{free_mem}")
logger.info(msg) logger.info(msg)
def warmup_all_buckets(self, buckets, is_prompt, kv_caches): def warmup_all_buckets(self, buckets, is_prompt):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)): for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i, self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len) len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) self.warmup_scenario(batch_size, seq_len, is_prompt)
def warmup_graphs(self, def warmup_graphs(self,
strategy, strategy,
buckets, buckets,
is_prompt, is_prompt,
kv_caches,
available_mem, available_mem,
starting_mem=0, starting_mem=0,
total_batch_seq=0.001): total_batch_seq=0.001):
@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.graphed_buckets.add(graphed_bucket) self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof: with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) self.warmup_scenario(batch_size, seq_len, is_prompt)
used_mem = align_workers(mem_prof.consumed_device_memory, used_mem = align_workers(mem_prof.consumed_device_memory,
torch.distributed.ReduceOp.MAX) torch.distributed.ReduceOp.MAX)
available_mem -= used_mem available_mem -= used_mem
@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
graphs = graph == 't' graphs = graph == 't'
if graphs: if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, self.warmup_scenario(int(bs), int(seq_len), is_prompt, True)
True)
raise AssertionError("Finished profiling") raise AssertionError("Finished profiling")
if self.skip_warmup: if self.skip_warmup:
logger.info("Skipping warmup...") logger.info("Skipping warmup...")
@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
with compile_only_mode_context( with compile_only_mode_context(
) if can_use_compile_only_mode else contextlib.nullcontext(): ) if can_use_compile_only_mode else contextlib.nullcontext():
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
True, kv_caches) True)
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
False, kv_caches) False)
if not self.enforce_eager and htorch.utils.internal.is_lazy(): if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \ assert self.mem_margin is not None, \
@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs( self.warmup_graphs(
prompt_strategy, self.bucketing_global_state.prompt_buckets, prompt_strategy, self.bucketing_global_state.prompt_buckets,
True, kv_caches, prompt_available_memory) True, prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \ mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs( self.warmup_graphs(
decode_strategy, self.bucketing_global_state.decode_buckets, decode_strategy, self.bucketing_global_state.decode_buckets,
False, kv_caches, decode_available_memory) False, decode_available_memory)
# Not all prompt buckets were captured, but all decode buckets # Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space # were captured and we have some free graph-allocated space
@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.warmup_graphs( self.warmup_graphs(
prompt_strategy, prompt_strategy,
self.bucketing_global_state.prompt_buckets, True, self.bucketing_global_state.prompt_buckets, True,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode, graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq)) mem_post_prompt, prompt_batch_seq))
@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_decode, _, _ = self.warmup_graphs( mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy, decode_strategy,
self.bucketing_global_state.decode_buckets, False, self.bucketing_global_state.decode_buckets, False,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode, graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq) mem_post_decode, decode_batch_seq)
@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": input_tokens, "input_ids": input_tokens,
"positions": input_positions, "positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": self.trim_attn_metadata(attn_metadata), "attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors, "intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask, "lora_mask": lora_mask,

View File

@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
from vllm.forward_context import set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module):
self.model( self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
) )
@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module):
output_hidden_or_intermediate_states = self.model( output_hidden_or_intermediate_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
) )
@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them. attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
del kv_caches
# Copy the input tensors to the input buffers. # Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)

View File

@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# path for warm up runs # path for warm up runs
if not model_input.is_multi_step: if not model_input.is_multi_step:
return self._base_model_runner.execute_model( return self._base_model_runner.execute_model(
frozen_model_input, kv_caches, intermediate_tensors, num_steps) frozen_model_input, None, intermediate_tensors, num_steps)
# make sure we skip the sampler on the lask rank and only pythonize # make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead. # if CPU is ahead.
@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Execute the model # Execute the model
output = self._base_model_runner.execute_model(frozen_model_input, output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches, None,
intermediate_tensors, intermediate_tensors,
num_steps=1) num_steps=1)

Some files were not shown because too many files have changed in this diff Show More