Remove unused kwargs from model definitions (#13555)
This commit is contained in:
parent
f61528d46d
commit
cdc1fa12eb
@ -74,8 +74,6 @@ def forward(
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
@ -16,8 +16,6 @@ Further update the model as follows:
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
+ pixel_values: torch.Tensor,
|
+ pixel_values: torch.Tensor,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
```
|
```
|
||||||
|
@ -644,11 +644,7 @@ def _run_encoder_attention_test(
|
|||||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
reshaped_query = packed_qkv.query.view(
|
reshaped_query = packed_qkv.query.view(
|
||||||
-1, test_pt.num_heads * test_pt.head_size)
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
return attn.forward(
|
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
|
||||||
reshaped_query, packed_qkv.key, packed_qkv.value,
|
|
||||||
torch.tensor([],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=packed_qkv.query.device), attn_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_decoder_self_attention_test(
|
def _run_decoder_self_attention_test(
|
||||||
@ -682,7 +678,6 @@ def _run_decoder_self_attention_test(
|
|||||||
& attn_metadata
|
& attn_metadata
|
||||||
'''
|
'''
|
||||||
attn = test_rsrcs.attn
|
attn = test_rsrcs.attn
|
||||||
kv_cache = test_rsrcs.kv_cache
|
|
||||||
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
|
||||||
assert packed_qkv is not None
|
assert packed_qkv is not None
|
||||||
with set_forward_context(attn_metadata, vllm_config):
|
with set_forward_context(attn_metadata, vllm_config):
|
||||||
@ -695,8 +690,7 @@ def _run_decoder_self_attention_test(
|
|||||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
reshaped_query = packed_qkv.query.view(
|
reshaped_query = packed_qkv.query.view(
|
||||||
-1, test_pt.num_heads * test_pt.head_size)
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
|
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
|
||||||
kv_cache, attn_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_encoder_decoder_cross_attention_test(
|
def _run_encoder_decoder_cross_attention_test(
|
||||||
@ -744,7 +738,6 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
assert decoder_test_params.packed_qkvo.packed_qkv is not None
|
||||||
|
|
||||||
attn = test_rsrcs.attn
|
attn = test_rsrcs.attn
|
||||||
kv_cache = test_rsrcs.kv_cache
|
|
||||||
if cross_test_params is None:
|
if cross_test_params is None:
|
||||||
key = None
|
key = None
|
||||||
value = None
|
value = None
|
||||||
@ -762,8 +755,7 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
|
||||||
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
|
||||||
-1, test_pt.num_heads * test_pt.head_size)
|
-1, test_pt.num_heads * test_pt.head_size)
|
||||||
return attn.forward(reshaped_query, key, value, kv_cache,
|
return attn.forward(reshaped_query, key, value)
|
||||||
attn_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionMetadata, AttentionType
|
from vllm.attention import AttentionType
|
||||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
@ -153,15 +153,10 @@ class Attention(nn.Module):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
|
|
||||||
# directly, use `self.kv_cache` and
|
|
||||||
# `get_forward_context().attn_metadata` instead.
|
|
||||||
if self.calculate_kv_scales:
|
if self.calculate_kv_scales:
|
||||||
ctx_attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if ctx_attn_metadata.enable_kv_scales_calculation:
|
if attn_metadata.enable_kv_scales_calculation:
|
||||||
self.calc_kv_scales(key, value)
|
self.calc_kv_scales(key, value)
|
||||||
if self.use_output:
|
if self.use_output:
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@ -177,14 +172,14 @@ class Attention(nn.Module):
|
|||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
ctx_attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
self.impl.forward(self,
|
self.impl.forward(self,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
self_kv_cache,
|
self_kv_cache,
|
||||||
ctx_attn_metadata,
|
attn_metadata,
|
||||||
output=output)
|
output=output)
|
||||||
else:
|
else:
|
||||||
torch.ops.vllm.unified_attention_with_output(
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
@ -193,10 +188,10 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
ctx_attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
return self.impl.forward(self, query, key, value,
|
return self.impl.forward(self, query, key, value,
|
||||||
self_kv_cache, ctx_attn_metadata)
|
self_kv_cache, attn_metadata)
|
||||||
else:
|
else:
|
||||||
return torch.ops.vllm.unified_attention(
|
return torch.ops.vllm.unified_attention(
|
||||||
query, key, value, self.layer_name)
|
query, key, value, self.layer_name)
|
||||||
|
@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -130,14 +131,14 @@ class MambaMixer(CustomOp):
|
|||||||
) if use_rms_norm else None
|
) if use_rms_norm else None
|
||||||
|
|
||||||
def forward_native(self, hidden_states: torch.Tensor,
|
def forward_native(self, hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward_cuda(self, hidden_states: torch.Tensor,
|
def forward_cuda(self, hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams):
|
mamba_cache_params: MambaCacheParams):
|
||||||
|
|
||||||
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
||||||
|
@ -14,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -376,17 +377,16 @@ class MambaMixer2(CustomOp):
|
|||||||
eps=rms_norm_eps)
|
eps=rms_norm_eps)
|
||||||
|
|
||||||
def forward_native(self, hidden_states: torch.Tensor,
|
def forward_native(self, hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
conv_state: torch.Tensor, ssm_state: torch.Tensor):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor] = None,
|
sequence_idx: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
seq_len, _ = hidden_states.shape
|
seq_len, _ = hidden_states.shape
|
||||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||||
|
@ -160,7 +160,6 @@ def as_classification_model(cls: _T) -> _T:
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
# Lazy import
|
# Lazy import
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
@ -201,13 +200,10 @@ def as_classification_model(cls: _T) -> _T:
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = super().forward(input_ids, positions, kv_caches,
|
hidden_states = super().forward(input_ids, positions,
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
logits, _ = self.score(hidden_states)
|
logits, _ = self.score(hidden_states)
|
||||||
|
@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -283,13 +283,11 @@ class ArcticAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -336,16 +334,12 @@ class ArcticDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual_input = hidden_states
|
residual_input = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual_input + hidden_states
|
hidden_states = residual_input + hidden_states
|
||||||
|
|
||||||
@ -400,8 +394,6 @@ class ArcticModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -413,11 +405,8 @@ class ArcticModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
@ -458,13 +447,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
|||||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||||
from transformers.models.aria.processing_aria import AriaProcessor
|
from transformers.models.aria.processing_aria import AriaProcessor
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -626,8 +625,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -643,8 +640,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
hidden_states = self.language_model(
|
hidden_states = self.language_model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -20,13 +20,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -182,14 +182,12 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.W_pack(hidden_states)
|
qkv, _ = self.W_pack(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
if self.postion_embedding != "ALIBI":
|
if self.postion_embedding != "ALIBI":
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -232,8 +230,6 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -246,8 +242,6 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -301,8 +295,6 @@ class BaiChuanModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -316,13 +308,10 @@ class BaiChuanModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -379,13 +368,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Inference-only Bamba model."""
|
"""Inference-only Bamba model."""
|
||||||
# Added by the IBM Team, 2024
|
# Added by the IBM Team, 2024
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BambaConfig
|
from transformers import BambaConfig
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -107,7 +107,6 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor] = None,
|
sequence_idx: Optional[torch.Tensor] = None,
|
||||||
@ -120,8 +119,8 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mamba(hidden_states, attn_metadata,
|
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||||
mamba_cache_params, sequence_idx)
|
sequence_idx)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(
|
hidden_states, residual = self.pre_ff_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
@ -215,15 +214,13 @@ class BambaAttentionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -231,8 +228,6 @@ class BambaAttentionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -246,8 +241,6 @@ class BambaAttentionDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attention(
|
hidden_states = self.self_attention(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(
|
hidden_states, residual = self.pre_ff_layernorm(
|
||||||
@ -312,8 +305,6 @@ class BambaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
@ -323,6 +314,7 @@ class BambaModel(nn.Module):
|
|||||||
# proper continuous batching computation including
|
# proper continuous batching computation including
|
||||||
# chunked prefill
|
# chunked prefill
|
||||||
seq_idx = None
|
seq_idx = None
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||||
for i, (srt, end) in enumerate(
|
for i, (srt, end) in enumerate(
|
||||||
@ -348,9 +340,7 @@ class BambaModel(nn.Module):
|
|||||||
num_attn = 0
|
num_attn = 0
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
kv_cache = None
|
|
||||||
if isinstance(layer, BambaAttentionDecoderLayer):
|
if isinstance(layer, BambaAttentionDecoderLayer):
|
||||||
kv_cache = kv_caches[num_attn]
|
|
||||||
num_attn += 1
|
num_attn += 1
|
||||||
|
|
||||||
layer_mamba_cache_params = None
|
layer_mamba_cache_params = None
|
||||||
@ -361,8 +351,6 @@ class BambaModel(nn.Module):
|
|||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
mamba_cache_params=layer_mamba_cache_params,
|
||||||
sequence_idx=seq_idx,
|
sequence_idx=seq_idx,
|
||||||
@ -440,8 +428,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[KVCache],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -454,8 +440,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||||
*self._get_mamba_cache_shape())
|
*self._get_mamba_cache_shape())
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||||
attn_metadata, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
intermediate_tensors, inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -19,14 +19,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch BART model."""
|
"""PyTorch BART model."""
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BartConfig
|
from transformers import BartConfig
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -181,14 +181,13 @@ class BartEncoderAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
attn_type=AttentionType.ENCODER)
|
attn_type=AttentionType.ENCODER)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
@ -261,14 +260,13 @@ class BartDecoderSelfAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
attn_type=AttentionType.DECODER)
|
attn_type=AttentionType.DECODER)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
@ -344,8 +342,6 @@ class BartCrossAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
decoder_hidden_states: torch.Tensor,
|
decoder_hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@ -363,7 +359,7 @@ class BartCrossAttention(nn.Module):
|
|||||||
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
@ -411,23 +407,16 @@ class BartEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states
|
hidden_states
|
||||||
torch.Tensor of *encoder* input embeddings.
|
torch.Tensor of *encoder* input embeddings.
|
||||||
kv_cache:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Encoder layer output torch.Tensor
|
Encoder layer output torch.Tensor
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
@ -509,18 +498,12 @@ class BartDecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
decoder_hidden_states: torch.Tensor,
|
decoder_hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
decoder_hidden_states
|
decoder_hidden_states
|
||||||
torch.Tensor of *decoder* input embeddings.
|
torch.Tensor of *decoder* input embeddings.
|
||||||
kv_cache:
|
|
||||||
KV cache tensor
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
encoder_hidden_states
|
encoder_hidden_states
|
||||||
torch.Tensor of *encoder* input embeddings.
|
torch.Tensor of *encoder* input embeddings.
|
||||||
Returns:
|
Returns:
|
||||||
@ -529,9 +512,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
residual = decoder_hidden_states
|
residual = decoder_hidden_states
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
|
hidden_states = self.self_attn(hidden_states=decoder_hidden_states)
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
@ -542,8 +523,6 @@ class BartDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
hidden_states = self.encoder_attn(
|
hidden_states = self.encoder_attn(
|
||||||
decoder_hidden_states=hidden_states,
|
decoder_hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -609,9 +588,8 @@ class BartEncoder(nn.Module):
|
|||||||
|
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
def forward(self, input_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
positions: torch.Tensor) -> torch.Tensor:
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids
|
input_ids
|
||||||
@ -620,10 +598,6 @@ class BartEncoder(nn.Module):
|
|||||||
provide it.
|
provide it.
|
||||||
positions
|
positions
|
||||||
Positions of *encoder* input sequence tokens.
|
Positions of *encoder* input sequence tokens.
|
||||||
kv_caches:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Decoder output torch.Tensor
|
Decoder output torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -636,12 +610,8 @@ class BartEncoder(nn.Module):
|
|||||||
hidden_states = inputs_embeds + embed_pos
|
hidden_states = inputs_embeds + embed_pos
|
||||||
hidden_states = self.layernorm_embedding(hidden_states)
|
hidden_states = self.layernorm_embedding(hidden_states)
|
||||||
|
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for encoder_layer in self.layers:
|
||||||
hidden_states = encoder_layer(
|
hidden_states = encoder_layer(hidden_states=hidden_states)
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_caches[idx],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -693,9 +663,7 @@ class BartDecoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, decoder_input_ids: torch.Tensor,
|
def forward(self, decoder_input_ids: torch.Tensor,
|
||||||
decoder_positions: torch.Tensor,
|
decoder_positions: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.Tensor],
|
encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
decoder_input_ids
|
decoder_input_ids
|
||||||
@ -706,10 +674,6 @@ class BartDecoder(nn.Module):
|
|||||||
Positions of *decoder* input sequence tokens.
|
Positions of *decoder* input sequence tokens.
|
||||||
encoder_hidden_states:
|
encoder_hidden_states:
|
||||||
Tensor of encoder output embeddings
|
Tensor of encoder output embeddings
|
||||||
kv_caches:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Decoder output torch.Tensor
|
Decoder output torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -725,11 +689,9 @@ class BartDecoder(nn.Module):
|
|||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
hidden_states = decoder_layer(
|
hidden_states = decoder_layer(
|
||||||
decoder_hidden_states=hidden_states,
|
decoder_hidden_states=hidden_states,
|
||||||
kv_cache=kv_caches[idx],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -768,8 +730,7 @@ class BartModel(nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
encoder_input_ids: torch.Tensor,
|
encoder_input_ids: torch.Tensor,
|
||||||
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
encoder_positions: torch.Tensor) -> torch.Tensor:
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids
|
input_ids
|
||||||
@ -782,10 +743,6 @@ class BartModel(nn.Module):
|
|||||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||||
encoder_positions:
|
encoder_positions:
|
||||||
Positions of *encoder* input sequence tokens.
|
Positions of *encoder* input sequence tokens.
|
||||||
kv_caches:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Model output torch.Tensor
|
Model output torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -796,18 +753,14 @@ class BartModel(nn.Module):
|
|||||||
# Run encoder attention if a non-zero number of encoder tokens
|
# Run encoder attention if a non-zero number of encoder tokens
|
||||||
# are provided as input
|
# are provided as input
|
||||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||||
positions=encoder_positions,
|
positions=encoder_positions)
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
# decoder outputs consists of
|
# decoder outputs consists of
|
||||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
decoder_input_ids=input_ids,
|
decoder_input_ids=input_ids,
|
||||||
decoder_positions=positions,
|
decoder_positions=positions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states)
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
@ -845,8 +798,6 @@ class BartForConditionalGeneration(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
*,
|
*,
|
||||||
encoder_input_ids: torch.Tensor,
|
encoder_input_ids: torch.Tensor,
|
||||||
@ -863,15 +814,11 @@ class BartForConditionalGeneration(nn.Module):
|
|||||||
torch.Tensor of *encoder* input token ids.
|
torch.Tensor of *encoder* input token ids.
|
||||||
encoder_positions
|
encoder_positions
|
||||||
torch.Tensor of *encoder* position indices
|
torch.Tensor of *encoder* position indices
|
||||||
kv_caches:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Output torch.Tensor
|
Output torch.Tensor
|
||||||
"""
|
"""
|
||||||
return self.model(input_ids, positions, encoder_input_ids,
|
return self.model(input_ids, positions, encoder_input_ids,
|
||||||
encoder_positions, kv_caches, attn_metadata)
|
encoder_positions)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BertConfig
|
from transformers import BertConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -113,12 +114,9 @@ class BertEncoder(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for i in range(len(self.layer)):
|
for layer in self.layer:
|
||||||
layer = self.layer[i]
|
hidden_states = layer(hidden_states)
|
||||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -152,13 +150,8 @@ class BertLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.output")
|
prefix=f"{prefix}.output")
|
||||||
|
|
||||||
def forward(
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
self,
|
attn_output = self.attention(hidden_states)
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: Optional[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
):
|
|
||||||
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
|
|
||||||
intermediate_output = self.intermediate(attn_output)
|
intermediate_output = self.intermediate(attn_output)
|
||||||
output = self.output(intermediate_output, attn_output)
|
output = self.output(intermediate_output, attn_output)
|
||||||
return output
|
return output
|
||||||
@ -191,10 +184,8 @@ class BertAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
self_output = self.self(hidden_states, kv_cache, attn_metadata)
|
self_output = self.self(hidden_states)
|
||||||
return self.output(self_output, hidden_states)
|
return self.output(self_output, hidden_states)
|
||||||
|
|
||||||
|
|
||||||
@ -246,12 +237,10 @@ class BertSelfAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
output = self.attn(q, k, v, kv_cache, attn_metadata)
|
output = self.attn(q, k, v)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -343,8 +332,6 @@ class BertModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
@ -352,13 +339,14 @@ class BertModel(nn.Module):
|
|||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
seq_lens=attn_metadata.seq_lens_tensor,
|
seq_lens=attn_metadata.seq_lens_tensor,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
token_type_ids=token_type_ids)
|
token_type_ids=token_type_ids)
|
||||||
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
return self.encoder(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
@ -420,17 +408,13 @@ class BertEmbeddingModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.model(input_ids=input_ids,
|
return self.model(input_ids=input_ids,
|
||||||
position_ids=positions,
|
position_ids=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors)
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
def pooler(
|
def pooler(
|
||||||
self,
|
self,
|
||||||
@ -519,16 +503,12 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.bert(input_ids=input_ids,
|
return self.bert(input_ids=input_ids,
|
||||||
position_ids=positions,
|
position_ids=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
token_type_ids=token_type_ids)
|
token_type_ids=token_type_ids)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,7 +9,6 @@ import torch.nn as nn
|
|||||||
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
|
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
|
||||||
apply_chunking_to_forward)
|
apply_chunking_to_forward)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -658,8 +657,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -708,8 +705,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -18,13 +18,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BloomConfig
|
from transformers import BloomConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -126,13 +126,11 @@ class BloomAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
del position_ids # Unused.
|
del position_ids # Unused.
|
||||||
qkv, _ = self.query_key_value(hidden_states)
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.dense(attn_output)
|
output, _ = self.dense(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -193,8 +191,6 @@ class BloomBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Layer norm at the beginning of the transformer layer.
|
# Layer norm at the beginning of the transformer layer.
|
||||||
layernorm_output = self.input_layernorm(hidden_states)
|
layernorm_output = self.input_layernorm(hidden_states)
|
||||||
@ -209,8 +205,6 @@ class BloomBlock(nn.Module):
|
|||||||
attention_output = self.self_attention(
|
attention_output = self.self_attention(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=layernorm_output,
|
hidden_states=layernorm_output,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
attention_output = attention_output + residual
|
attention_output = attention_output + residual
|
||||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||||
@ -266,8 +260,6 @@ class BloomModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -279,14 +271,8 @@ class BloomModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
hidden_states = layer(position_ids, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
position_ids,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
@ -322,14 +308,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
|
||||||
Tuple, TypedDict, Union)
|
Tuple, TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||||||
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
||||||
ChameleonVQVAEConfig)
|
ChameleonVQVAEConfig)
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -310,15 +310,13 @@ class ChameleonAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self._apply_qk_norm(q, k)
|
q, k = self._apply_qk_norm(q, k)
|
||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -372,8 +370,6 @@ class ChameleonDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
@ -386,8 +382,6 @@ class ChameleonDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -447,8 +441,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
@ -456,8 +448,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -906,8 +896,6 @@ class ChameleonModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -921,13 +909,10 @@ class ChameleonModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -1028,8 +1013,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -1048,8 +1031,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.model(input_ids,
|
hidden_states = self.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -2,13 +2,13 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/THUDM/ChatGLM2-6B
|
# https://github.com/THUDM/ChatGLM2-6B
|
||||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -108,19 +108,11 @@ class GLMAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.query_key_value(hidden_states)
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
context_layer = self.attn(
|
context_layer = self.attn(q, k, v)
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
attn_output, _ = self.dense(context_layer)
|
attn_output, _ = self.dense(context_layer)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@ -215,8 +207,6 @@ class GLMBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# hidden_states: [num_tokens, h]
|
# hidden_states: [num_tokens, h]
|
||||||
# Layer norm at the beginning of the transformer layer.
|
# Layer norm at the beginning of the transformer layer.
|
||||||
@ -225,8 +215,6 @@ class GLMBlock(nn.Module):
|
|||||||
attention_output = self.self_attention(
|
attention_output = self.self_attention(
|
||||||
hidden_states=layernorm_output,
|
hidden_states=layernorm_output,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Residual connection.
|
# Residual connection.
|
||||||
@ -289,17 +277,10 @@ class GLMTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(hidden_states=hidden_states,
|
||||||
hidden_states = layer(
|
position_ids=position_ids)
|
||||||
hidden_states=hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
kv_cache=kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -350,8 +331,6 @@ class ChatGLMModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -369,8 +348,6 @@ class ChatGLMModel(nn.Module):
|
|||||||
hidden_states = self.encoder(
|
hidden_states = self.encoder(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_ids=positions,
|
position_ids=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -494,12 +471,9 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -21,14 +21,14 @@
|
|||||||
|
|
||||||
# This file is based on the LLama model definition file in transformers
|
# This file is based on the LLama model definition file in transformers
|
||||||
"""PyTorch Cohere model."""
|
"""PyTorch Cohere model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CohereConfig
|
from transformers import CohereConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -218,8 +218,6 @@ class CohereAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
@ -227,7 +225,7 @@ class CohereAttention(nn.Module):
|
|||||||
q, k = self._apply_qk_norm(q, k)
|
q, k = self._apply_qk_norm(q, k)
|
||||||
if self.v1 or self.sliding_window:
|
if self.v1 or self.sliding_window:
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -255,8 +253,6 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -265,8 +261,6 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
hidden_states_attention = self.self_attn(
|
hidden_states_attention = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states_mlp = self.mlp(hidden_states)
|
hidden_states_mlp = self.mlp(hidden_states)
|
||||||
# Add everything together
|
# Add everything together
|
||||||
@ -311,8 +305,6 @@ class CohereModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -326,13 +318,10 @@ class CohereModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -389,13 +378,10 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@ -230,15 +230,13 @@ class DbrxAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.Wqkv(hidden_states)
|
qkv, _ = self.Wqkv(hidden_states)
|
||||||
if self.clip_qkv is not None:
|
if self.clip_qkv is not None:
|
||||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
hidden_states, _ = self.out_proj(attn_output)
|
hidden_states, _ = self.out_proj(attn_output)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -265,16 +263,12 @@ class DbrxFusedNormAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm_1(hidden_states)
|
hidden_states = self.norm_1(hidden_states)
|
||||||
x = self.attn(
|
x = self.attn(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + x
|
hidden_states = residual + x
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -303,14 +297,10 @@ class DbrxBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states, residual = self.norm_attn_norm(
|
hidden_states, residual = self.norm_attn_norm(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = self.ffn(hidden_states)
|
hidden_states = self.ffn(hidden_states)
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
@ -353,8 +343,6 @@ class DbrxModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -366,14 +354,8 @@ class DbrxModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors
|
assert intermediate_tensors
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for block in self.blocks[self.start_layer:self.end_layer]:
|
||||||
block = self.blocks[i]
|
hidden_states = block(position_ids, hidden_states)
|
||||||
hidden_states = block(
|
|
||||||
position_ids,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.norm_f(hidden_states)
|
hidden_states = self.norm_f(hidden_states)
|
||||||
@ -415,14 +397,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Deepseek model."""
|
"""Inference-only Deepseek model."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -248,13 +248,11 @@ class DeepseekAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -309,8 +307,6 @@ class DeepseekDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -323,8 +319,6 @@ class DeepseekDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -370,8 +364,6 @@ class DeepseekModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -384,11 +376,8 @@ class DeepseekModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata, residual)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -425,13 +414,10 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -69,8 +68,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
previous_hidden_states: torch.Tensor,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
spec_step_index: int = 0,
|
spec_step_index: int = 0,
|
||||||
@ -88,8 +85,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
|
|
||||||
hidden_states, residual = self.mtp_block(positions=positions,
|
hidden_states, residual = self.mtp_block(positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
residual=None)
|
residual=None)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
return self.shared_head(hidden_states)
|
return self.shared_head(hidden_states)
|
||||||
@ -122,8 +117,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
previous_hidden_states: torch.Tensor,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
spec_step_idx: int = 0,
|
spec_step_idx: int = 0,
|
||||||
@ -131,8 +124,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches[spec_step_idx],
|
|
||||||
attn_metadata,
|
|
||||||
previous_hidden_states,
|
previous_hidden_states,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
spec_step_idx,
|
spec_step_idx,
|
||||||
@ -165,16 +156,14 @@ class DeepSeekMTP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
previous_hidden_states: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
spec_step_idx: int = 0,
|
spec_step_idx: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions,
|
||||||
attn_metadata, previous_hidden_states,
|
previous_hidden_states, inputs_embeds,
|
||||||
inputs_embeds, spec_step_idx)
|
spec_step_idx)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only DeepseekV2/DeepseekV3 model."""
|
"""Inference-only DeepseekV2/DeepseekV3 model."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import (get_pp_group,
|
||||||
@ -279,8 +279,6 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q = self.q_a_proj(hidden_states)[0]
|
q = self.q_a_proj(hidden_states)[0]
|
||||||
@ -313,7 +311,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
v = torch.nn.functional.pad(
|
v = torch.nn.functional.pad(
|
||||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
attn_output = attn_output.view(
|
attn_output = attn_output.view(
|
||||||
-1, self.num_local_heads,
|
-1, self.num_local_heads,
|
||||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||||
@ -451,8 +449,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
ckq = self.q_a_proj(hidden_states)[0]
|
ckq = self.q_a_proj(hidden_states)[0]
|
||||||
@ -462,8 +458,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||||
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe)
|
||||||
attn_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
@ -532,8 +527,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -546,8 +539,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -608,8 +599,6 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -624,11 +613,8 @@ class DeepseekV2Model(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states)
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata, residual)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -665,13 +651,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
@ -595,8 +594,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object):
|
**kwargs: object):
|
||||||
@ -614,8 +611,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -121,8 +120,6 @@ class EAGLE(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
previous_hidden_states: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
@ -140,8 +137,6 @@ class EAGLE(nn.Module):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -24,12 +24,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Exaone model compatible with HuggingFace weights."""
|
"""Inference-only Exaone model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -179,13 +179,11 @@ class ExaoneAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -225,14 +223,10 @@ class ExaoneBlockAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.attention(
|
return self.attention(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -288,8 +282,6 @@ class ExaoneDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -301,8 +293,6 @@ class ExaoneDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -365,8 +355,6 @@ class ExaoneModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -381,13 +369,10 @@ class ExaoneModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -471,14 +456,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
model_output = self.transformer(input_ids, positions, kv_caches,
|
model_output = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -20,14 +20,14 @@
|
|||||||
"""PyTorch Falcon model."""
|
"""PyTorch Falcon model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
from transformers import FalconConfig as HF_FalconConfig
|
from transformers import FalconConfig as HF_FalconConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -190,8 +190,6 @@ class FalconAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, bias = self.query_key_value(hidden_states)
|
qkv, bias = self.query_key_value(hidden_states)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
@ -199,7 +197,7 @@ class FalconAttention(nn.Module):
|
|||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
if self.use_rotary:
|
if self.use_rotary:
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
attn_output, bias = self.dense(attn_output)
|
attn_output, bias = self.dense(attn_output)
|
||||||
return attn_output, bias
|
return attn_output, bias
|
||||||
|
|
||||||
@ -291,8 +289,6 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -306,8 +302,6 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
attention_output, attention_bias = self.self_attention(
|
attention_output, attention_bias = self.self_attention(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=attention_layernorm_out,
|
hidden_states=attention_layernorm_out,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
if self.reduce_row_parallel_results and attention_bias is not None:
|
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||||
attention_output += attention_bias
|
attention_output += attention_bias
|
||||||
@ -384,8 +378,6 @@ class FalconModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -396,14 +388,8 @@ class FalconModel(nn.Module):
|
|||||||
hidden_states = self.get_input_embeddings(input_ids)
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
else:
|
else:
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
@ -450,14 +436,11 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -50,8 +49,7 @@ class Florence2LanguageModel(nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
encoder_input_ids: torch.Tensor,
|
encoder_input_ids: torch.Tensor,
|
||||||
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
encoder_positions: torch.Tensor) -> torch.Tensor:
|
||||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids
|
input_ids
|
||||||
@ -64,10 +62,6 @@ class Florence2LanguageModel(nn.Module):
|
|||||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||||
encoder_positions:
|
encoder_positions:
|
||||||
Positions of *encoder* input sequence tokens.
|
Positions of *encoder* input sequence tokens.
|
||||||
kv_caches:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Model output torch.Tensor
|
Model output torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -78,18 +72,14 @@ class Florence2LanguageModel(nn.Module):
|
|||||||
# Run encoder attention if a non-zero number of encoder tokens
|
# Run encoder attention if a non-zero number of encoder tokens
|
||||||
# are provided as input
|
# are provided as input
|
||||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||||
positions=encoder_positions,
|
positions=encoder_positions)
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
# decoder outputs consists of
|
# decoder outputs consists of
|
||||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
decoder_input_ids=input_ids,
|
decoder_input_ids=input_ids,
|
||||||
decoder_positions=positions,
|
decoder_positions=positions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states)
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
@ -122,8 +112,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
encoder_input_ids: torch.Tensor,
|
encoder_input_ids: torch.Tensor,
|
||||||
encoder_positions: torch.Tensor,
|
encoder_positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
@ -136,15 +124,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
|||||||
torch.Tensor of *encoder* input token ids.
|
torch.Tensor of *encoder* input token ids.
|
||||||
encoder_positions
|
encoder_positions
|
||||||
torch.Tensor of *encoder* position indices
|
torch.Tensor of *encoder* position indices
|
||||||
kv_caches:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Output torch.Tensor
|
Output torch.Tensor
|
||||||
"""
|
"""
|
||||||
return self.model(input_ids, positions, encoder_input_ids,
|
return self.model(input_ids, positions, encoder_input_ids,
|
||||||
encoder_positions, kv_caches, attn_metadata)
|
encoder_positions)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
@ -213,8 +197,6 @@ class Florence2ForConditionalGeneration(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
*,
|
*,
|
||||||
encoder_input_ids: torch.Tensor,
|
encoder_input_ids: torch.Tensor,
|
||||||
@ -231,15 +213,11 @@ class Florence2ForConditionalGeneration(nn.Module):
|
|||||||
torch.Tensor of *encoder* input token ids.
|
torch.Tensor of *encoder* input token ids.
|
||||||
encoder_positions
|
encoder_positions
|
||||||
torch.Tensor of *encoder* position indices
|
torch.Tensor of *encoder* position indices
|
||||||
kv_caches:
|
|
||||||
Layer-wise list of KV cache tensors
|
|
||||||
attn_metadata:
|
|
||||||
vLLM Attention metadata structure
|
|
||||||
Returns:
|
Returns:
|
||||||
Output torch.Tensor
|
Output torch.Tensor
|
||||||
"""
|
"""
|
||||||
return self.language_model(input_ids, positions, encoder_input_ids,
|
return self.language_model(input_ids, positions, encoder_input_ids,
|
||||||
encoder_positions, kv_caches, attn_metadata)
|
encoder_positions)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
|
@ -25,7 +25,6 @@ import torch.nn as nn
|
|||||||
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||||
FuyuProcessor)
|
FuyuProcessor)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
@ -351,8 +350,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -371,8 +368,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
hidden_states = self.language_model(
|
hidden_states = self.language_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -16,13 +16,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GemmaConfig
|
from transformers import GemmaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -183,13 +183,11 @@ class GemmaAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -233,8 +231,6 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -247,8 +243,6 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -298,8 +292,6 @@ class GemmaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -313,13 +305,10 @@ class GemmaModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -370,13 +359,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -15,13 +15,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Gemma2Config
|
from transformers import Gemma2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -164,13 +164,11 @@ class Gemma2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -220,8 +218,6 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -233,8 +229,6 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
|
||||||
@ -284,8 +278,6 @@ class Gemma2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -300,13 +292,10 @@ class Gemma2Model(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -415,13 +404,10 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
# https://github.com/THUDM/CogAgent
|
# https://github.com/THUDM/CogAgent
|
||||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import List, Literal, Mapping, Optional, TypedDict, Union
|
from typing import Literal, Mapping, Optional, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -15,7 +15,6 @@ from transformers import PreTrainedTokenizer, TensorType
|
|||||||
from transformers.image_utils import ImageInput
|
from transformers.image_utils import ImageInput
|
||||||
from transformers.tokenization_utils_base import TextInput
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -628,8 +627,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -645,8 +642,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
vision_embeddings)
|
vision_embeddings)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -18,13 +18,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
@ -92,12 +92,10 @@ class GPT2Attention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
attn_output, _ = self.c_proj(attn_output)
|
attn_output, _ = self.c_proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@ -164,16 +162,10 @@ class GPT2Block(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(hidden_states=hidden_states)
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
@ -222,8 +214,6 @@ class GPT2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor],
|
inputs_embeds: Optional[torch.Tensor],
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -236,11 +226,8 @@ class GPT2Model(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
hidden_states = layer(hidden_states)
|
||||||
hidden_states = layer(hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -279,14 +266,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -19,13 +19,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPTBigCodeConfig
|
from transformers import GPTBigCodeConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -101,8 +101,6 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
q, k, v = qkv.split(
|
q, k, v = qkv.split(
|
||||||
@ -112,7 +110,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
attn_output, _ = self.c_proj(attn_output)
|
attn_output, _ = self.c_proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@ -173,16 +171,10 @@ class GPTBigCodeBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(hidden_states=hidden_states, )
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
@ -234,8 +226,6 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -246,11 +236,8 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
hidden_states = layer(hidden_states)
|
||||||
hidden_states = layer(hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -302,14 +289,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -17,13 +17,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPTJConfig
|
from transformers import GPTJConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -104,13 +104,11 @@ class GPTJAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
attn_output, _ = self.out_proj(attn_output)
|
attn_output, _ = self.out_proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@ -167,16 +165,12 @@ class GPTJBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
mlp_output = self.mlp(hidden_states)
|
mlp_output = self.mlp(hidden_states)
|
||||||
hidden_states = attn_output + mlp_output + residual
|
hidden_states = attn_output + mlp_output + residual
|
||||||
@ -217,8 +211,6 @@ class GPTJModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -229,14 +221,8 @@ class GPTJModel(nn.Module):
|
|||||||
hidden_states = self.get_input_embeddings(input_ids)
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
else:
|
else:
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
hidden_states = layer(position_ids, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
position_ids,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
@ -273,14 +259,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -17,13 +17,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPTNeoXConfig
|
from transformers import GPTNeoXConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -104,13 +104,11 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.query_key_value(hidden_states)
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.dense(attn_output)
|
output, _ = self.dense(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -167,15 +165,11 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
attn_input = self.input_layernorm(hidden_states)
|
attn_input = self.input_layernorm(hidden_states)
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=attn_input,
|
hidden_states=attn_input,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
@ -230,8 +224,6 @@ class GPTNeoXModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -242,14 +234,8 @@ class GPTNeoXModel(nn.Module):
|
|||||||
hidden_states = self.get_input_embeddings(input_ids)
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
else:
|
else:
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(position_ids, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
position_ids,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
@ -285,14 +271,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
hidden_states = self.gpt_neox(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
|
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GraniteConfig
|
from transformers import GraniteConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -166,13 +166,11 @@ class GraniteAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -233,8 +231,6 @@ class GraniteDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -242,8 +238,6 @@ class GraniteDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -300,8 +294,6 @@ class GraniteModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -318,14 +310,8 @@ class GraniteModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -405,13 +391,10 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
model_output = self.model(input_ids, positions, kv_caches,
|
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GraniteMoe model."""
|
"""Inference-only GraniteMoe model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.granitemoe import GraniteMoeConfig
|
from transformers.models.granitemoe import GraniteMoeConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -173,13 +173,11 @@ class GraniteMoeAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -226,8 +224,6 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -235,8 +231,6 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -287,8 +281,6 @@ class GraniteMoeModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -303,11 +295,8 @@ class GraniteMoeModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -377,13 +366,10 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from array import array
|
from array import array
|
||||||
from typing import List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.attention.backends.xformers import XFormersImpl
|
from vllm.attention.backends.xformers import XFormersImpl
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.pooler import PoolerHead
|
from vllm.model_executor.layers.pooler import PoolerHead
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
@ -217,13 +217,12 @@ class GritLM(LlamaForCausalLM):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
|
||||||
# Change attention to non-causal for pooling tasks.
|
# Change attention to non-causal for pooling tasks.
|
||||||
if self.runner_type == "pooling":
|
if self.runner_type == "pooling":
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
assert attn_metadata.prefill_metadata.attn_bias is None
|
assert attn_metadata.prefill_metadata.attn_bias is None
|
||||||
attn_metadata.prefill_metadata.attn_bias = [
|
attn_metadata.prefill_metadata.attn_bias = [
|
||||||
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
|
||||||
@ -232,8 +231,6 @@ class GritLM(LlamaForCausalLM):
|
|||||||
return super().forward(
|
return super().forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ from torch import nn
|
|||||||
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
|
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
|
||||||
Idefics3Processor)
|
Idefics3Processor)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
@ -563,8 +562,6 @@ class Idefics3Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -572,8 +569,6 @@ class Idefics3Model(nn.Module):
|
|||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
@ -645,8 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -664,8 +657,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.model.text_model(input_ids,
|
hidden_states = self.model.text_model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
|
from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload,
|
||||||
overload, runtime_checkable)
|
runtime_checkable)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -11,7 +11,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils import supports_kw
|
from vllm.utils import supports_kw
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import PoolerOutput
|
from vllm.model_executor.layers.pooler import PoolerOutput
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
@ -46,8 +45,6 @@ class VllmModel(Protocol[T_co]):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: "AttentionMetadata",
|
|
||||||
) -> T_co:
|
) -> T_co:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
|
|||||||
if not callable(model_forward):
|
if not callable(model_forward):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
|
vllm_kws = ("input_ids", "positions")
|
||||||
missing_kws = tuple(kw for kw in vllm_kws
|
missing_kws = tuple(kw for kw in vllm_kws
|
||||||
if not supports_kw(model_forward, kw))
|
if not supports_kw(model_forward, kw))
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.wqkv(hidden_states)
|
qkv, _ = self.wqkv(hidden_states)
|
||||||
q, k, v = self.split_qkv(qkv)
|
q, k, v = self.split_qkv(qkv)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.wo(attn_output)
|
output, _ = self.wo(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.attention(
|
hidden_states = self.attention(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
hidden_states, residual = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
logits, _ = self.v_head(hidden_states)
|
logits, _ = self.v_head(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
visual_token_mask: Optional[torch.Tensor] = None,
|
visual_token_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.attention(
|
hidden_states = self.attention(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
visual_token_mask: Optional[torch.Tensor] = None,
|
visual_token_mask: Optional[torch.Tensor] = None,
|
||||||
@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
visual_token_mask=visual_token_mask,
|
visual_token_mask=visual_token_mask,
|
||||||
)
|
)
|
||||||
|
@ -17,7 +17,6 @@ import torchvision.transforms as T
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
forward_kwargs = {
|
forward_kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"positions": positions,
|
"positions": positions,
|
||||||
"kv_caches": kv_caches,
|
|
||||||
"attn_metadata": attn_metadata,
|
|
||||||
"intermediate_tensors": intermediate_tensors,
|
"intermediate_tensors": intermediate_tensors,
|
||||||
"inputs_embeds": inputs_embeds,
|
"inputs_embeds": inputs_embeds,
|
||||||
}
|
}
|
||||||
|
@ -21,12 +21,12 @@
|
|||||||
"""Inference-only Jais model compatible with HuggingFace weights."""
|
"""Inference-only Jais model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
attn_output, _ = self.c_proj(attn_output)
|
attn_output, _ = self.c_proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(hidden_states=hidden_states, )
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
@ -266,8 +258,6 @@ class JAISModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[IntermediateTensors, torch.Tensor]:
|
) -> Union[IntermediateTensors, torch.Tensor]:
|
||||||
@ -285,11 +275,8 @@ class JAISModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
hidden_states = layer(hidden_states)
|
||||||
hidden_states = layer(hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[IntermediateTensors, torch.Tensor]:
|
) -> Union[IntermediateTensors, torch.Tensor]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Inference-only Jamba model."""
|
"""Inference-only Jamba model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import JambaConfig
|
from transformers import JambaConfig
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mamba(hidden_states, attn_metadata,
|
hidden_states = self.mamba(hidden_states, mamba_cache_params)
|
||||||
mamba_cache_params)
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(
|
hidden_states, residual = self.pre_ff_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attention(
|
hidden_states = self.self_attention(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(
|
hidden_states, residual = self.pre_ff_layernorm(
|
||||||
@ -320,8 +311,6 @@ class JambaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
@ -339,12 +328,9 @@ class JambaModel(nn.Module):
|
|||||||
|
|
||||||
kv_cache_index = 0
|
kv_cache_index = 0
|
||||||
mamba_cache_index = 0
|
mamba_cache_index = 0
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
kv_cache = None
|
|
||||||
layer_mamba_cache_params = None
|
layer_mamba_cache_params = None
|
||||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||||
kv_cache = kv_caches[kv_cache_index]
|
|
||||||
kv_cache_index += 1
|
kv_cache_index += 1
|
||||||
if isinstance(layer, JambaMambaDecoderLayer):
|
if isinstance(layer, JambaMambaDecoderLayer):
|
||||||
current_state_layer = mamba_cache_index
|
current_state_layer = mamba_cache_index
|
||||||
@ -355,8 +341,6 @@ class JambaModel(nn.Module):
|
|||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=layer_mamba_cache_params)
|
mamba_cache_params=layer_mamba_cache_params)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[KVCache],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||||
attn_metadata, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
intermediate_tensors, inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
hidden_states = self.self_attn(positions=positions,
|
hidden_states = self.self_attn(positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states)
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata, residual)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
model_output = self.model(input_ids, positions, kv_caches,
|
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
|
|||||||
from transformers.models.llava import LlavaProcessor
|
from transformers.models.llava import LlavaProcessor
|
||||||
from transformers.models.pixtral import PixtralProcessor
|
from transformers.models.pixtral import PixtralProcessor
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
|
|||||||
get_anyres_image_grid_shape, unpad_image)
|
get_anyres_image_grid_shape, unpad_image)
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -10,7 +10,6 @@ import torch.nn as nn
|
|||||||
from transformers import (BatchFeature, LlavaNextVideoConfig,
|
from transformers import (BatchFeature, LlavaNextVideoConfig,
|
||||||
LlavaNextVideoProcessor)
|
LlavaNextVideoProcessor)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
|
|||||||
get_anyres_image_grid_shape, unpad_image)
|
get_anyres_image_grid_shape, unpad_image)
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""PyTorch MAMBA model."""
|
"""PyTorch MAMBA model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MambaConfig
|
from transformers import MambaConfig
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, residual = self.norm(hidden_states, residual)
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
hidden_states = self.mixer(hidden_states, mamba_cache_params)
|
||||||
mamba_cache_params)
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@ -125,7 +122,6 @@ class MambaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
@ -146,7 +142,6 @@ class MambaModel(nn.Module):
|
|||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||||
i - self.start_layer))
|
i - self.start_layer))
|
||||||
@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[KVCache],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||||
mamba_cache_params, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""PyTorch MAMBA2 model."""
|
"""PyTorch MAMBA2 model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
sequence_idx: Optional[torch.Tensor],
|
sequence_idx: Optional[torch.Tensor],
|
||||||
@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, residual = self.norm(hidden_states, residual)
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||||
mamba_cache_params, sequence_idx)
|
sequence_idx)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
mamba_cache_params: MambaCacheParams,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
|
|||||||
# proper continuous batching computation including
|
# proper continuous batching computation including
|
||||||
# chunked prefill
|
# chunked prefill
|
||||||
seq_idx = None
|
seq_idx = None
|
||||||
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||||
for i, (srt, end) in enumerate(
|
for i, (srt, end) in enumerate(
|
||||||
@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
|
|||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||||
i - self.start_layer),
|
i - self.start_layer),
|
||||||
@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[KVCache],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||||
mamba_cache_params, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -23,13 +23,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
|
|||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
q, k = q.to(orig_dtype), k.to(orig_dtype)
|
q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states * \
|
hidden_states = residual + hidden_states * \
|
||||||
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
|
(self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
|
||||||
@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q, _ = self.q_a_proj(hidden_states)
|
q, _ = self.q_a_proj(hidden_states)
|
||||||
q = self.q_a_layernorm(q)
|
q = self.q_a_layernorm(q)
|
||||||
@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
v, [0, self.qk_head_dim - self.v_head_dim],
|
v, [0, self.qk_head_dim - self.v_head_dim],
|
||||||
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
attn_output = attn_output.view(
|
attn_output = attn_output.view(
|
||||||
-1, self.num_local_heads,
|
-1, self.num_local_heads,
|
||||||
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
self.qk_head_dim)[..., :self.v_head_dim].reshape(
|
||||||
|
@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
|||||||
from transformers.models.whisper.modeling_whisper import (
|
from transformers.models.whisper.modeling_whisper import (
|
||||||
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
|
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||||
@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
output = self.llm.model(
|
output = self.llm.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=vlm_embeddings,
|
inputs_embeds=vlm_embeddings,
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,6 @@ from torch import nn
|
|||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
||||||
@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
output = self.llm.model(
|
output = self.llm.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=vlm_embeddings,
|
inputs_embeds=vlm_embeddings,
|
||||||
)
|
)
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Mixtral model."""
|
"""Inference-only Mixtral model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata, residual)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Mixtral model."""
|
"""Inference-only Mixtral model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -30,7 +30,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata, residual)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
|
|||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.selector import _Backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tp_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
|
|||||||
prefix: str = ""):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
model_parallel_size = get_tensor_model_parallel_world_size()
|
tensor_parallel_size = get_tp_group().world_size
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_heads = config.attention_heads
|
self.num_heads = config.attention_heads
|
||||||
self.head_dim = config.hidden_size // config.attention_heads
|
self.head_dim = config.hidden_size // config.attention_heads
|
||||||
self.num_local_heads = self.num_heads // model_parallel_size
|
self.num_local_heads = self.num_heads // tensor_parallel_size
|
||||||
self.q_size = self.num_local_heads * self.head_dim
|
self.q_size = self.num_local_heads * self.head_dim
|
||||||
self.kv_size = self.num_local_heads * self.head_dim
|
self.kv_size = self.num_local_heads * self.head_dim
|
||||||
|
|
||||||
@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model_parallel_size = get_tensor_model_parallel_world_size()
|
self.pipeline_parallel_rank = get_pp_group().rank_in_group
|
||||||
|
self.tensor_parallel_size = get_tp_group().world_size
|
||||||
self.num_heads = self.config.num_attention_heads
|
self.num_heads = self.config.num_attention_heads
|
||||||
self.num_local_heads = self.num_heads // self.model_parallel_size
|
self.num_local_heads = self.num_heads // self.tensor_parallel_size
|
||||||
self.num_key_value_heads = self.config.num_key_value_heads
|
self.num_key_value_heads = self.config.num_key_value_heads
|
||||||
self.num_local_key_value_heads = \
|
self.num_local_key_value_heads = \
|
||||||
self.num_key_value_heads // self.model_parallel_size
|
self.num_key_value_heads // self.tensor_parallel_size
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_dim = config.hidden_size // self.num_heads
|
self.head_dim = config.hidden_size // self.num_heads
|
||||||
@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||||
cross_attention_states: Optional[torch.Tensor],
|
cross_attention_states: Optional[torch.Tensor],
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv_dec, _ = self.qkv_proj(hidden_states)
|
qkv_dec, _ = self.qkv_proj(hidden_states)
|
||||||
q, _, _ = qkv_dec.split(
|
q, _, _ = qkv_dec.split(
|
||||||
@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
output = self._attention_with_mask(q, k, v, kv_cache,
|
output = self._attention_with_mask(q, k, v, attention_mask,
|
||||||
attention_mask,
|
kv_range_for_decode)
|
||||||
kv_range_for_decode,
|
|
||||||
attn_metadata)
|
|
||||||
else:
|
else:
|
||||||
output = self.attn(
|
output = self.attn(
|
||||||
q.view(-1, self.num_local_heads * self.head_dim), k, v,
|
q.view(-1, self.num_local_heads * self.head_dim), k, v)
|
||||||
kv_cache, attn_metadata)
|
|
||||||
out, _ = self.o_proj(output)
|
out, _ = self.o_proj(output)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
kv_range_for_decode: List[Tuple[int, int]],
|
kv_range_for_decode: List[Tuple[int, int]],
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
|
||||||
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
# Skip writing kv-cache for the initial profiling run.
|
# Skip writing kv-cache for the initial profiling run.
|
||||||
if len(kv_cache.shape) > 1:
|
if len(kv_cache.shape) > 1:
|
||||||
i = torch.ones(1, dtype=torch.float32)
|
i = torch.ones(1, dtype=torch.float32)
|
||||||
@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||||||
cross_attention_mask: torch.Tensor,
|
cross_attention_mask: torch.Tensor,
|
||||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||||
full_text_row_masked_out_mask: torch.Tensor,
|
full_text_row_masked_out_mask: torch.Tensor,
|
||||||
kv_cache: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||||||
attention_mask=cross_attention_mask,
|
attention_mask=cross_attention_mask,
|
||||||
kv_range_for_decode=kv_range_for_decode,
|
kv_range_for_decode=kv_range_for_decode,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = full_text_row_masked_out_mask * hidden_states
|
hidden_states = full_text_row_masked_out_mask * hidden_states
|
||||||
hidden_states = residual + self.cross_attn_attn_gate.tanh(
|
hidden_states = residual + self.cross_attn_attn_gate.tanh(
|
||||||
@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
|
|||||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
||||||
torch.Tensor]],
|
torch.Tensor]],
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
skip_cross_attention: bool,
|
skip_cross_attention: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
|
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
|
||||||
if not skip_cross_attention:
|
if not skip_cross_attention:
|
||||||
hidden_states = decoder_layer(
|
hidden_states = decoder_layer(
|
||||||
@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
|
|||||||
kv_range_for_decode=kv_range_for_decode,
|
kv_range_for_decode=kv_range_for_decode,
|
||||||
full_text_row_masked_out_mask=
|
full_text_row_masked_out_mask=
|
||||||
full_text_row_masked_out_mask,
|
full_text_row_masked_out_mask,
|
||||||
kv_cache=kv_caches[idx],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
elif isinstance(decoder_layer, LlamaDecoderLayer):
|
elif isinstance(decoder_layer, LlamaDecoderLayer):
|
||||||
hidden_states, residual = decoder_layer(
|
hidden_states, residual = decoder_layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_caches[idx],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
residual=None,
|
residual=None,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
|
|||||||
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
kv_range_for_decode: Optional[List[Tuple[int, int]]],
|
||||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
|
||||||
torch.Tensor]],
|
torch.Tensor]],
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
skip_cross_attention: bool,
|
skip_cross_attention: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
|
|||||||
cross_attention_mask=cross_attention_mask,
|
cross_attention_mask=cross_attention_mask,
|
||||||
kv_range_for_decode=kv_range_for_decode,
|
kv_range_for_decode=kv_range_for_decode,
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
skip_cross_attention=skip_cross_attention,
|
skip_cross_attention=skip_cross_attention,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.num_prefill_tokens > 0 and \
|
if attn_metadata.num_prefill_tokens > 0 and \
|
||||||
attn_metadata.num_decode_tokens > 0:
|
attn_metadata.num_decode_tokens > 0:
|
||||||
raise ValueError("Chunk prefill not supported")
|
raise ValueError("Chunk prefill not supported")
|
||||||
@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
cross_attention_mask=cross_attention_mask,
|
cross_attention_mask=cross_attention_mask,
|
||||||
kv_range_for_decode=kv_range_for_decode,
|
kv_range_for_decode=kv_range_for_decode,
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
skip_cross_attention=skip_cross_attention,
|
skip_cross_attention=skip_cross_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
|
|||||||
from transformers.image_utils import ImageInput
|
from transformers.image_utils import ImageInput
|
||||||
from transformers.tokenization_utils_base import TextInput
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
if self.q_norm is not None and self.k_norm is not None:
|
if self.q_norm is not None and self.k_norm is not None:
|
||||||
q, k = self._apply_qk_norm(q, k)
|
q, k = self._apply_qk_norm(q, k)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
|
|||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
# Apply blocks one-by-one.
|
# Apply blocks one-by-one.
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
positions: torch.LongTensor,
|
positions: torch.LongTensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
|
|
||||||
hidden_states = self.model(input_ids,
|
hidden_states = self.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -125,8 +125,6 @@ class MPTAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
del position_ids # unused.
|
del position_ids # unused.
|
||||||
qkv, _ = self.Wqkv(hidden_states)
|
qkv, _ = self.Wqkv(hidden_states)
|
||||||
@ -136,7 +134,7 @@ class MPTAttention(nn.Module):
|
|||||||
if self.qk_ln:
|
if self.qk_ln:
|
||||||
q = self.q_ln(q)
|
q = self.q_ln(q)
|
||||||
k = self.k_ln(k)
|
k = self.k_ln(k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -196,15 +194,11 @@ class MPTBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = self.norm_1(hidden_states)
|
x = self.norm_1(hidden_states)
|
||||||
x = self.attn(
|
x = self.attn(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states + x
|
hidden_states = hidden_states + x
|
||||||
x = self.norm_2(hidden_states)
|
x = self.norm_2(hidden_states)
|
||||||
@ -253,8 +247,6 @@ class MPTModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -267,14 +259,8 @@ class MPTModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for block in self.blocks[self.start_layer:self.end_layer]:
|
||||||
block = self.blocks[i]
|
hidden_states = block(position_ids, hidden_states)
|
||||||
hidden_states = block(
|
|
||||||
position_ids,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.norm_f(hidden_states)
|
hidden_states = self.norm_f(hidden_states)
|
||||||
@ -306,14 +292,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -27,7 +27,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -204,13 +204,11 @@ class NemotronAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -269,8 +267,6 @@ class NemotronDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -283,8 +279,6 @@ class NemotronDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -343,8 +337,6 @@ class NemotronModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -359,15 +351,8 @@ class NemotronModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
hidden_states, residual = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -444,13 +429,10 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
model_output = self.model(input_ids, positions, kv_caches,
|
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import OlmoConfig
|
from transformers import OlmoConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -119,15 +119,13 @@ class OlmoAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
if self.clip_qkv is not None:
|
if self.clip_qkv is not None:
|
||||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -212,14 +210,11 @@ class OlmoDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Attention block.
|
# Attention block.
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
hidden_states = self.self_attn(positions, hidden_states, kv_cache,
|
hidden_states = self.self_attn(positions, hidden_states)
|
||||||
attn_metadata)
|
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
# MLP block.
|
# MLP block.
|
||||||
@ -263,8 +258,6 @@ class OlmoModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -281,14 +274,9 @@ class OlmoModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
# Apply blocks one-by-one.
|
# Apply blocks one-by-one.
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
# shape: (batch_size, seq_len, d_model)
|
# shape: (batch_size, seq_len, d_model)
|
||||||
hidden_states = self.layers[i](
|
hidden_states = layer(positions, hidden_states)
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -332,16 +320,12 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -24,12 +24,12 @@
|
|||||||
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
|
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Iterable, List, Optional, Tuple, Union
|
from typing import Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
|
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
|
||||||
@ -153,14 +153,12 @@ class Olmo2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self._apply_qk_norm(q, k)
|
q, k = self._apply_qk_norm(q, k)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -239,13 +237,10 @@ class Olmo2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Attention block.
|
# Attention block.
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.self_attn(positions, hidden_states, kv_cache,
|
hidden_states = self.self_attn(positions, hidden_states)
|
||||||
attn_metadata)
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
@ -287,8 +282,6 @@ class Olmo2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
"""
|
"""
|
||||||
@ -307,14 +300,9 @@ class Olmo2Model(nn.Module):
|
|||||||
assert isinstance(hidden_states, torch.Tensor)
|
assert isinstance(hidden_states, torch.Tensor)
|
||||||
|
|
||||||
# Apply blocks one-by-one.
|
# Apply blocks one-by-one.
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
# shape: (batch_size, seq_len, d_model)
|
# shape: (batch_size, seq_len, d_model)
|
||||||
hidden_states = self.layers[i](
|
hidden_states = layer(positions, hidden_states)
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -357,15 +345,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -12,13 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -168,14 +168,12 @@ class OlmoeAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
|
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -222,8 +220,6 @@ class OlmoeDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -237,8 +233,6 @@ class OlmoeDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -283,8 +277,6 @@ class OlmoeModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -299,13 +291,10 @@ class OlmoeModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -347,13 +336,10 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -18,13 +18,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only OPT model compatible with HuggingFace weights."""
|
"""Inference-only OPT model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import OPTConfig
|
from transformers import OPTConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -107,12 +107,10 @@ class OPTAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -164,17 +162,13 @@ class OPTDecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||||
if self.do_layer_norm_before:
|
if self.do_layer_norm_before:
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
# 350m applies layer norm AFTER attention
|
# 350m applies layer norm AFTER attention
|
||||||
if not self.do_layer_norm_before:
|
if not self.do_layer_norm_before:
|
||||||
@ -261,8 +255,6 @@ class OPTDecoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -277,11 +269,8 @@ class OPTDecoder(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(hidden_states)
|
||||||
hidden_states = layer(hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -317,15 +306,11 @@ class OPTModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
return self.decoder(input_ids,
|
return self.decoder(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
@ -362,13 +347,10 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -5,13 +5,13 @@
|
|||||||
# Copyright (c) OrionStar Inc.
|
# Copyright (c) OrionStar Inc.
|
||||||
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
|
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
|
||||||
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
|
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -136,13 +136,11 @@ class OrionAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -189,8 +187,6 @@ class OrionDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -198,8 +194,6 @@ class OrionDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@ -247,8 +241,6 @@ class OrionModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -260,14 +252,8 @@ class OrionModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -303,13 +289,10 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PaliGemmaConfig
|
from transformers import PaliGemmaConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||||
InputContext, token_inputs)
|
InputContext, token_inputs)
|
||||||
@ -288,8 +287,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
|
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
|
||||||
@ -306,8 +303,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -21,13 +21,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only persimmon model compatible with HuggingFace weights."""
|
"""Inference-only persimmon model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PersimmonConfig
|
from transformers import PersimmonConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -142,8 +142,6 @@ class PersimmonAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [seq_length, 3 x hidden_size]
|
# [seq_length, 3 x hidden_size]
|
||||||
qkv, _ = self.query_key_value(hidden_states)
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
@ -161,7 +159,7 @@ class PersimmonAttention(nn.Module):
|
|||||||
k = self._merge_heads(k)
|
k = self._merge_heads(k)
|
||||||
|
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.dense(attn_output)
|
output, _ = self.dense(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -189,8 +187,6 @@ class PersimmonDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -200,8 +196,6 @@ class PersimmonDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
@ -248,8 +242,6 @@ class PersimmonModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -261,13 +253,8 @@ class PersimmonModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
hidden_states = self.layers[i](
|
hidden_states = layer(positions, hidden_states)
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.final_layernorm(hidden_states)
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
@ -298,16 +285,12 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -36,13 +36,13 @@
|
|||||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PhiConfig
|
from transformers import PhiConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -126,13 +126,11 @@ class PhiAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.dense(attn_output)
|
output, _ = self.dense(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -186,16 +184,12 @@ class PhiLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
attn_outputs = self.self_attn(
|
attn_outputs = self.self_attn(
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||||
@ -234,8 +228,6 @@ class PhiModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -247,14 +239,8 @@ class PhiModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
@ -304,13 +290,10 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@ -231,8 +231,6 @@ class Phi3SmallSelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[Tuple[torch.Tensor]]]:
|
Optional[Tuple[torch.Tensor]]]:
|
||||||
qkv, _ = self.query_key_value(hidden_states)
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
@ -248,7 +246,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
|||||||
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
|
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
|
||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.dense(attn_output)
|
output, _ = self.dense(attn_output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -282,8 +280,6 @@ class Phi3SmallDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -291,8 +287,6 @@ class Phi3SmallDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
@ -338,8 +332,6 @@ class Phi3SmallModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
positions: Optional[torch.LongTensor],
|
positions: Optional[torch.LongTensor],
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor],
|
inputs_embeds: Optional[torch.Tensor],
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -354,14 +346,8 @@ class Phi3SmallModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors
|
assert intermediate_tensors
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.final_layernorm(hidden_states)
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
@ -438,16 +424,12 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
positions: Optional[torch.LongTensor],
|
positions: Optional[torch.LongTensor],
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
output_hidden_states = self.model(
|
output_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -23,7 +23,6 @@ import torch.nn as nn
|
|||||||
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
||||||
ProcessorMixin)
|
ProcessorMixin)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -672,8 +671,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object):
|
**kwargs: object):
|
||||||
@ -691,8 +688,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -22,13 +22,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only PhiMoE model."""
|
"""Inference-only PhiMoE model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -357,13 +357,11 @@ class PhiMoEAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -410,8 +408,6 @@ class PhiMoEDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -422,8 +418,6 @@ class PhiMoEDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
@ -478,8 +472,6 @@ class PhiMoEModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -494,13 +486,10 @@ class PhiMoEModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -571,13 +560,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
|
|||||||
from transformers.models.pixtral.modeling_pixtral import (
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||||
@ -270,8 +269,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -291,8 +288,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
@ -15,13 +15,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
from typing import Iterable, Mapping, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||||
@ -181,8 +180,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
|
@ -6,13 +6,13 @@
|
|||||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||||
"""Inference-only QWen model compatible with HuggingFace weights."""
|
"""Inference-only QWen model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -124,13 +124,11 @@ class QWenAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.c_proj(attn_output)
|
output, _ = self.c_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -168,8 +166,6 @@ class QWenBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -181,8 +177,6 @@ class QWenBlock(nn.Module):
|
|||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -225,8 +219,6 @@ class QWenModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -241,13 +233,10 @@ class QWenModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.h[self.start_layer:self.end_layer]:
|
||||||
layer = self.h[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -373,12 +362,9 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -23,13 +23,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Qwen2Config
|
from transformers import Qwen2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -170,13 +170,11 @@ class Qwen2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -233,8 +231,6 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -247,8 +243,6 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -328,8 +322,6 @@ class Qwen2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -343,13 +335,10 @@ class Qwen2Model(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -468,13 +457,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -553,12 +539,9 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
return self.model(input_ids, positions, intermediate_tensors)
|
||||||
intermediate_tensors)
|
|
||||||
|
|
||||||
def pooler(
|
def pooler(
|
||||||
self,
|
self,
|
||||||
|
@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
|||||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@ -992,8 +991,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -1047,8 +1044,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -22,8 +22,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict,
|
||||||
TypedDict, Union)
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -33,7 +33,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
|||||||
Qwen2AudioProcessor)
|
Qwen2AudioProcessor)
|
||||||
from transformers.models.whisper import WhisperFeatureExtractor
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
@ -380,8 +379,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -400,8 +397,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -23,14 +23,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import (get_pp_group,
|
||||||
@ -232,13 +232,11 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -296,8 +294,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -310,8 +306,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -358,8 +352,6 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -373,11 +365,8 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata, residual)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -416,13 +405,10 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -5,12 +5,11 @@
|
|||||||
# Copyright 2024 The Qwen team.
|
# Copyright 2024 The Qwen team.
|
||||||
# Copyright 2023 The vLLM team.
|
# Copyright 2023 The vLLM team.
|
||||||
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
logits, _ = self.score(hidden_states)
|
logits, _ = self.score(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
@ -24,8 +24,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||||
from functools import cached_property, partial
|
from functools import cached_property, partial
|
||||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
|
||||||
Set, Tuple, Type, TypedDict, Union)
|
Tuple, Type, TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
|||||||
Qwen2VLConfig, Qwen2VLVisionConfig)
|
Qwen2VLConfig, Qwen2VLVisionConfig)
|
||||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
|
|||||||
from transformers.image_utils import ImageInput
|
from transformers.image_utils import ImageInput
|
||||||
from transformers.tokenization_utils_base import TextInput
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
|||||||
vision_embeddings)
|
vision_embeddings)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
attn_metadata, intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import RobertaConfig
|
from transformers import RobertaConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.roberta(input_ids=input_ids,
|
return self.roberta(input_ids=input_ids,
|
||||||
position_ids=positions,
|
position_ids=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
token_type_ids=token_type_ids)
|
token_type_ids=token_type_ids)
|
||||||
|
@ -23,13 +23,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Solar model compatible with HuggingFace weights."""
|
"""Inference-only Solar model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -172,13 +172,11 @@ class SolarAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -315,8 +309,6 @@ class SolarModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -357,8 +349,6 @@ class SolarModel(nn.Module):
|
|||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
model_output = self.model(input_ids, positions, kv_caches,
|
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
|
@ -20,13 +20,13 @@
|
|||||||
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
||||||
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
||||||
model compatible with HuggingFace weights."""
|
model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import StableLmConfig
|
from transformers import StableLmConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states, residual = layer(positions, hidden_states)
|
||||||
hidden_states, residual = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -19,13 +19,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Starcoder2 model."""
|
""" PyTorch Starcoder2 model."""
|
||||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Starcoder2Config
|
from transformers import Starcoder2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||||
layer = self.layers[i]
|
hidden_states = layer(positions, hidden_states)
|
||||||
hidden_states = layer(positions, hidden_states,
|
|
||||||
kv_caches[i - self.start_layer],
|
|
||||||
attn_metadata)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
attn_metadata, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ from torch import nn
|
|||||||
from transformers import AutoModel, PreTrainedModel
|
from transformers import AutoModel, PreTrainedModel
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.utils import divide
|
from vllm.distributed.utils import divide
|
||||||
@ -59,7 +59,6 @@ def vllm_flash_attention_forward(
|
|||||||
# Transformers kwargs
|
# Transformers kwargs
|
||||||
scaling: Optional[float] = None,
|
scaling: Optional[float] = None,
|
||||||
# vLLM kwargs
|
# vLLM kwargs
|
||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
|
||||||
attention_instances: Optional[list[Attention]] = None,
|
attention_instances: Optional[list[Attention]] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self_attn = attention_instances[module.layer_idx]
|
self_attn = attention_instances[module.layer_idx]
|
||||||
@ -68,12 +67,7 @@ def vllm_flash_attention_forward(
|
|||||||
hidden = query.shape[-2]
|
hidden = query.shape[-2]
|
||||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||||
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
|
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
|
||||||
return self_attn.forward(
|
return self_attn.forward(query, key, value), None
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache=None, # argument not used
|
|
||||||
attn_metadata=attn_metadata), None
|
|
||||||
|
|
||||||
|
|
||||||
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
||||||
@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: list[torch.Tensor], # argument not used
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant):
|
|||||||
input_ids[None, ...],
|
input_ids[None, ...],
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
position_ids=positions[None, ...],
|
position_ids=positions[None, ...],
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
attention_instances=self.attention_instances,
|
attention_instances=self.attention_instances,
|
||||||
return_dict=False)[0][0, ...] # we remove batch dimension for now
|
return_dict=False)[0][0, ...] # we remove batch dimension for now
|
||||||
|
@ -4,8 +4,8 @@
|
|||||||
"""PyTorch Ultravox model."""
|
"""PyTorch Ultravox model."""
|
||||||
import math
|
import math
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||||
Tuple, TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
|||||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
|
|
||||||
# TODO(ywang96): remove this block after v0 is deprecated.
|
# TODO(ywang96): remove this block after v0 is deprecated.
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
merge_multimodal_embeddings_from_map(
|
merge_multimodal_embeddings_from_map(
|
||||||
inputs_embeds, multimodal_embeddings,
|
inputs_embeds, multimodal_embeddings,
|
||||||
attn_metadata.multi_modal_placeholder_index_maps["audio"])
|
attn_metadata.multi_modal_placeholder_index_maps["audio"])
|
||||||
@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[torch.Tensor] = None,
|
intermediate_tensors: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
|
||||||
# TODO(ywang96): remove attn_metadata from get_input_embeddings
|
|
||||||
# after v0 is deprecated
|
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
multimodal_embeddings,
|
multimodal_embeddings)
|
||||||
attn_metadata)
|
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
|||||||
WhisperProcessor)
|
WhisperProcessor)
|
||||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
):
|
):
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
|
|
||||||
@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.Tensor],
|
encoder_hidden_states: Optional[torch.Tensor],
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
):
|
):
|
||||||
q, _ = self.q_proj(hidden_states)
|
q, _ = self.q_proj(hidden_states)
|
||||||
|
|
||||||
@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
|
|||||||
else:
|
else:
|
||||||
k = v = None
|
k = v = None
|
||||||
|
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(q, k, v)
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
|
|
||||||
@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.Tensor],
|
encoder_hidden_states: Optional[torch.Tensor],
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata)
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.encoder_attn(
|
hidden_states = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
|
|||||||
self.embed_positions.weight.copy_(
|
self.embed_positions.weight.copy_(
|
||||||
sinusoids(*self.embed_positions.weight.shape))
|
sinusoids(*self.embed_positions.weight.shape))
|
||||||
|
|
||||||
def forward(
|
def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
|
||||||
self,
|
|
||||||
input_features: Union[torch.Tensor, List[torch.Tensor]],
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
):
|
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
for features in input_features:
|
for features in input_features:
|
||||||
embeds = nn.functional.gelu(self.conv1(features))
|
embeds = nn.functional.gelu(self.conv1(features))
|
||||||
@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
|
|||||||
hidden_states.append(embeds)
|
hidden_states.append(embeds)
|
||||||
hidden_states = torch.cat(hidden_states)
|
hidden_states = torch.cat(hidden_states)
|
||||||
|
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for encoder_layer in self.layers:
|
||||||
hidden_states = encoder_layer(
|
hidden_states = encoder_layer(hidden_states)
|
||||||
hidden_states,
|
|
||||||
kv_cache=kv_caches[idx],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
|
|||||||
input_ids,
|
input_ids,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.Tensor],
|
encoder_hidden_states: Optional[torch.Tensor],
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
):
|
):
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
positions = self.embed_positions(positions)
|
positions = self.embed_positions(positions)
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
hidden_states = decoder_layer(
|
hidden_states = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
kv_cache=kv_caches[idx],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
|
|||||||
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
encoder_outputs = self.get_encoder_outputs(
|
encoder_outputs = self.get_encoder_outputs(input_features)
|
||||||
input_features,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
encoder_hidden_states=encoder_outputs,
|
encoder_hidden_states=encoder_outputs,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
def get_encoder_outputs(
|
def get_encoder_outputs(
|
||||||
self,
|
self,
|
||||||
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
if input_features is None:
|
if input_features is None:
|
||||||
return None
|
return None
|
||||||
return self.encoder(
|
return self.encoder(input_features)
|
||||||
input_features,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
input_features=audio_input["input_features"],
|
input_features=audio_input["input_features"],
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||||
self,
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
**kwargs,
|
|
||||||
) -> Optional[NestedTensors]:
|
|
||||||
# TODO: This method does not obey the interface for SupportsMultiModal.
|
# TODO: This method does not obey the interface for SupportsMultiModal.
|
||||||
# Refactor this once encoder/decoder support is implemented in V1.
|
# Refactor this once encoder/decoder support is implemented in V1.
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
return self.model.get_encoder_outputs(
|
return self.model.get_encoder_outputs(audio_input["input_features"])
|
||||||
audio_input["input_features"],
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO: This method just returns the decoder sequence embeddings since
|
# TODO: This method just returns the decoder sequence embeddings since
|
||||||
# Whisper does not have encoder text tokens. Refactor this once
|
# Whisper does not have encoder text tokens. Refactor this once
|
||||||
|
@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=model_input.attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
|
@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=self.kv_caches,
|
|
||||||
attn_metadata=None,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
model = self.model
|
model = self.model
|
||||||
if kv_caches is None:
|
|
||||||
kv_caches = self.kv_caches
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
hidden_states = model(
|
hidden_states = model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=None,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
|
||||||
# it by reference, rather by specializing on the value `None`.
|
|
||||||
# the `dtype` argument does not matter, and we use `float32` as
|
|
||||||
# a placeholder (it has wide hardware support).
|
|
||||||
# it is important to create tensors inside the loop, rather than
|
|
||||||
# multiplying the list, to avoid Dynamo from treating them as
|
|
||||||
# tensor aliasing.
|
|
||||||
dummy_kv_caches = [
|
|
||||||
torch.tensor((), dtype=torch.float32, device=self.device)
|
|
||||||
for _ in range(self.num_attn_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
# TODO: handle encoder-decoder models once we support them.
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
||||||
@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with self.maybe_profile_with_lora(self.lora_config,
|
with self.maybe_profile_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||||
dummy_kv_caches)
|
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
hidden_states = hidden_states[logit_indices]
|
hidden_states = hidden_states[logit_indices]
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
@ -13,11 +13,10 @@ import torch.nn as nn
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
@ -623,7 +622,6 @@ class TPUModelRunner:
|
|||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
selected_token_ids = self.model(prompt_data.input_tokens,
|
selected_token_ids = self.model(prompt_data.input_tokens,
|
||||||
prompt_data.input_positions,
|
prompt_data.input_positions,
|
||||||
prompt_data.attn_metadata,
|
|
||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
# In parallel to TPU execution, prepare the next iteration
|
# In parallel to TPU execution, prepare the next iteration
|
||||||
@ -662,7 +660,6 @@ class TPUModelRunner:
|
|||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
selected_token_ids = self.model(decode_data.input_tokens,
|
selected_token_ids = self.model(decode_data.input_tokens,
|
||||||
decode_data.input_positions,
|
decode_data.input_positions,
|
||||||
decode_data.attn_metadata,
|
|
||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
# Transfer sampled tokens from TPU to CPU
|
# Transfer sampled tokens from TPU to CPU
|
||||||
@ -839,7 +836,7 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
self.model(token_ids, position_ids, attn_metadata, kv_caches)
|
self.model(token_ids, position_ids, kv_caches)
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""Compile the model."""
|
"""Compile the model."""
|
||||||
@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module):
|
|||||||
self,
|
self,
|
||||||
token_ids: torch.Tensor,
|
token_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model and samples the next token.
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
token_ids: The input token IDs of shape [batch_size, seq_len].
|
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||||
position_ids: The input position IDs of shape [batch_size, seq_len].
|
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||||
attn_metadata: The Pallas attention metadata.
|
|
||||||
input_lens: The actual input lengths of shape [batch_size].
|
input_lens: The actual input lengths of shape [batch_size].
|
||||||
t: The sampling temperature of shape [batch_size].
|
t: The sampling temperature of shape [batch_size].
|
||||||
p: The top-p probability of shape [batch_size].
|
p: The top-p probability of shape [batch_size].
|
||||||
@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module):
|
|||||||
memory profiling at initialization.
|
memory profiling at initialization.
|
||||||
"""
|
"""
|
||||||
# Skip this in memory profiling at initialization.
|
# Skip this in memory profiling at initialization.
|
||||||
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
|
if kv_caches[0][0].numel() > 0:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||||
@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module):
|
|||||||
attn_metadata.slot_mapping = slot_mapping
|
attn_metadata.slot_mapping = slot_mapping
|
||||||
|
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(token_ids, position_ids)
|
||||||
token_ids,
|
|
||||||
position_ids,
|
|
||||||
kv_caches,
|
|
||||||
attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.flatten(0, 1)
|
hidden_states = hidden_states.flatten(0, 1)
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner(
|
|||||||
model_input.encoder_input_tokens,
|
model_input.encoder_input_tokens,
|
||||||
"encoder_positions":
|
"encoder_positions":
|
||||||
model_input.encoder_input_positions,
|
model_input.encoder_input_positions,
|
||||||
"kv_caches":
|
|
||||||
kv_caches,
|
|
||||||
"attn_metadata":
|
|
||||||
model_input.attn_metadata,
|
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||||
device=self.device),
|
device=self.device),
|
||||||
"intermediate_tensors":
|
"intermediate_tensors":
|
||||||
|
@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=model_input.attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**execute_model_kwargs,
|
**execute_model_kwargs,
|
||||||
**multimodal_kwargs,
|
**multimodal_kwargs,
|
||||||
|
@ -41,16 +41,6 @@ class CPUPoolingModelRunner(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"CPU worker does not support multi-step execution.")
|
"CPU worker does not support multi-step execution.")
|
||||||
|
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
|
||||||
# it by reference, rather by specializing on the value ``None``.
|
|
||||||
# the `dtype` argument does not matter, and we use `float32` as
|
|
||||||
# a placeholder (it has wide hardware support).
|
|
||||||
kv_caches = [
|
|
||||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
cross_enc_kwargs = {}
|
cross_enc_kwargs = {}
|
||||||
if model_input.token_type_ids is not None:
|
if model_input.token_type_ids is not None:
|
||||||
@ -60,10 +50,6 @@ class CPUPoolingModelRunner(
|
|||||||
model_input.input_tokens,
|
model_input.input_tokens,
|
||||||
"positions":
|
"positions":
|
||||||
model_input.input_positions,
|
model_input.input_positions,
|
||||||
"kv_caches":
|
|
||||||
kv_caches,
|
|
||||||
"attn_metadata":
|
|
||||||
model_input.attn_metadata,
|
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||||
device=self.device),
|
device=self.device),
|
||||||
**cross_enc_kwargs,
|
**cross_enc_kwargs,
|
||||||
|
@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
encoder_input_ids=model_input.encoder_input_tokens,
|
encoder_input_ids=model_input.encoder_input_tokens,
|
||||||
encoder_positions=model_input.encoder_input_positions,
|
encoder_positions=model_input.encoder_input_positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=model_input.attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
or encoder_dummy_data.multi_modal_placeholders)
|
or encoder_dummy_data.multi_modal_placeholders)
|
||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
|
|
||||||
# Run the model with the dummy inputs.
|
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
|
||||||
# it by reference, rather by specializing on the value ``None``.
|
|
||||||
# the `dtype` argument does not matter, and we use `float32` as
|
|
||||||
# a placeholder (it has wide hardware support).
|
|
||||||
kv_caches = [
|
|
||||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
model_input = self.prepare_model_input(
|
model_input = self.prepare_model_input(
|
||||||
seqs, finished_requests_ids=finished_requests_ids)
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
self.execute_model(model_input, None, intermediate_tensors)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -384,11 +384,12 @@ class HpuModelAdapter:
|
|||||||
if 'virtual_engine' in kwargs:
|
if 'virtual_engine' in kwargs:
|
||||||
virtual_engine = kwargs.pop('virtual_engine')
|
virtual_engine = kwargs.pop('virtual_engine')
|
||||||
input_ids = kwargs['input_ids']
|
input_ids = kwargs['input_ids']
|
||||||
kwargs['attn_metadata'] = self._update_metadata(
|
attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'),
|
||||||
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
|
input_ids.size(0),
|
||||||
|
input_ids.size(1),
|
||||||
input_ids.device, self.dtype)
|
input_ids.device, self.dtype)
|
||||||
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
|
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
|
||||||
with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
|
with set_forward_context(attn_metadata, self.vllm_config,
|
||||||
virtual_engine):
|
virtual_engine):
|
||||||
hidden_states = self.model(*args, **kwargs)
|
hidden_states = self.model(*args, **kwargs)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
|
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
|
||||||
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
|
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
|
||||||
self.scheduler_config.max_num_seqs)
|
self.scheduler_config.max_num_seqs)
|
||||||
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
|
self.warmup_scenario(max_batch_size, max_seq_len, True, False, True)
|
||||||
False, True)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def warmup_scenario(self,
|
def warmup_scenario(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
seq_len,
|
seq_len,
|
||||||
is_prompt,
|
is_prompt,
|
||||||
kv_caches,
|
|
||||||
is_pt_profiler_run=False,
|
is_pt_profiler_run=False,
|
||||||
is_lora_profile_run=False) -> None:
|
is_lora_profile_run=False) -> None:
|
||||||
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
||||||
@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
profiler.start()
|
profiler.start()
|
||||||
for _ in range(times):
|
for _ in range(times):
|
||||||
inputs = self.prepare_model_input(seqs)
|
inputs = self.prepare_model_input(seqs)
|
||||||
self.execute_model(inputs, kv_caches, warmup_mode=True)
|
self.execute_model(inputs, None, warmup_mode=True)
|
||||||
torch.hpu.synchronize()
|
torch.hpu.synchronize()
|
||||||
if profiler:
|
if profiler:
|
||||||
profiler.step()
|
profiler.step()
|
||||||
@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
f"free_mem:{free_mem}")
|
f"free_mem:{free_mem}")
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
|
def warmup_all_buckets(self, buckets, is_prompt):
|
||||||
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
|
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
|
||||||
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
|
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
|
||||||
len(buckets), batch_size, seq_len)
|
len(buckets), batch_size, seq_len)
|
||||||
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
self.warmup_scenario(batch_size, seq_len, is_prompt)
|
||||||
|
|
||||||
def warmup_graphs(self,
|
def warmup_graphs(self,
|
||||||
strategy,
|
strategy,
|
||||||
buckets,
|
buckets,
|
||||||
is_prompt,
|
is_prompt,
|
||||||
kv_caches,
|
|
||||||
available_mem,
|
available_mem,
|
||||||
starting_mem=0,
|
starting_mem=0,
|
||||||
total_batch_seq=0.001):
|
total_batch_seq=0.001):
|
||||||
@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
self.graphed_buckets.add(graphed_bucket)
|
self.graphed_buckets.add(graphed_bucket)
|
||||||
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
|
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
|
||||||
with HabanaMemoryProfiler() as mem_prof:
|
with HabanaMemoryProfiler() as mem_prof:
|
||||||
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
self.warmup_scenario(batch_size, seq_len, is_prompt)
|
||||||
used_mem = align_workers(mem_prof.consumed_device_memory,
|
used_mem = align_workers(mem_prof.consumed_device_memory,
|
||||||
torch.distributed.ReduceOp.MAX)
|
torch.distributed.ReduceOp.MAX)
|
||||||
available_mem -= used_mem
|
available_mem -= used_mem
|
||||||
@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
graphs = graph == 't'
|
graphs = graph == 't'
|
||||||
if graphs:
|
if graphs:
|
||||||
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
|
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
|
||||||
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
|
self.warmup_scenario(int(bs), int(seq_len), is_prompt, True)
|
||||||
True)
|
|
||||||
raise AssertionError("Finished profiling")
|
raise AssertionError("Finished profiling")
|
||||||
if self.skip_warmup:
|
if self.skip_warmup:
|
||||||
logger.info("Skipping warmup...")
|
logger.info("Skipping warmup...")
|
||||||
@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
with compile_only_mode_context(
|
with compile_only_mode_context(
|
||||||
) if can_use_compile_only_mode else contextlib.nullcontext():
|
) if can_use_compile_only_mode else contextlib.nullcontext():
|
||||||
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
|
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
|
||||||
True, kv_caches)
|
True)
|
||||||
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
|
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
|
||||||
False, kv_caches)
|
False)
|
||||||
|
|
||||||
if not self.enforce_eager and htorch.utils.internal.is_lazy():
|
if not self.enforce_eager and htorch.utils.internal.is_lazy():
|
||||||
assert self.mem_margin is not None, \
|
assert self.mem_margin is not None, \
|
||||||
@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
|
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
|
||||||
self.warmup_graphs(
|
self.warmup_graphs(
|
||||||
prompt_strategy, self.bucketing_global_state.prompt_buckets,
|
prompt_strategy, self.bucketing_global_state.prompt_buckets,
|
||||||
True, kv_caches, prompt_available_memory)
|
True, prompt_available_memory)
|
||||||
mem_post_decode, decode_batch_seq, decode_captured_all = \
|
mem_post_decode, decode_batch_seq, decode_captured_all = \
|
||||||
self.warmup_graphs(
|
self.warmup_graphs(
|
||||||
decode_strategy, self.bucketing_global_state.decode_buckets,
|
decode_strategy, self.bucketing_global_state.decode_buckets,
|
||||||
False, kv_caches, decode_available_memory)
|
False, decode_available_memory)
|
||||||
|
|
||||||
# Not all prompt buckets were captured, but all decode buckets
|
# Not all prompt buckets were captured, but all decode buckets
|
||||||
# were captured and we have some free graph-allocated space
|
# were captured and we have some free graph-allocated space
|
||||||
@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
self.warmup_graphs(
|
self.warmup_graphs(
|
||||||
prompt_strategy,
|
prompt_strategy,
|
||||||
self.bucketing_global_state.prompt_buckets, True,
|
self.bucketing_global_state.prompt_buckets, True,
|
||||||
kv_caches,
|
|
||||||
graph_free_mem - mem_post_prompt - mem_post_decode,
|
graph_free_mem - mem_post_prompt - mem_post_decode,
|
||||||
mem_post_prompt, prompt_batch_seq))
|
mem_post_prompt, prompt_batch_seq))
|
||||||
|
|
||||||
@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
mem_post_decode, _, _ = self.warmup_graphs(
|
mem_post_decode, _, _ = self.warmup_graphs(
|
||||||
decode_strategy,
|
decode_strategy,
|
||||||
self.bucketing_global_state.decode_buckets, False,
|
self.bucketing_global_state.decode_buckets, False,
|
||||||
kv_caches,
|
|
||||||
graph_free_mem - mem_post_prompt - mem_post_decode,
|
graph_free_mem - mem_post_prompt - mem_post_decode,
|
||||||
mem_post_decode, decode_batch_seq)
|
mem_post_decode, decode_batch_seq)
|
||||||
|
|
||||||
@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
|||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids": input_tokens,
|
"input_ids": input_tokens,
|
||||||
"positions": input_positions,
|
"positions": input_positions,
|
||||||
"kv_caches": kv_caches,
|
|
||||||
"attn_metadata": self.trim_attn_metadata(attn_metadata),
|
"attn_metadata": self.trim_attn_metadata(attn_metadata),
|
||||||
"intermediate_tensors": intermediate_tensors,
|
"intermediate_tensors": intermediate_tensors,
|
||||||
"lora_mask": lora_mask,
|
"lora_mask": lora_mask,
|
||||||
|
@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs
|
|||||||
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
||||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||||
graph_capture)
|
graph_capture)
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=model_input.attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
self.model(
|
self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_inputs,
|
intermediate_tensors=intermediate_inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
output_hidden_or_intermediate_states = self.model(
|
output_hidden_or_intermediate_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_inputs,
|
intermediate_tensors=intermediate_inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# KV caches are fixed tensors, so we don't need to copy them.
|
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||||
del kv_caches
|
|
||||||
|
|
||||||
# Copy the input tensors to the input buffers.
|
# Copy the input tensors to the input buffers.
|
||||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||||
|
@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
# path for warm up runs
|
# path for warm up runs
|
||||||
if not model_input.is_multi_step:
|
if not model_input.is_multi_step:
|
||||||
return self._base_model_runner.execute_model(
|
return self._base_model_runner.execute_model(
|
||||||
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
|
frozen_model_input, None, intermediate_tensors, num_steps)
|
||||||
|
|
||||||
# make sure we skip the sampler on the lask rank and only pythonize
|
# make sure we skip the sampler on the lask rank and only pythonize
|
||||||
# if CPU is ahead.
|
# if CPU is ahead.
|
||||||
@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
|||||||
|
|
||||||
# Execute the model
|
# Execute the model
|
||||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||||
kv_caches,
|
None,
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
num_steps=1)
|
num_steps=1)
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user