Remove unused kwargs from model definitions (#13555)
This commit is contained in:
parent
f61528d46d
commit
cdc1fa12eb
@ -74,8 +74,6 @@ def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
@ -16,8 +16,6 @@ Further update the model as follows:
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
+ pixel_values: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
```
|
||||
|
@ -644,11 +644,7 @@ def _run_encoder_attention_test(
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(
|
||||
reshaped_query, packed_qkv.key, packed_qkv.value,
|
||||
torch.tensor([],
|
||||
dtype=torch.float32,
|
||||
device=packed_qkv.query.device), attn_metadata)
|
||||
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
|
||||
|
||||
|
||||
def _run_decoder_self_attention_test(
|
||||
@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
|
||||
& attn_metadata
|
||||
'''
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||
assert packed_qkv is not None
|
||||
with set_forward_context(attn_metadata, vllm_config):
|
||||
@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
|
||||
kv_cache, attn_metadata)
|
||||
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
|
||||
|
||||
|
||||
def _run_encoder_decoder_cross_attention_test(
|
||||
@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
||||
|
||||
attn = test_rsrcs.attn
|
||||
kv_cache = test_rsrcs.kv_cache
|
||||
if cross_test_params is None:
|
||||
key = None
|
||||
value = None
|
||||
@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
||||
-1, test_pt.num_heads * test_pt.head_size)
|
||||
return attn.forward(reshaped_query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
return attn.forward(reshaped_query, key, value)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
@ -153,15 +153,10 @@ class Attention(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
|
||||
# directly, use `self.kv_cache` and
|
||||
# `get_forward_context().attn_metadata` instead.
|
||||
if self.calculate_kv_scales:
|
||||
ctx_attn_metadata = get_forward_context().attn_metadata
|
||||
if ctx_attn_metadata.enable_kv_scales_calculation:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(key, value)
|
||||
if self.use_output:
|
||||
output = torch.empty_like(query)
|
||||
@ -177,14 +172,14 @@ class Attention(nn.Module):
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
ctx_attn_metadata = forward_context.attn_metadata
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
ctx_attn_metadata,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
@ -193,10 +188,10 @@ class Attention(nn.Module):
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
forward_context = get_forward_context()
|
||||
ctx_attn_metadata = forward_context.attn_metadata
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, ctx_attn_metadata)
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -130,14 +131,14 @@ class MambaMixer(CustomOp):
|
||||
) if use_rms_norm else None
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams):
|
||||
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
||||
|
@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -376,17 +377,16 @@ class MambaMixer2(CustomOp):
|
||||
eps=rms_norm_eps)
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor] = None,
|
||||
):
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
seq_len, _ = hidden_states.shape
|
||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||
|
@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T:
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = super().forward(input_ids, positions, kv_caches,
|
||||
attn_metadata,
|
||||
hidden_states = super().forward(input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
|
@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -283,13 +283,11 @@ class ArcticAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual_input = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual_input + hidden_states
|
||||
|
||||
@ -400,8 +394,6 @@ class ArcticModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -413,11 +405,8 @@ class ArcticModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||
from transformers.models.aria.processing_aria import AriaProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
hidden_states = self.language_model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -20,13 +20,13 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -1,17 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Inference-only Bamba model."""
|
||||
# Added by the IBM Team, 2024
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor] = None,
|
||||
@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, attn_metadata,
|
||||
mamba_cache_params, sequence_idx)
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||
sequence_idx)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
@ -312,8 +305,6 @@ class BambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -323,6 +314,7 @@ class BambaModel(nn.Module):
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
@ -348,9 +340,7 @@ class BambaModel(nn.Module):
|
||||
num_attn = 0
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
kv_cache = None
|
||||
if isinstance(layer, BambaAttentionDecoderLayer):
|
||||
kv_cache = kv_caches[num_attn]
|
||||
num_attn += 1
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
@ -361,8 +351,6 @@ class BambaModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
sequence_idx=seq_idx,
|
||||
@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
@ -19,14 +19,14 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch BART model."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BartConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.DECODER)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module):
|
||||
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module):
|
||||
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
hidden_states
|
||||
torch.Tensor of *encoder* input embeddings.
|
||||
kv_cache:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Encoder layer output torch.Tensor
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
decoder_hidden_states
|
||||
torch.Tensor of *decoder* input embeddings.
|
||||
kv_cache:
|
||||
KV cache tensor
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
encoder_hidden_states
|
||||
torch.Tensor of *encoder* input embeddings.
|
||||
Returns:
|
||||
@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module):
|
||||
residual = decoder_hidden_states
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = self.self_attn(hidden_states=decoder_hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module):
|
||||
|
||||
hidden_states = self.encoder_attn(
|
||||
decoder_hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
@ -609,9 +588,8 @@ class BartEncoder(nn.Module):
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
def forward(self, input_ids: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
@ -620,10 +598,6 @@ class BartEncoder(nn.Module):
|
||||
provide it.
|
||||
positions
|
||||
Positions of *encoder* input sequence tokens.
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
@ -636,12 +610,8 @@ class BartEncoder(nn.Module):
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states=hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -693,9 +663,7 @@ class BartDecoder(nn.Module):
|
||||
|
||||
def forward(self, decoder_input_ids: torch.Tensor,
|
||||
decoder_positions: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
decoder_input_ids
|
||||
@ -706,10 +674,6 @@ class BartDecoder(nn.Module):
|
||||
Positions of *decoder* input sequence tokens.
|
||||
encoder_hidden_states:
|
||||
Tensor of encoder output embeddings
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
@ -725,11 +689,9 @@ class BartDecoder(nn.Module):
|
||||
|
||||
# decoder layers
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states = decoder_layer(
|
||||
decoder_hidden_states=hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
@ -768,8 +730,7 @@ class BartModel(nn.Module):
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
encoder_positions: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
@ -782,10 +743,6 @@ class BartModel(nn.Module):
|
||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||
encoder_positions:
|
||||
Positions of *encoder* input sequence tokens.
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Model output torch.Tensor
|
||||
"""
|
||||
@ -796,18 +753,14 @@ class BartModel(nn.Module):
|
||||
# Run encoder attention if a non-zero number of encoder tokens
|
||||
# are provided as input
|
||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||
positions=encoder_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
positions=encoder_positions)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids=input_ids,
|
||||
decoder_positions=positions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
return decoder_outputs
|
||||
|
||||
@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
*,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module):
|
||||
torch.Tensor of *encoder* input token ids.
|
||||
encoder_positions
|
||||
torch.Tensor of *encoder* position indices
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions, kv_caches, attn_metadata)
|
||||
encoder_positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
|
@ -1,15 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
for i in range(len(self.layer)):
|
||||
layer = self.layer[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -152,13 +150,8 @@ class BertLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
attn_output = self.attention(hidden_states)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
return output
|
||||
@ -191,10 +184,8 @@ class BertAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
self_output = self.self(hidden_states, kv_cache, attn_metadata)
|
||||
self_output = self.self(hidden_states)
|
||||
return self.output(self_output, hidden_states)
|
||||
|
||||
|
||||
@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output = self.attn(q, k, v)
|
||||
return output
|
||||
|
||||
|
||||
@ -343,8 +332,6 @@ class BertModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
@ -352,13 +339,14 @@ class BertModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
||||
return self.encoder(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
attn_metadata=attn_metadata)
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.bert(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
attn_metadata=attn_metadata,
|
||||
token_type_ids=token_type_ids)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -9,7 +9,6 @@ import torch.nn as nn
|
||||
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
|
||||
apply_chunking_to_forward)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -18,13 +18,13 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
del position_ids # Unused.
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
|
||||
attention_output = self.self_attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
attention_output = attention_output + residual
|
||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||
@ -266,8 +260,6 @@ class BloomModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -279,14 +271,8 @@ class BloomModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import cached_property
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
||||
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
||||
ChameleonVQVAEConfig)
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -906,8 +896,6 @@ class ChameleonModel(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -921,13 +909,10 @@ class ChameleonModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
@ -2,13 +2,13 @@
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -108,19 +108,11 @@ class GLMAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
context_layer = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
)
|
||||
context_layer = self.attn(q, k, v)
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
return attn_output
|
||||
|
||||
@ -215,8 +207,6 @@ class GLMBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# hidden_states: [num_tokens, h]
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
@ -225,8 +215,6 @@ class GLMBlock(nn.Module):
|
||||
attention_output = self.self_attention(
|
||||
hidden_states=layernorm_output,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
@ -289,17 +277,10 @@ class GLMTransformer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i - self.start_layer],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(hidden_states=hidden_states,
|
||||
position_ids=position_ids)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module):
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
@ -21,14 +21,14 @@
|
||||
|
||||
# This file is based on the LLama model definition file in transformers
|
||||
"""PyTorch Cohere model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import CohereConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
if self.v1 or self.sliding_window:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
|
||||
hidden_states_attention = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
# Add everything together
|
||||
@ -311,8 +305,6 @@ class CohereModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -326,13 +318,10 @@ class CohereModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
hidden_states, _ = self.out_proj(attn_output)
|
||||
return hidden_states
|
||||
|
||||
@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + x
|
||||
residual = hidden_states
|
||||
@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, residual = self.norm_attn_norm(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for block in self.blocks[self.start_layer:self.end_layer]:
|
||||
hidden_states = block(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Deepseek model."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -370,8 +364,6 @@ class DeepseekModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -384,11 +376,8 @@ class DeepseekModel(nn.Module):
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -69,8 +68,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
@ -88,8 +85,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return self.shared_head(hidden_states)
|
||||
@ -122,8 +117,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
@ -131,8 +124,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches[spec_step_idx],
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
spec_step_idx,
|
||||
@ -165,16 +156,14 @@ class DeepSeekMTP(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
hidden_states = self.model(input_ids, positions,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
@ -279,8 +279,6 @@ class DeepseekV2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
@ -313,7 +311,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
v = torch.nn.functional.pad(
|
||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_local_heads,
|
||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||
@ -451,8 +449,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
@ -462,8 +458,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe)
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
@ -532,8 +527,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -546,8 +539,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -608,8 +599,6 @@ class DeepseekV2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -624,11 +613,8 @@ class DeepseekV2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -665,13 +651,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -13,7 +13,6 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
@ -595,8 +594,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object):
|
||||
@ -614,8 +611,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -121,8 +120,6 @@ class EAGLE(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -140,8 +137,6 @@ class EAGLE(nn.Module):
|
||||
input_ids=None,
|
||||
inputs_embeds=inputs_embeds,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
return hidden_states
|
||||
|
@ -24,12 +24,12 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Exaone model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -179,13 +179,11 @@ class ExaoneAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -225,14 +223,10 @@ class ExaoneBlockAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
return self.attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -288,8 +282,6 @@ class ExaoneDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -301,8 +293,6 @@ class ExaoneDecoderLayer(nn.Module):
|
||||
hidden_states = self.attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -365,8 +355,6 @@ class ExaoneModel(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -381,13 +369,10 @@ class ExaoneModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
@ -471,14 +456,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
model_output = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
|
@ -20,14 +20,14 @@
|
||||
"""PyTorch Falcon model."""
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -190,8 +190,6 @@ class FalconAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
@ -199,7 +197,7 @@ class FalconAttention(nn.Module):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_rotary:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output, bias = self.dense(attn_output)
|
||||
return attn_output, bias
|
||||
|
||||
@ -291,8 +289,6 @@ class FalconDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
@ -306,8 +302,6 @@ class FalconDecoderLayer(nn.Module):
|
||||
attention_output, attention_bias = self.self_attention(
|
||||
positions=positions,
|
||||
hidden_states=attention_layernorm_out,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||
attention_output += attention_bias
|
||||
@ -384,8 +378,6 @@ class FalconModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -396,14 +388,8 @@ class FalconModel(nn.Module):
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
@ -450,14 +436,11 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -50,8 +49,7 @@ class Florence2LanguageModel(nn.Module):
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
encoder_positions: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
@ -64,10 +62,6 @@ class Florence2LanguageModel(nn.Module):
|
||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||
encoder_positions:
|
||||
Positions of *encoder* input sequence tokens.
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Model output torch.Tensor
|
||||
"""
|
||||
@ -78,18 +72,14 @@ class Florence2LanguageModel(nn.Module):
|
||||
# Run encoder attention if a non-zero number of encoder tokens
|
||||
# are provided as input
|
||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||
positions=encoder_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
positions=encoder_positions)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids=input_ids,
|
||||
decoder_positions=positions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
return decoder_outputs
|
||||
|
||||
@ -122,8 +112,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
@ -136,15 +124,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
torch.Tensor of *encoder* input token ids.
|
||||
encoder_positions
|
||||
torch.Tensor of *encoder* position indices
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions, kv_caches, attn_metadata)
|
||||
encoder_positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@ -213,8 +197,6 @@ class Florence2ForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
*,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
@ -231,15 +213,11 @@ class Florence2ForConditionalGeneration(nn.Module):
|
||||
torch.Tensor of *encoder* input token ids.
|
||||
encoder_positions
|
||||
torch.Tensor of *encoder* position indices
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.language_model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions, kv_caches, attn_metadata)
|
||||
encoder_positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
|
@ -25,7 +25,6 @@ import torch.nn as nn
|
||||
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||
FuyuProcessor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -351,8 +350,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -371,8 +368,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -16,13 +16,13 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
||||
from functools import cache
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GemmaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -183,13 +183,11 @@ class GemmaAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -233,8 +231,6 @@ class GemmaDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -247,8 +243,6 @@ class GemmaDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -298,8 +292,6 @@ class GemmaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -313,13 +305,10 @@ class GemmaModel(nn.Module):
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -370,13 +359,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -15,13 +15,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Gemma2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -164,13 +164,11 @@ class Gemma2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -220,8 +218,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
@ -233,8 +229,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
@ -284,8 +278,6 @@ class Gemma2Model(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -300,13 +292,10 @@ class Gemma2Model(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -415,13 +404,10 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
# https://github.com/THUDM/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from typing import List, Literal, Mapping, Optional, TypedDict, Union
|
||||
from typing import Literal, Mapping, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -15,7 +15,6 @@ from transformers import PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -628,8 +627,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -645,8 +642,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
@ -18,13 +18,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -92,12 +92,10 @@ class GPT2Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -164,16 +162,10 @@ class GPT2Block(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
attn_output = self.attn(hidden_states=hidden_states)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
@ -222,8 +214,6 @@ class GPT2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -236,11 +226,8 @@ class GPT2Model(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -279,14 +266,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -19,13 +19,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTBigCodeConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -101,8 +101,6 @@ class GPTBigCodeAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split(
|
||||
@ -112,7 +110,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -173,16 +171,10 @@ class GPTBigCodeBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
attn_output = self.attn(hidden_states=hidden_states, )
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
@ -234,8 +226,6 @@ class GPTBigCodeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -246,11 +236,8 @@ class GPTBigCodeModel(nn.Module):
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -302,14 +289,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -17,13 +17,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTJConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -104,13 +104,11 @@ class GPTJAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -167,16 +165,12 @@ class GPTJBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
hidden_states = attn_output + mlp_output + residual
|
||||
@ -217,8 +211,6 @@ class GPTJModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -229,14 +221,8 @@ class GPTJModel(nn.Module):
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
@ -273,14 +259,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -17,13 +17,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -104,13 +104,11 @@ class GPTNeoXAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -167,15 +165,11 @@ class GPTNeoXLayer(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.input_layernorm(hidden_states)
|
||||
attn_output = self.attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=attn_input,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
@ -230,8 +224,6 @@ class GPTNeoXModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -242,14 +234,8 @@ class GPTNeoXModel(nn.Module):
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -285,14 +271,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.gpt_neox(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GraniteConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -166,13 +166,11 @@ class GraniteAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -233,8 +231,6 @@ class GraniteDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -242,8 +238,6 @@ class GraniteDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
# Fully Connected
|
||||
@ -300,8 +294,6 @@ class GraniteModel(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -318,14 +310,8 @@ class GraniteModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -405,13 +391,10 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GraniteMoe model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.granitemoe import GraniteMoeConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -173,13 +173,11 @@ class GraniteMoeAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -226,8 +224,6 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -235,8 +231,6 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
residual = hidden_states
|
||||
@ -287,8 +281,6 @@ class GraniteMoeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -303,11 +295,8 @@ class GraniteMoeModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -377,13 +366,10 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -1,15 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from array import array
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.xformers import XFormersImpl
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import PoolerHead
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
@ -217,13 +217,12 @@ class GritLM(LlamaForCausalLM):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
# Change attention to non-causal for pooling tasks.
|
||||
if self.runner_type == "pooling":
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
assert attn_metadata.prefill_metadata.attn_bias is None
|
||||
attn_metadata.prefill_metadata.attn_bias = [
|
||||
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
||||
@ -232,8 +231,6 @@ class GritLM(LlamaForCausalLM):
|
||||
return super().forward(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -25,7 +25,6 @@ from torch import nn
|
||||
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
|
||||
Idefics3Processor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
@ -563,8 +562,6 @@ class Idefics3Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -572,8 +569,6 @@ class Idefics3Model(nn.Module):
|
||||
hidden_states = self.text_model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
@ -645,8 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -664,8 +657,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.model.text_model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
|
||||
overload, runtime_checkable)
|
||||
from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload,
|
||||
runtime_checkable)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -11,7 +11,6 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -46,8 +45,6 @@ class VllmModel(Protocol[T_co]):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
) -> T_co:
|
||||
...
|
||||
|
||||
@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
|
||||
if not callable(model_forward):
|
||||
return False
|
||||
|
||||
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
|
||||
vllm_kws = ("input_ids", "positions")
|
||||
missing_kws = tuple(kw for kw in vllm_kws
|
||||
if not supports_kw(model_forward, kw))
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.wqkv(hidden_states)
|
||||
q, k, v = self.split_qkv(qkv)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
|
||||
hidden_states = self.attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.v_head(hidden_states)
|
||||
return logits
|
||||
|
@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
visual_token_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
|
||||
hidden_states = self.attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
visual_token_mask: Optional[torch.Tensor] = None,
|
||||
@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
visual_token_mask=visual_token_mask,
|
||||
)
|
||||
|
@ -17,7 +17,6 @@ import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
|
@ -21,12 +21,12 @@
|
||||
"""Inference-only Jais model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
attn_output = self.attn(hidden_states=hidden_states, )
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
@ -266,8 +258,6 @@ class JAISModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[IntermediateTensors, torch.Tensor]:
|
||||
@ -285,11 +275,8 @@ class JAISModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[IntermediateTensors, torch.Tensor]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Inference-only Jamba model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import JambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, attn_metadata,
|
||||
mamba_cache_params)
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
@ -320,8 +311,6 @@ class JambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -339,12 +328,9 @@ class JambaModel(nn.Module):
|
||||
|
||||
kv_cache_index = 0
|
||||
mamba_cache_index = 0
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
kv_cache = None
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||
kv_cache = kv_caches[kv_cache_index]
|
||||
kv_cache_index += 1
|
||||
if isinstance(layer, JambaMambaDecoderLayer):
|
||||
current_state_layer = mamba_cache_index
|
||||
@ -355,8 +341,6 @@ class JambaModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
|
@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
@ -10,7 +10,6 @@ import torch.nn as nn
|
||||
from transformers import (BatchFeature, LlavaNextVideoConfig,
|
||||
LlavaNextVideoProcessor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""PyTorch MAMBA model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
||||
mamba_cache_params)
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -125,7 +122,6 @@ class MambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -146,7 +142,6 @@ class MambaModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer))
|
||||
@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""PyTorch MAMBA2 model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor],
|
||||
@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
||||
mamba_cache_params, sequence_idx)
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||
sequence_idx)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer),
|
||||
@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -23,13 +23,13 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
|
||||
q, k = q.float(), k.float()
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states * \
|
||||
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
|
||||
@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -29,7 +29,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
q, _ = self.q_a_proj(hidden_states)
|
||||
q = self.q_a_layernorm(q)
|
||||
@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_local_heads,
|
||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||
|
@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
output = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=vlm_embeddings,
|
||||
)
|
||||
|
@ -37,7 +37,6 @@ from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
||||
@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
output = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=vlm_embeddings,
|
||||
)
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -22,7 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -30,7 +30,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
model_parallel_size = get_tensor_model_parallel_world_size()
|
||||
tensor_parallel_size = get_tp_group().world_size
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.attention_heads
|
||||
self.head_dim = config.hidden_size // config.attention_heads
|
||||
self.num_local_heads = self.num_heads // model_parallel_size
|
||||
self.num_local_heads = self.num_heads // tensor_parallel_size
|
||||
self.q_size = self.num_local_heads * self.head_dim
|
||||
self.kv_size = self.num_local_heads * self.head_dim
|
||||
|
||||
@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_parallel_size = get_tensor_model_parallel_world_size()
|
||||
self.pipeline_parallel_rank = get_pp_group().rank_in_group
|
||||
self.tensor_parallel_size = get_tp_group().world_size
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.num_local_heads = self.num_heads // self.model_parallel_size
|
||||
self.num_local_heads = self.num_heads // self.tensor_parallel_size
|
||||
self.num_key_value_heads = self.config.num_key_value_heads
|
||||
self.num_local_key_value_heads = \
|
||||
self.num_key_value_heads // self.model_parallel_size
|
||||
self.num_key_value_heads // self.tensor_parallel_size
|
||||
self.dropout = config.dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
cross_attention_states: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv_dec, _ = self.qkv_proj(hidden_states)
|
||||
q, _, _ = qkv_dec.split(
|
||||
@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
|
||||
if attention_mask is not None:
|
||||
output = self._attention_with_mask(q, k, v, kv_cache,
|
||||
attention_mask,
|
||||
kv_range_for_decode,
|
||||
attn_metadata)
|
||||
output = self._attention_with_mask(q, k, v, attention_mask,
|
||||
kv_range_for_decode)
|
||||
else:
|
||||
output = self.attn(
|
||||
q.view(-1, self.num_local_heads * self.head_dim), k, v,
|
||||
kv_cache, attn_metadata)
|
||||
q.view(-1, self.num_local_heads * self.head_dim), k, v)
|
||||
out, _ = self.o_proj(output)
|
||||
return out
|
||||
|
||||
@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
kv_range_for_decode: List[Tuple[int, int]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
# Skip writing kv-cache for the initial profiling run.
|
||||
if len(kv_cache.shape) > 1:
|
||||
i = torch.ones(1, dtype=torch.float32)
|
||||
@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
cross_attention_mask: torch.Tensor,
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
full_text_row_masked_out_mask: torch.Tensor,
|
||||
kv_cache: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
cross_attention_states=cross_attention_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = full_text_row_masked_out_mask * hidden_states
|
||||
hidden_states = residual + self.cross_attn_attn_gate.tanh(
|
||||
@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
skip_cross_attention: bool,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
for decoder_layer in self.layers:
|
||||
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
|
||||
if not skip_cross_attention:
|
||||
hidden_states = decoder_layer(
|
||||
@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
full_text_row_masked_out_mask=
|
||||
full_text_row_masked_out_mask,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
elif isinstance(decoder_layer, LlamaDecoderLayer):
|
||||
hidden_states, residual = decoder_layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None,
|
||||
)
|
||||
hidden_states = hidden_states + residual
|
||||
@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
|
||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
skip_cross_attention: bool,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
skip_cross_attention=skip_cross_attention,
|
||||
)
|
||||
return hidden_states
|
||||
@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: object,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefill_tokens > 0 and \
|
||||
attn_metadata.num_decode_tokens > 0:
|
||||
raise ValueError("Chunk prefill not supported")
|
||||
@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
kv_range_for_decode=kv_range_for_decode,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
skip_cross_attention=skip_cross_attention,
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.q_norm is not None and self.k_norm is not None:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Self Attention
|
||||
@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Self Attention
|
||||
@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# Apply blocks one-by-one.
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -125,8 +125,6 @@ class MPTAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
del position_ids # unused.
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
@ -136,7 +134,7 @@ class MPTAttention(nn.Module):
|
||||
if self.qk_ln:
|
||||
q = self.q_ln(q)
|
||||
k = self.k_ln(k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -196,15 +194,11 @@ class MPTBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
x = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=x,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = hidden_states + x
|
||||
x = self.norm_2(hidden_states)
|
||||
@ -253,8 +247,6 @@ class MPTModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -267,14 +259,8 @@ class MPTModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for block in self.blocks[self.start_layer:self.end_layer]:
|
||||
hidden_states = block(position_ids, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
@ -306,14 +292,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
@ -27,7 +27,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -204,13 +204,11 @@ class NemotronAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -269,8 +267,6 @@ class NemotronDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -283,8 +279,6 @@ class NemotronDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -343,8 +337,6 @@ class NemotronModel(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -359,15 +351,8 @@ class NemotronModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -444,13 +429,10 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import OlmoConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -119,15 +119,13 @@ class OlmoAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -212,14 +210,11 @@ class OlmoDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Attention block.
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(positions, hidden_states, kv_cache,
|
||||
attn_metadata)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# MLP block.
|
||||
@ -263,8 +258,6 @@ class OlmoModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -281,14 +274,9 @@ class OlmoModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
# Apply blocks one-by-one.
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
hidden_states = self.layers[i](
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -332,16 +320,12 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -24,12 +24,12 @@
|
||||
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
|
||||
@ -153,14 +153,12 @@ class Olmo2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -239,13 +237,10 @@ class Olmo2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Attention block.
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(positions, hidden_states, kv_cache,
|
||||
attn_metadata)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
@ -287,8 +282,6 @@ class Olmo2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""
|
||||
@ -307,14 +300,9 @@ class Olmo2Model(nn.Module):
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
|
||||
# Apply blocks one-by-one.
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
hidden_states = self.layers[i](
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -357,15 +345,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
return hidden_states
|
||||
|
@ -12,13 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -168,14 +168,12 @@ class OlmoeAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -222,8 +220,6 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -237,8 +233,6 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -283,8 +277,6 @@ class OlmoeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -299,13 +291,10 @@ class OlmoeModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
@ -347,13 +336,10 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -18,13 +18,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OPT model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -107,12 +107,10 @@ class OPTAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -164,17 +162,13 @@ class OPTDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
@ -261,8 +255,6 @@ class OPTDecoder(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -277,11 +269,8 @@ class OPTDecoder(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -317,15 +306,11 @@ class OPTModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.decoder(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
@ -362,13 +347,10 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -5,13 +5,13 @@
|
||||
# Copyright (c) OrionStar Inc.
|
||||
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
|
||||
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -136,13 +136,11 @@ class OrionAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -189,8 +187,6 @@ class OrionDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -198,8 +194,6 @@ class OrionDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
@ -247,8 +241,6 @@ class OrionModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -260,14 +252,8 @@ class OrionModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -303,13 +289,10 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PaliGemmaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
@ -288,8 +287,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
@ -306,8 +303,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -21,13 +21,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only persimmon model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PersimmonConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -142,8 +142,6 @@ class PersimmonAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# [seq_length, 3 x hidden_size]
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
@ -161,7 +159,7 @@ class PersimmonAttention(nn.Module):
|
||||
k = self._merge_heads(k)
|
||||
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -189,8 +187,6 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
@ -200,8 +196,6 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -248,8 +242,6 @@ class PersimmonModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -261,13 +253,8 @@ class PersimmonModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
hidden_states = self.layers[i](
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
@ -298,16 +285,12 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -36,13 +36,13 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PhiConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -126,13 +126,11 @@ class PhiAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -186,16 +184,12 @@ class PhiLayer(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
attn_outputs = self.self_attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||
@ -234,8 +228,6 @@ class PhiModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -247,14 +239,8 @@ class PhiModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
@ -304,13 +290,10 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
@ -231,8 +231,6 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[Tuple[torch.Tensor]]]:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
@ -248,7 +246,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.dense(attn_output)
|
||||
|
||||
return output
|
||||
@ -282,8 +280,6 @@ class Phi3SmallDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -291,8 +287,6 @@ class Phi3SmallDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -338,8 +332,6 @@ class Phi3SmallModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -354,14 +346,8 @@ class Phi3SmallModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
@ -438,16 +424,12 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: Optional[torch.LongTensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
output_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -23,7 +23,6 @@ import torch.nn as nn
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -672,8 +671,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object):
|
||||
@ -691,8 +688,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -22,13 +22,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only PhiMoE model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -357,13 +357,11 @@ class PhiMoEAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -410,8 +408,6 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
@ -422,8 +418,6 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
@ -478,8 +472,6 @@ class PhiMoEModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -494,13 +486,10 @@ class PhiMoEModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
@ -571,13 +560,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -16,7 +16,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
|
||||
from transformers.models.pixtral.modeling_pixtral import (
|
||||
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
@ -270,8 +269,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -291,8 +288,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
|
@ -15,13 +15,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
@ -181,8 +180,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
|
@ -6,13 +6,13 @@
|
||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||
"""Inference-only QWen model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -124,13 +124,11 @@ class QWenAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.c_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -168,8 +166,6 @@ class QWenBlock(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -181,8 +177,6 @@ class QWenBlock(nn.Module):
|
||||
hidden_states = self.attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -225,8 +219,6 @@ class QWenModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -241,13 +233,10 @@ class QWenModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
for layer in self.h[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -373,12 +362,9 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
@ -23,13 +23,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -170,13 +170,11 @@ class Qwen2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -233,8 +231,6 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -247,8 +243,6 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -328,8 +322,6 @@ class Qwen2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -343,13 +335,10 @@ class Qwen2Model(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -468,13 +457,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
@ -553,12 +539,9 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors)
|
||||
return self.model(input_ids, positions, intermediate_tensors)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
|
@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -992,8 +991,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -1047,8 +1044,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -22,8 +22,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from functools import cached_property
|
||||
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -33,7 +33,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
||||
Qwen2AudioProcessor)
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -380,8 +379,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -400,8 +397,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
@ -23,14 +23,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
@ -232,13 +232,11 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -296,8 +294,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -310,8 +306,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -358,8 +352,6 @@ class Qwen2MoeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -373,11 +365,8 @@ class Qwen2MoeModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -416,13 +405,10 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -5,12 +5,11 @@
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
@ -24,8 +24,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Set, Tuple, Type, TypedDict, Union)
|
||||
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
|
||||
Tuple, Type, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
Qwen2VLConfig, Qwen2VLVisionConfig)
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.roberta(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
attn_metadata=attn_metadata,
|
||||
token_type_ids=token_type_ids)
|
||||
|
@ -23,13 +23,13 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Solar model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -172,13 +172,11 @@ class SolarAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -315,8 +309,6 @@ class SolarModel(nn.Module):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -357,8 +349,6 @@ class SolarModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
|
@ -20,13 +20,13 @@
|
||||
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
||||
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
||||
model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import StableLmConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -19,13 +19,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Starcoder2 model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Starcoder2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
@ -22,7 +22,7 @@ from torch import nn
|
||||
from transformers import AutoModel, PreTrainedModel
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
@ -59,7 +59,6 @@ def vllm_flash_attention_forward(
|
||||
# Transformers kwargs
|
||||
scaling: Optional[float] = None,
|
||||
# vLLM kwargs
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
attention_instances: Optional[list[Attention]] = None,
|
||||
**kwargs):
|
||||
self_attn = attention_instances[module.layer_idx]
|
||||
@ -68,12 +67,7 @@ def vllm_flash_attention_forward(
|
||||
hidden = query.shape[-2]
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
|
||||
return self_attn.forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache=None, # argument not used
|
||||
attn_metadata=attn_metadata), None
|
||||
return self_attn.forward(query, key, value), None
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
||||
@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[torch.Tensor], # argument not used
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant):
|
||||
input_ids[None, ...],
|
||||
use_cache=False,
|
||||
position_ids=positions[None, ...],
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
attention_instances=self.attention_instances,
|
||||
return_dict=False)[0][0, ...] # we remove batch dimension for now
|
||||
|
@ -4,8 +4,8 @@
|
||||
"""PyTorch Ultravox model."""
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
|
||||
# TODO(ywang96): remove this block after v0 is deprecated.
|
||||
if not envs.VLLM_USE_V1:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
merge_multimodal_embeddings_from_map(
|
||||
inputs_embeds, multimodal_embeddings,
|
||||
attn_metadata.multi_modal_placeholder_index_maps["audio"])
|
||||
@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
# TODO(ywang96): remove attn_metadata from get_input_embeddings
|
||||
# after v0 is deprecated
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings,
|
||||
attn_metadata)
|
||||
multimodal_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
WhisperProcessor)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
|
||||
@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
else:
|
||||
k = v = None
|
||||
|
||||
attn_output = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
|
||||
hidden_states = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features: Union[torch.Tensor, List[torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
|
||||
hidden_states = []
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
|
||||
hidden_states.append(embeds)
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
|
||||
input_ids,
|
||||
positions: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
positions = self.embed_positions(positions)
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
|
||||
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
encoder_outputs = self.get_encoder_outputs(
|
||||
input_features,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
encoder_outputs = self.get_encoder_outputs(input_features)
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_encoder_outputs(
|
||||
self,
|
||||
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if input_features is None:
|
||||
return None
|
||||
return self.encoder(
|
||||
input_features,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return self.encoder(input_features)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
input_features=audio_input["input_features"],
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Optional[NestedTensors]:
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
# TODO: This method does not obey the interface for SupportsMultiModal.
|
||||
# Refactor this once encoder/decoder support is implemented in V1.
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
return self.model.get_encoder_outputs(
|
||||
audio_input["input_features"],
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return self.model.get_encoder_outputs(audio_input["input_features"])
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO: This method just returns the decoder sequence embeddings since
|
||||
# Whisper does not have encoder text tokens. Refactor this once
|
||||
|
@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
|
@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
model = self.model
|
||||
if kv_caches is None:
|
||||
kv_caches = self.kv_caches
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value `None`.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
# it is important to create tensors inside the loop, rather than
|
||||
# multiplying the list, to avoid Dynamo from treating them as
|
||||
# tensor aliasing.
|
||||
dummy_kv_caches = [
|
||||
torch.tensor((), dtype=torch.float32, device=self.device)
|
||||
for _ in range(self.num_attn_layers)
|
||||
]
|
||||
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
||||
@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
with self.maybe_profile_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
@ -13,11 +13,10 @@ import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sampling_params import SamplingType
|
||||
@ -623,7 +622,6 @@ class TPUModelRunner:
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(prompt_data.input_tokens,
|
||||
prompt_data.input_positions,
|
||||
prompt_data.attn_metadata,
|
||||
self.kv_caches)
|
||||
|
||||
# In parallel to TPU execution, prepare the next iteration
|
||||
@ -662,7 +660,6 @@ class TPUModelRunner:
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(decode_data.input_tokens,
|
||||
decode_data.input_positions,
|
||||
decode_data.attn_metadata,
|
||||
self.kv_caches)
|
||||
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
@ -839,7 +836,7 @@ class TPUModelRunner:
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
self.model(token_ids, position_ids, attn_metadata, kv_caches)
|
||||
self.model(token_ids, position_ids, kv_caches)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module):
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module):
|
||||
Args:
|
||||
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||
attn_metadata: The Pallas attention metadata.
|
||||
input_lens: The actual input lengths of shape [batch_size].
|
||||
t: The sampling temperature of shape [batch_size].
|
||||
p: The top-p probability of shape [batch_size].
|
||||
@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module):
|
||||
memory profiling at initialization.
|
||||
"""
|
||||
# Skip this in memory profiling at initialization.
|
||||
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
|
||||
if kv_caches[0][0].numel() > 0:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module):
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
token_ids,
|
||||
position_ids,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = self.model(token_ids, position_ids)
|
||||
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner(
|
||||
model_input.encoder_input_tokens,
|
||||
"encoder_positions":
|
||||
model_input.encoder_input_positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
"attn_metadata":
|
||||
model_input.attn_metadata,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
"intermediate_tensors":
|
||||
|
@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**execute_model_kwargs,
|
||||
**multimodal_kwargs,
|
||||
|
@ -41,16 +41,6 @@ class CPUPoolingModelRunner(
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
model_executable = self.model
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_type_ids is not None:
|
||||
@ -60,10 +50,6 @@ class CPUPoolingModelRunner(
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
"attn_metadata":
|
||||
model_input.attn_metadata,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
**cross_enc_kwargs,
|
||||
|
@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
positions=model_input.input_positions,
|
||||
encoder_input_ids=model_input.encoder_input_tokens,
|
||||
encoder_positions=model_input.encoder_input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
or encoder_dummy_data.multi_modal_placeholders)
|
||||
seqs.append(seq)
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
intermediate_tensors = None
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
self.execute_model(model_input, None, intermediate_tensors)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
|
@ -384,11 +384,12 @@ class HpuModelAdapter:
|
||||
if 'virtual_engine' in kwargs:
|
||||
virtual_engine = kwargs.pop('virtual_engine')
|
||||
input_ids = kwargs['input_ids']
|
||||
kwargs['attn_metadata'] = self._update_metadata(
|
||||
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
|
||||
attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'),
|
||||
input_ids.size(0),
|
||||
input_ids.size(1),
|
||||
input_ids.device, self.dtype)
|
||||
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
|
||||
with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
|
||||
with set_forward_context(attn_metadata, self.vllm_config,
|
||||
virtual_engine):
|
||||
hidden_states = self.model(*args, **kwargs)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
|
||||
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
|
||||
self.scheduler_config.max_num_seqs)
|
||||
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
|
||||
False, True)
|
||||
self.warmup_scenario(max_batch_size, max_seq_len, True, False, True)
|
||||
return
|
||||
|
||||
def warmup_scenario(self,
|
||||
batch_size,
|
||||
seq_len,
|
||||
is_prompt,
|
||||
kv_caches,
|
||||
is_pt_profiler_run=False,
|
||||
is_lora_profile_run=False) -> None:
|
||||
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
||||
@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
profiler.start()
|
||||
for _ in range(times):
|
||||
inputs = self.prepare_model_input(seqs)
|
||||
self.execute_model(inputs, kv_caches, warmup_mode=True)
|
||||
self.execute_model(inputs, None, warmup_mode=True)
|
||||
torch.hpu.synchronize()
|
||||
if profiler:
|
||||
profiler.step()
|
||||
@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
f"free_mem:{free_mem}")
|
||||
logger.info(msg)
|
||||
|
||||
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
|
||||
def warmup_all_buckets(self, buckets, is_prompt):
|
||||
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
|
||||
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
|
||||
len(buckets), batch_size, seq_len)
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt)
|
||||
|
||||
def warmup_graphs(self,
|
||||
strategy,
|
||||
buckets,
|
||||
is_prompt,
|
||||
kv_caches,
|
||||
available_mem,
|
||||
starting_mem=0,
|
||||
total_batch_seq=0.001):
|
||||
@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt)
|
||||
used_mem = align_workers(mem_prof.consumed_device_memory,
|
||||
torch.distributed.ReduceOp.MAX)
|
||||
available_mem -= used_mem
|
||||
@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
graphs = graph == 't'
|
||||
if graphs:
|
||||
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
|
||||
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
|
||||
True)
|
||||
self.warmup_scenario(int(bs), int(seq_len), is_prompt, True)
|
||||
raise AssertionError("Finished profiling")
|
||||
if self.skip_warmup:
|
||||
logger.info("Skipping warmup...")
|
||||
@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
with compile_only_mode_context(
|
||||
) if can_use_compile_only_mode else contextlib.nullcontext():
|
||||
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
|
||||
True, kv_caches)
|
||||
True)
|
||||
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
|
||||
False, kv_caches)
|
||||
False)
|
||||
|
||||
if not self.enforce_eager and htorch.utils.internal.is_lazy():
|
||||
assert self.mem_margin is not None, \
|
||||
@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
|
||||
self.warmup_graphs(
|
||||
prompt_strategy, self.bucketing_global_state.prompt_buckets,
|
||||
True, kv_caches, prompt_available_memory)
|
||||
True, prompt_available_memory)
|
||||
mem_post_decode, decode_batch_seq, decode_captured_all = \
|
||||
self.warmup_graphs(
|
||||
decode_strategy, self.bucketing_global_state.decode_buckets,
|
||||
False, kv_caches, decode_available_memory)
|
||||
False, decode_available_memory)
|
||||
|
||||
# Not all prompt buckets were captured, but all decode buckets
|
||||
# were captured and we have some free graph-allocated space
|
||||
@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
self.warmup_graphs(
|
||||
prompt_strategy,
|
||||
self.bucketing_global_state.prompt_buckets, True,
|
||||
kv_caches,
|
||||
graph_free_mem - mem_post_prompt - mem_post_decode,
|
||||
mem_post_prompt, prompt_batch_seq))
|
||||
|
||||
@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
mem_post_decode, _, _ = self.warmup_graphs(
|
||||
decode_strategy,
|
||||
self.bucketing_global_state.decode_buckets, False,
|
||||
kv_caches,
|
||||
graph_free_mem - mem_post_prompt - mem_post_decode,
|
||||
mem_post_decode, decode_batch_seq)
|
||||
|
||||
@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": self.trim_attn_metadata(attn_metadata),
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"lora_mask": lora_mask,
|
||||
|
@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||
graph_capture)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module):
|
||||
self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module):
|
||||
output_hidden_or_intermediate_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# KV caches are fixed tensors, so we don't need to copy them.
|
||||
del kv_caches
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
|
@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
# path for warm up runs
|
||||
if not model_input.is_multi_step:
|
||||
return self._base_model_runner.execute_model(
|
||||
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
|
||||
frozen_model_input, None, intermediate_tensors, num_steps)
|
||||
|
||||
# make sure we skip the sampler on the lask rank and only pythonize
|
||||
# if CPU is ahead.
|
||||
@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
|
||||
# Execute the model
|
||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||
kv_caches,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
num_steps=1)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user