[Model] Support Mamba2 (Codestral Mamba) (#9292)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
parent
7b623fca0b
commit
1f69c4a892
@ -4,6 +4,7 @@
|
||||
Run `pytest tests/models/test_mamba.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -11,7 +12,14 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]
|
||||
MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
# "mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
|
||||
# Use lower-level interfaces to create this greedy generator, as mamba will
|
||||
@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
# Set the device (GPU if available, else CPU)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
# Generate texts from the prompts
|
||||
outputs = []
|
||||
for prompt in example_prompts:
|
||||
@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens):
|
||||
input_ids = inputs["input_ids"].to(model.device)
|
||||
|
||||
# Generate text using the model's generate method directly
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
|
||||
generated_ids = model.generate(input_ids,
|
||||
max_new_tokens=max_tokens,
|
||||
do_sample=False)
|
||||
generated_text = tokenizer.decode(generated_ids[0],
|
||||
skip_special_tokens=True)
|
||||
|
||||
@ -50,7 +64,8 @@ def test_models(
|
||||
) -> None:
|
||||
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
# Set max_num_seqs to keep Codestral from going OOM at fp32
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
# This test is for verifying whether the model's extra_repr
|
||||
@ -81,7 +96,7 @@ def test_batching(
|
||||
) -> None:
|
||||
# To pass the small model tests, we need full precision.
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
for_loop_outputs.append(
|
||||
vllm_model.generate_greedy([prompt], max_tokens)[0])
|
||||
@ -165,20 +180,22 @@ def test_parallel_sampling(
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
# Numerical differences produce slightly different output for these
|
||||
if 'state-spaces' in model:
|
||||
example_prompts.pop(0)
|
||||
example_prompts.pop(0)
|
||||
example_prompts.pop(0)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
for_loop_outputs = []
|
||||
for _ in range(10):
|
||||
for_loop_outputs.append(
|
||||
# using example_prompts index 1 instead of 0 since with 0 the
|
||||
# logprobs get really close and the test doesn't pass
|
||||
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
|
||||
[0])
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)[0])
|
||||
sampling_params = SamplingParams(n=10,
|
||||
temperature=0.001,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
|
||||
sampling_params)
|
||||
n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params)
|
||||
token_ids, texts = n_lt_1_outputs[0]
|
||||
n_lt_1_outputs = [(token_id, text)
|
||||
for token_id, text in zip(token_ids, texts)]
|
||||
@ -232,7 +249,7 @@ def test_models_preemption_recompute(
|
||||
# Tests that outputs are identical with and w/o preemtions (recompute)
|
||||
assert dtype == "float"
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
vllm_model.model.llm_engine.scheduler[
|
||||
0].ENABLE_ARTIFICIAL_PREEMPT = True
|
||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||
@ -283,7 +300,7 @@ def test_state_cleanup(
|
||||
# This test is for verifying that the Mamba state is cleaned up between
|
||||
# steps, If its not cleaned, an error would be expected.
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
for _ in range(10):
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
||||
except ValueError:
|
||||
|
@ -145,6 +145,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
||||
is_available_online=False),
|
||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
|
||||
is_available_online=False),
|
||||
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
|
||||
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
|
||||
trust_remote_code=True),
|
||||
|
@ -293,7 +293,8 @@ def _chunk_scan_fwd_kernel(
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1,
|
||||
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
|
||||
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
@ -463,7 +464,10 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
|
||||
p += (s % chunk_size > 0)
|
||||
|
||||
# get the dimensions
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + 1
|
||||
# - the + 1 for _e is to shift the boundary by one chunk
|
||||
# - this shifting is not needed if chunk_size divides e
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||
> 0)
|
||||
|
||||
# adjust inidces and offsets
|
||||
chunk_indices[_s:_e] -= p
|
||||
|
@ -440,23 +440,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# follow jamba
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
# for compilation
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
elif self.scheduler_config is not None:
|
||||
# for eager just take the scheduler_config if avail
|
||||
self.max_batch_size = self.scheduler_config.max_num_seqs
|
||||
else:
|
||||
self.max_batch_size = 8192 + 2
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@ -474,8 +457,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
|
@ -426,17 +426,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
else:
|
||||
self.max_batch_size = 8192 + 2
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
@ -453,8 +442,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
@ -166,14 +166,13 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Mamba does not support prefix caching"
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.backbone = MambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "backbone"))
|
||||
@ -202,17 +201,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.backbone.make_empty_intermediate_tensors)
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
else:
|
||||
self.max_batch_size = 8192 + 2
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
@ -229,8 +217,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
320
vllm/model_executor/models/mamba2.py
Normal file
320
vllm/model_executor/models/mamba2.py
Normal file
@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""PyTorch MAMBA2 model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
MambaMixer2, extra_groups_for_head_shards)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class Mamba2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: MambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.mixer = MambaMixer2(hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.state_size,
|
||||
conv_kernel_size=config.conv_kernel,
|
||||
intermediate_size=getattr(
|
||||
config, "intermediate_size",
|
||||
config.expand * config.hidden_size),
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.use_bias,
|
||||
n_groups=config.n_groups,
|
||||
num_heads=config.num_heads,
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.layer_norm_epsilon,
|
||||
activation=config.hidden_act,
|
||||
chunk_size=config.chunk_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
||||
mamba_cache_params, sequence_idx)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Mamba2Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
is_lora_enabled = bool(lora_config)
|
||||
assert not is_lora_enabled
|
||||
|
||||
self.config = config
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.embeddings = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Mamba2DecoderLayer(config,
|
||||
quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# pass a sequence index tensor, that is required for
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.query_start_loc[1:],
|
||||
)):
|
||||
seq_idx[srt:end] = i
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer),
|
||||
sequence_idx=seq_idx)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Mamba does not support prefix caching"
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.backbone = Mamba2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "backbone"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.backbone.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
conv_state_shape, temporal_state_shape = None, None
|
||||
|
||||
intermediate_size = getattr(
|
||||
self.config, "intermediate_size",
|
||||
self.config.expand * self.config.hidden_size)
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (
|
||||
self.config.n_groups +
|
||||
extra_groups_for_head_shards(self.config.n_groups, world_size))
|
||||
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.config.conv_kernel - 1,
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(self.config.num_heads, world_size),
|
||||
self.config.head_dim,
|
||||
self.config.state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -22,8 +23,14 @@ class MambaCacheParams:
|
||||
|
||||
class MambaCacheManager:
|
||||
|
||||
def __init__(self, dtype, num_mamba_layers, max_batch_size,
|
||||
conv_state_shape, temporal_state_shape):
|
||||
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
|
||||
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
|
||||
temporal_state_shape: Tuple[int, int]):
|
||||
|
||||
# Determine max batch size to set size of MambaCache
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
if not vllm_config.model_config.enforce_eager:
|
||||
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
|
||||
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
|
@ -71,6 +71,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
|
||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user