[Model] Support Mamba2 (Codestral Mamba) (#9292)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
parent
7b623fca0b
commit
1f69c4a892
@ -4,6 +4,7 @@
|
|||||||
Run `pytest tests/models/test_mamba.py`.
|
Run `pytest tests/models/test_mamba.py`.
|
||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
@ -11,7 +12,14 @@ from vllm.sampling_params import SamplingParams
|
|||||||
|
|
||||||
from ...utils import check_outputs_equal
|
from ...utils import check_outputs_equal
|
||||||
|
|
||||||
MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]
|
MODELS = [
|
||||||
|
"state-spaces/mamba-130m-hf",
|
||||||
|
"tiiuae/falcon-mamba-tiny-dev",
|
||||||
|
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||||
|
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
|
||||||
|
# See https://github.com/huggingface/transformers/pull/35943
|
||||||
|
# "mistralai/Mamba-Codestral-7B-v0.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# Use lower-level interfaces to create this greedy generator, as mamba will
|
# Use lower-level interfaces to create this greedy generator, as mamba will
|
||||||
@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Set the device (GPU if available, else CPU)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
# Generate texts from the prompts
|
# Generate texts from the prompts
|
||||||
outputs = []
|
outputs = []
|
||||||
for prompt in example_prompts:
|
for prompt in example_prompts:
|
||||||
@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens):
|
|||||||
input_ids = inputs["input_ids"].to(model.device)
|
input_ids = inputs["input_ids"].to(model.device)
|
||||||
|
|
||||||
# Generate text using the model's generate method directly
|
# Generate text using the model's generate method directly
|
||||||
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
|
generated_ids = model.generate(input_ids,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
do_sample=False)
|
||||||
generated_text = tokenizer.decode(generated_ids[0],
|
generated_text = tokenizer.decode(generated_ids[0],
|
||||||
skip_special_tokens=True)
|
skip_special_tokens=True)
|
||||||
|
|
||||||
@ -50,7 +64,8 @@ def test_models(
|
|||||||
) -> None:
|
) -> None:
|
||||||
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
|
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
# Set max_num_seqs to keep Codestral from going OOM at fp32
|
||||||
|
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
# This test is for verifying whether the model's extra_repr
|
||||||
@ -81,7 +96,7 @@ def test_batching(
|
|||||||
) -> None:
|
) -> None:
|
||||||
# To pass the small model tests, we need full precision.
|
# To pass the small model tests, we need full precision.
|
||||||
for_loop_outputs = []
|
for_loop_outputs = []
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||||
for prompt in example_prompts:
|
for prompt in example_prompts:
|
||||||
for_loop_outputs.append(
|
for_loop_outputs.append(
|
||||||
vllm_model.generate_greedy([prompt], max_tokens)[0])
|
vllm_model.generate_greedy([prompt], max_tokens)[0])
|
||||||
@ -165,20 +180,22 @@ def test_parallel_sampling(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
# Numerical differences produce slightly different output for these
|
||||||
|
if 'state-spaces' in model:
|
||||||
|
example_prompts.pop(0)
|
||||||
|
example_prompts.pop(0)
|
||||||
|
example_prompts.pop(0)
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||||
for_loop_outputs = []
|
for_loop_outputs = []
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
for_loop_outputs.append(
|
for_loop_outputs.append(
|
||||||
# using example_prompts index 1 instead of 0 since with 0 the
|
vllm_model.generate_greedy(example_prompts, max_tokens)[0])
|
||||||
# logprobs get really close and the test doesn't pass
|
|
||||||
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
|
|
||||||
[0])
|
|
||||||
sampling_params = SamplingParams(n=10,
|
sampling_params = SamplingParams(n=10,
|
||||||
temperature=0.001,
|
temperature=0.001,
|
||||||
seed=0,
|
seed=0,
|
||||||
max_tokens=max_tokens)
|
max_tokens=max_tokens)
|
||||||
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
|
n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params)
|
||||||
sampling_params)
|
|
||||||
token_ids, texts = n_lt_1_outputs[0]
|
token_ids, texts = n_lt_1_outputs[0]
|
||||||
n_lt_1_outputs = [(token_id, text)
|
n_lt_1_outputs = [(token_id, text)
|
||||||
for token_id, text in zip(token_ids, texts)]
|
for token_id, text in zip(token_ids, texts)]
|
||||||
@ -232,7 +249,7 @@ def test_models_preemption_recompute(
|
|||||||
# Tests that outputs are identical with and w/o preemtions (recompute)
|
# Tests that outputs are identical with and w/o preemtions (recompute)
|
||||||
assert dtype == "float"
|
assert dtype == "float"
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||||
vllm_model.model.llm_engine.scheduler[
|
vllm_model.model.llm_engine.scheduler[
|
||||||
0].ENABLE_ARTIFICIAL_PREEMPT = True
|
0].ENABLE_ARTIFICIAL_PREEMPT = True
|
||||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||||
@ -283,7 +300,7 @@ def test_state_cleanup(
|
|||||||
# This test is for verifying that the Mamba state is cleaned up between
|
# This test is for verifying that the Mamba state is cleaned up between
|
||||||
# steps, If its not cleaned, an error would be expected.
|
# steps, If its not cleaned, an error would be expected.
|
||||||
try:
|
try:
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -145,6 +145,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
||||||
is_available_online=False),
|
is_available_online=False),
|
||||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||||
|
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
|
||||||
|
is_available_online=False),
|
||||||
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
|
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
|
||||||
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
|
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
@ -293,7 +293,8 @@ def _chunk_scan_fwd_kernel(
|
|||||||
dA_cs_m_boundary = tl.load(
|
dA_cs_m_boundary = tl.load(
|
||||||
dA_cumsum_ptr +
|
dA_cumsum_ptr +
|
||||||
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
|
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
|
||||||
mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1,
|
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
|
||||||
|
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
|
||||||
other=0.0).to(tl.float32)
|
other=0.0).to(tl.float32)
|
||||||
|
|
||||||
if HAS_SEQ_IDX:
|
if HAS_SEQ_IDX:
|
||||||
@ -463,7 +464,10 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
|
|||||||
p += (s % chunk_size > 0)
|
p += (s % chunk_size > 0)
|
||||||
|
|
||||||
# get the dimensions
|
# get the dimensions
|
||||||
_s, _e = s // chunk_size + p, e // chunk_size + p + 1
|
# - the + 1 for _e is to shift the boundary by one chunk
|
||||||
|
# - this shifting is not needed if chunk_size divides e
|
||||||
|
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||||
|
> 0)
|
||||||
|
|
||||||
# adjust inidces and offsets
|
# adjust inidces and offsets
|
||||||
chunk_indices[_s:_e] -= p
|
chunk_indices[_s:_e] -= p
|
||||||
|
@ -440,23 +440,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
# follow jamba
|
|
||||||
if self.scheduler_config is not None and \
|
|
||||||
not self.model_config.enforce_eager:
|
|
||||||
# for compilation
|
|
||||||
if self.scheduler_config.max_num_seqs > \
|
|
||||||
vllm_config.compilation_config.max_capture_size:
|
|
||||||
self.max_batch_size = \
|
|
||||||
vllm_config.compilation_config.max_capture_size
|
|
||||||
else:
|
|
||||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
|
||||||
self.scheduler_config.max_num_seqs)
|
|
||||||
elif self.scheduler_config is not None:
|
|
||||||
# for eager just take the scheduler_config if avail
|
|
||||||
self.max_batch_size = self.scheduler_config.max_num_seqs
|
|
||||||
else:
|
|
||||||
self.max_batch_size = 8192 + 2
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
@ -474,8 +457,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||||
|
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.mamba_cache = MambaCacheManager(
|
||||||
self.lm_head.weight.dtype, num_mamba_layers,
|
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
*self._get_mamba_cache_shape())
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
attn_metadata, mamba_cache_params,
|
attn_metadata, mamba_cache_params,
|
||||||
|
@ -426,17 +426,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
if self.scheduler_config is not None and \
|
|
||||||
not self.model_config.enforce_eager:
|
|
||||||
if self.scheduler_config.max_num_seqs > \
|
|
||||||
vllm_config.compilation_config.max_capture_size:
|
|
||||||
self.max_batch_size = \
|
|
||||||
vllm_config.compilation_config.max_capture_size
|
|
||||||
else:
|
|
||||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
|
||||||
self.scheduler_config.max_num_seqs)
|
|
||||||
else:
|
|
||||||
self.max_batch_size = 8192 + 2
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
@ -453,8 +442,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.mamba_cache = MambaCacheManager(
|
||||||
self.lm_head.weight.dtype, num_mamba_layers,
|
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
*self._get_mamba_cache_shape())
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
|
@ -166,14 +166,13 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
assert not cache_config.enable_prefix_caching, \
|
assert not cache_config.enable_prefix_caching, \
|
||||||
"Mamba does not support prefix caching"
|
"Mamba does not support prefix caching"
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.backbone = MambaModel(vllm_config=vllm_config,
|
self.backbone = MambaModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "backbone"))
|
prefix=maybe_prefix(prefix, "backbone"))
|
||||||
@ -202,17 +201,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.backbone.make_empty_intermediate_tensors)
|
self.backbone.make_empty_intermediate_tensors)
|
||||||
if self.scheduler_config is not None and \
|
|
||||||
not self.model_config.enforce_eager:
|
|
||||||
if self.scheduler_config.max_num_seqs > \
|
|
||||||
vllm_config.compilation_config.max_capture_size:
|
|
||||||
self.max_batch_size = \
|
|
||||||
vllm_config.compilation_config.max_capture_size
|
|
||||||
else:
|
|
||||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
|
||||||
self.scheduler_config.max_num_seqs)
|
|
||||||
else:
|
|
||||||
self.max_batch_size = 8192 + 2
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.backbone.get_input_embeddings(input_ids)
|
return self.backbone.get_input_embeddings(input_ids)
|
||||||
@ -229,8 +217,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.mamba_cache = MambaCacheManager(
|
||||||
self.lm_head.weight.dtype, num_mamba_layers,
|
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
*self._get_mamba_cache_shape())
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
|
320
vllm/model_executor/models/mamba2.py
Normal file
320
vllm/model_executor/models/mamba2.py
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""PyTorch MAMBA2 model."""
|
||||||
|
from typing import Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import MambaConfig
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
|
MambaMixer2, extra_groups_for_head_shards)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||||
|
IsAttentionFree)
|
||||||
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||||
|
MambaCacheParams)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import LayerBlockType
|
||||||
|
|
||||||
|
from .utils import (is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
|
maybe_prefix)
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: MambaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.mixer = MambaMixer2(hidden_size=config.hidden_size,
|
||||||
|
ssm_state_size=config.state_size,
|
||||||
|
conv_kernel_size=config.conv_kernel,
|
||||||
|
intermediate_size=getattr(
|
||||||
|
config, "intermediate_size",
|
||||||
|
config.expand * config.hidden_size),
|
||||||
|
use_conv_bias=config.use_conv_bias,
|
||||||
|
use_bias=config.use_bias,
|
||||||
|
n_groups=config.n_groups,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
rms_norm_eps=config.layer_norm_epsilon,
|
||||||
|
activation=config.hidden_act,
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
sequence_idx: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
hidden_states = self.mixer(hidden_states, attn_metadata,
|
||||||
|
mamba_cache_params, sequence_idx)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
is_lora_enabled = bool(lora_config)
|
||||||
|
assert not is_lora_enabled
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embeddings = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: Mamba2DecoderLayer(config,
|
||||||
|
quant_config=quant_config),
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
self.norm_f = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
mamba_cache_params: MambaCacheParams,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
# pass a sequence index tensor, that is required for
|
||||||
|
# proper continuous batching computation including
|
||||||
|
# chunked prefill
|
||||||
|
seq_idx = None
|
||||||
|
if attn_metadata.num_prefills > 0:
|
||||||
|
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||||
|
for i, (srt, end) in enumerate(
|
||||||
|
zip(
|
||||||
|
attn_metadata.query_start_loc,
|
||||||
|
attn_metadata.query_start_loc[1:],
|
||||||
|
)):
|
||||||
|
seq_idx[srt:end] = i
|
||||||
|
seq_idx.unsqueeze_(0)
|
||||||
|
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
residual=residual,
|
||||||
|
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||||
|
i - self.start_layer),
|
||||||
|
sequence_idx=seq_idx)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
assert not cache_config.enable_prefix_caching, \
|
||||||
|
"Mamba does not support prefix caching"
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.backbone = Mamba2Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "backbone"))
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
|
||||||
|
|
||||||
|
# Used to track and store by the Mamba cache between steps.
|
||||||
|
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.backbone.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.backbone.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs):
|
||||||
|
if self.mamba_cache is None:
|
||||||
|
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||||
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||||
|
self.mamba_cache = MambaCacheManager(
|
||||||
|
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||||
|
*self._get_mamba_cache_shape())
|
||||||
|
|
||||||
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
|
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||||
|
mamba_cache_params, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||||
|
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||||
|
input_buffers, **kwargs)
|
||||||
|
|
||||||
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
|
def _get_mamba_cache_shape(
|
||||||
|
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||||
|
world_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
conv_state_shape, temporal_state_shape = None, None
|
||||||
|
|
||||||
|
intermediate_size = getattr(
|
||||||
|
self.config, "intermediate_size",
|
||||||
|
self.config.expand * self.config.hidden_size)
|
||||||
|
|
||||||
|
# if n_groups is not divisible by world_size, need to extend the shards
|
||||||
|
# to ensure all groups needed by a head is sharded along with it
|
||||||
|
n_groups = (
|
||||||
|
self.config.n_groups +
|
||||||
|
extra_groups_for_head_shards(self.config.n_groups, world_size))
|
||||||
|
|
||||||
|
# - heads and n_groups are TP-ed
|
||||||
|
conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size)
|
||||||
|
conv_state_shape = (
|
||||||
|
divide(conv_dim, world_size),
|
||||||
|
self.config.conv_kernel - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are not TP-ed as they depend on A, dt_bias, D
|
||||||
|
# - they are typically small
|
||||||
|
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||||
|
temporal_state_shape = (
|
||||||
|
divide(self.config.num_heads, world_size),
|
||||||
|
self.config.head_dim,
|
||||||
|
self.config.state_size,
|
||||||
|
)
|
||||||
|
return conv_state_shape, temporal_state_shape
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "A_log" in name:
|
||||||
|
name = name.replace("A_log", "A")
|
||||||
|
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
@ -1,11 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -22,8 +23,14 @@ class MambaCacheParams:
|
|||||||
|
|
||||||
class MambaCacheManager:
|
class MambaCacheManager:
|
||||||
|
|
||||||
def __init__(self, dtype, num_mamba_layers, max_batch_size,
|
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
|
||||||
conv_state_shape, temporal_state_shape):
|
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
|
||||||
|
temporal_state_shape: Tuple[int, int]):
|
||||||
|
|
||||||
|
# Determine max batch size to set size of MambaCache
|
||||||
|
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
|
if not vllm_config.model_config.enforce_eager:
|
||||||
|
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
|
||||||
|
|
||||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||||
conv_state_shape,
|
conv_state_shape,
|
||||||
|
@ -71,6 +71,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||||
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||||
|
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
|
||||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user