Upstream Llama4 Support to Main (#16113)
Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com> Signed-off-by: Chris Thi <chris.c.thi@gmail.com> Signed-off-by: drisspg <drisspguessous@gmail.com> Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Keyun Tong <tongkeyun@gmail.com> Signed-off-by: Lu Fang <fanglu@meta.com> Signed-off-by: Xiaodong Wang <xdwang@meta.com> Signed-off-by: Yang Chen <yangche@fb.com> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Lu Fang <fanglu@fb.com> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8017c8db7f
commit
55dcce91df
@ -389,7 +389,8 @@ steps:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||
|
||||
- label: Language Models Test (Standard) # 32min
|
||||
#mirror_hardwares: [amd]
|
||||
|
@ -553,6 +553,9 @@ def main(args: argparse.Namespace):
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
if not hasattr(config, "hidden_size"):
|
||||
# Support for llama4
|
||||
config = config.text_config
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
@ -850,6 +850,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `Llama4ForConditionalGeneration`
|
||||
* Llama-4-17B-Omni-Instruct
|
||||
* T + I<sup>+</sup>
|
||||
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `LlavaForConditionalGeneration`
|
||||
* LLaVA-1.5
|
||||
* T + I<sup>E+</sup>
|
||||
|
@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
|
@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_llama4(questions: list[str], modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=4,
|
||||
tensor_parallel_size=8,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
gpu_memory_utilization=0.4,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": f"{question}"
|
||||
}]
|
||||
}] for question in questions]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
stop_token_ids = None
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# Molmo
|
||||
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -907,6 +943,7 @@ model_example_map = {
|
||||
"minicpmv": run_minicpmv,
|
||||
"mistral3": run_mistral3,
|
||||
"mllama": run_mllama,
|
||||
"llama4": run_llama4,
|
||||
"molmo": run_molmo,
|
||||
"NVLM_D": run_nvlm_d,
|
||||
"paligemma": run_paligemma,
|
||||
|
@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=4,
|
||||
tensor_parallel_size=8,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
@ -567,6 +604,7 @@ model_example_map = {
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"idefics3": load_idefics3,
|
||||
"internvl_chat": load_internvl,
|
||||
"llama4": load_llama4,
|
||||
"mistral3": load_mistral3,
|
||||
"mllama": load_mllama,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
|
@ -6,7 +6,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.50.3
|
||||
transformers >= 4.51.0
|
||||
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
|
@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
transformers==4.50.3
|
||||
transformers==4.51.0
|
||||
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
|
||||
# quantization
|
||||
bitsandbytes>=0.45.3
|
||||
|
@ -645,7 +645,7 @@ tqdm==4.66.6
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.50.3
|
||||
transformers==4.51.0
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
|
@ -12,6 +12,7 @@ from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import HfRunner, VllmRunner
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
@ -55,7 +56,10 @@ def server(request, audio_assets):
|
||||
for key, value in request.param.items()
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
args,
|
||||
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
"30"}) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@ -106,6 +110,10 @@ def run_test(
|
||||
**kwargs,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm."""
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
@ -156,6 +164,10 @@ def run_multi_audio_test(
|
||||
num_logprobs: int,
|
||||
**kwargs,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
|
@ -160,17 +160,32 @@ VLM_TEST_SETTINGS = {
|
||||
),
|
||||
"aya_vision": VLMTestInfo(
|
||||
models=["CohereForAI/aya-vision-8b"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<image>What's the content in the center of the image?", # noqa: E501
|
||||
"cherry_blossom": "<image>What is the season?", # noqa: E501
|
||||
}),
|
||||
multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501
|
||||
max_model_len=8192,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}},
|
||||
),
|
||||
"aya_vision-multi_image": VLMTestInfo(
|
||||
models=["CohereForAI/aya-vision-8b"],
|
||||
test_type=(VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<image>What's the content in the center of the image?", # noqa: E501
|
||||
"cherry_blossom": "<image>What is the season?", # noqa: E501
|
||||
}),
|
||||
multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}},
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"blip2": VLMTestInfo(
|
||||
# TODO: Change back to 2.7b once head_dim = 80 is supported
|
||||
@ -303,6 +318,22 @@ VLM_TEST_SETTINGS = {
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
),
|
||||
"llama4": VLMTestInfo(
|
||||
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda _: "<|image|>",
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
distributed_executor_backend="mp",
|
||||
image_size_factors=[(.25, 0.5, 1.0)],
|
||||
hf_model_kwargs={"device_map": "auto"},
|
||||
max_model_len=8192,
|
||||
max_num_seqs=4,
|
||||
dtype="bfloat16",
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
tensor_parallel_size=8,
|
||||
vllm_runner_kwargs={"gpu_memory_utilization": 0.8},
|
||||
marks=multi_gpu_marks(num_gpus=8),
|
||||
),
|
||||
"llava_next": VLMTestInfo(
|
||||
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS),
|
||||
|
@ -5,7 +5,9 @@ import re
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
@ -81,6 +83,13 @@ def run_test(
|
||||
from transformers import AutoImageProcessor # noqa: F401
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
|
||||
# Once the model repo is updated to 4.49, we should be able to run the
|
||||
# test in `test_models.py` without the above workaround
|
||||
if Version(TRANSFORMERS_VERSION) >= Version("4.49"):
|
||||
pytest.skip(f"`transformers=={TRANSFORMERS_VERSION}` installed, "
|
||||
"but `transformers<=4.49` is required to run this model. "
|
||||
"Reason: Cannot run HF implementation")
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
|
@ -176,6 +176,8 @@ def test_chat(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
load_format="mistral",
|
||||
config_format="mistral",
|
||||
max_model_len=max_model_len,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
) as vllm_model:
|
||||
|
@ -257,6 +257,7 @@ def _test_processing_correctness_mistral(
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
|
99
tests/models/multimodal/processing/test_llama4.py
Normal file
99
tests/models/multimodal/processing/test_llama4.py
Normal file
@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for Llama4's multimodal preprocessing kwargs."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.transformers_utils.tokenizer import encode_tokens
|
||||
|
||||
from ....conftest import _ImageAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id",
|
||||
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
|
||||
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 5])
|
||||
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
|
||||
@pytest.mark.parametrize("tokenized_prompt", [True, False])
|
||||
def test_processor_override(
|
||||
image_assets: _ImageAssets,
|
||||
model_id: str,
|
||||
mm_processor_kwargs: dict,
|
||||
num_imgs: int,
|
||||
disable_mm_preprocessor_cache: bool,
|
||||
tokenized_prompt: bool,
|
||||
):
|
||||
"""Ensure llama4 processor works properly."""
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
config = processor.info.get_hf_config()
|
||||
tokenizer = processor.info.get_tokenizer()
|
||||
hf_processor = processor.info.get_hf_processor()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \
|
||||
+ "<|image|>" * num_imgs \
|
||||
+ "<|eot|><|header_start|>assistant<|header_end|>"
|
||||
mm_data = {
|
||||
"image": [
|
||||
image_assets[(i % len(image_assets))].pil_image
|
||||
for i in range(num_imgs)
|
||||
]
|
||||
}
|
||||
if tokenized_prompt:
|
||||
prompt = encode_tokens(tokenizer, prompt)
|
||||
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
mm_kwargs = processed_inputs["mm_kwargs"]
|
||||
|
||||
# place holder replacements
|
||||
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
||||
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
|
||||
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
|
||||
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
|
||||
aspect_ratios = mm_kwargs["aspect_ratios"]
|
||||
num_x_separators = num_y_separators = 0
|
||||
for tiles_y, tiles_x in aspect_ratios:
|
||||
if tiles_x * tiles_y > 1:
|
||||
num_x_separators += (tiles_x - 1) * tiles_y
|
||||
num_y_separators += tiles_y
|
||||
assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \
|
||||
== num_x_separators
|
||||
assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \
|
||||
== num_y_separators
|
||||
|
||||
# image token offsets
|
||||
img_locs = processed_inputs["mm_placeholders"].get("image", [])
|
||||
assert len(img_locs) == num_imgs
|
||||
assert [img_loc["offset"] for img_loc in img_locs] == \
|
||||
[i for i, v in enumerate(prompt_token_ids) \
|
||||
if v == config.boi_token_index]
|
||||
|
||||
# patch sizes and masks
|
||||
assert prompt_token_ids.count(config.image_token_index) \
|
||||
== sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"])
|
||||
patch_token_id = vocab[hf_processor.img_patch_token]
|
||||
num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
|
||||
mm_counts = {"image": num_imgs}
|
||||
assert num_patches / num_imgs <= \
|
||||
processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
|
||||
num_patches_per_chunk = processor.info.get_patch_per_chunk(
|
||||
config.vision_config)
|
||||
assert prompt_token_ids.count(config.image_token_index) \
|
||||
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
|
||||
assert mm_kwargs["pixel_values"].shape[0] \
|
||||
== mm_kwargs["patches_per_image"].sum()
|
||||
|
||||
for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"],
|
||||
mm_kwargs["aspect_ratios"]):
|
||||
assert embed_is_patch.shape[0] == \
|
||||
len(tokenizer.encode(
|
||||
hf_processor._prompt_split_image(
|
||||
aspect_ratio, num_patches_per_chunk),
|
||||
add_special_tokens=False))
|
@ -287,12 +287,16 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
transformers_version_reason="HF model is not compatible."), # noqa: E501
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
||||
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||
min_transformers_version="4.51"),
|
||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
|
||||
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
|
||||
|
@ -7,6 +7,8 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
@ -42,14 +44,21 @@ def test_can_initialize(model_arch):
|
||||
self.cache_config.num_gpu_blocks = 0
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
|
||||
def _initalize_kv_caches_v1(self, vllm_config):
|
||||
# gpu_blocks (> 0), cpu_blocks
|
||||
return 1, 0
|
||||
def _initialize_kv_caches_v1(self, vllm_config):
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
scheduler_kv_cache_config = get_kv_cache_config(
|
||||
vllm_config,
|
||||
kv_cache_specs[0],
|
||||
20 * GiB_bytes,
|
||||
)
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
return 1, 0, scheduler_kv_cache_config
|
||||
|
||||
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
|
||||
_initialize_kv_caches_v0),
|
||||
patch.object(V1EngineCore, "_initialize_kv_caches",
|
||||
_initalize_kv_caches_v1)):
|
||||
_initialize_kv_caches_v1)):
|
||||
LLM(
|
||||
model_info.default,
|
||||
tokenizer=model_info.tokenizer,
|
||||
|
@ -358,6 +358,8 @@ class ModelConfig:
|
||||
self.hf_config = hf_config
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(self.hf_text_config,
|
||||
"attention_chunk_size", None)
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, hf_token=hf_token, revision=revision)
|
||||
|
@ -500,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"internvl_chat", "skywork_chat", "NVLM_D",
|
||||
"h2ovl_chat", "idefics3"):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
if model_type in ("mllama", "llama4"):
|
||||
return "<|image|>"
|
||||
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
@ -0,0 +1,200 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
}
|
@ -23,6 +23,7 @@ def cutlass_moe_fp8(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: torch.dtype = torch.half,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
@ -96,8 +97,14 @@ def cutlass_moe_fp8(
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
if apply_router_weight_on_input:
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
a = a * topk_weights.to(out_dtype)
|
||||
|
||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
||||
@ -139,6 +146,8 @@ def cutlass_moe_fp8(
|
||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
||||
ab_strides2, c_strides2)
|
||||
|
||||
return (c2[c_map].view(m, topk, k) *
|
||||
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||
# Gather tokens
|
||||
c2 = c2[c_map].view(m, topk, k)
|
||||
if not apply_router_weight_on_input:
|
||||
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
return c2.sum(dim=1)
|
||||
|
@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@ -980,6 +981,7 @@ def inplace_fused_experts_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@ -1010,6 +1012,7 @@ def outplace_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@ -1023,10 +1026,11 @@ def outplace_fused_experts(
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
assert apply_router_weight_on_input is False
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
apply_router_weight_on_input,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
|
@ -65,7 +65,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -156,22 +158,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
return self.forward(x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation)
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@ -188,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
@ -202,15 +208,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
@ -228,9 +236,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert activation == "silu", f"{activation} is not supported."
|
||||
assert apply_router_weight_on_input is False
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
@ -259,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert not use_grouped_topk
|
||||
@ -266,6 +277,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
assert layer is not None
|
||||
assert apply_router_weight_on_input is False
|
||||
if scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for HPU.")
|
||||
@ -290,12 +302,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
assert apply_router_weight_on_input is False
|
||||
if scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for TPU.")
|
||||
@ -401,6 +415,7 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
@ -486,6 +501,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
@ -853,6 +869,7 @@ class FusedMoE(torch.nn.Module):
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
if self.dp_size > 1:
|
||||
|
@ -92,6 +92,7 @@ class RMSNorm(CustomOp):
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -100,8 +101,10 @@ class RMSNorm(CustomOp):
|
||||
self.variance_size_override = (None if var_hidden_size == hidden_size
|
||||
else var_hidden_size)
|
||||
self.has_weight = has_weight
|
||||
|
||||
self.weight = torch.ones(hidden_size)
|
||||
if dtype is not None:
|
||||
self.weight = torch.ones(hidden_size, dtype=dtype)
|
||||
else:
|
||||
self.weight = torch.ones(hidden_size)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
|
||||
|
@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
|
@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@ -240,20 +241,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
out_dtype=x.dtype,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
|
@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int8_w8a16=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale)
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int8_w8a16=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale)
|
||||
|
||||
@staticmethod
|
||||
def quantizing_weight_loader(layer, weight_loader):
|
||||
|
@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
|
@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused GGUF MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input is not None:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
|
@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@ -312,21 +313,23 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size])
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size])
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
|
@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
return new_freqs
|
||||
|
||||
|
||||
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
inv_freqs = super()._compute_inv_freq(base)
|
||||
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
|
||||
return inv_freqs
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
|
||||
# self.max_position_embeddings here is number of image patches
|
||||
# i.e. (image_size // patch_size) ** 2
|
||||
num_patches = self.max_position_embeddings
|
||||
img_idx = torch.arange(num_patches,
|
||||
dtype=torch.int32) \
|
||||
.reshape(num_patches, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
|
||||
num_patches_single_dim = int(math.sqrt(num_patches))
|
||||
frequencies_x = img_idx % num_patches_single_dim
|
||||
frequencies_y = img_idx // num_patches_single_dim
|
||||
freqs_x = ((frequencies_x + 1)[..., None] *
|
||||
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
|
||||
freqs_y = ((frequencies_y + 1)[..., None] *
|
||||
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
|
||||
freqs = torch.cat([freqs_x, freqs_y],
|
||||
dim=-1).float().contiguous()[..., ::2]
|
||||
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
||||
cache = torch.view_as_complex(
|
||||
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
query_ = torch.view_as_complex(query.float().reshape(
|
||||
*query.shape[:-1], -1, 2))
|
||||
key_ = torch.view_as_complex(key.float().reshape(
|
||||
*key.shape[:-1], -1, 2))
|
||||
broadcast_shape = [
|
||||
d if i == 1 or i == (query_.ndim - 1) else 1
|
||||
for i, d in enumerate(query_.shape)
|
||||
]
|
||||
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
|
||||
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
||||
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
||||
return query_out.type_as(query), key_out.type_as(key)
|
||||
|
||||
|
||||
class MRotaryEmbedding(RotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
@ -1130,6 +1194,10 @@ def get_rope(
|
||||
scaling_factor, low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position)
|
||||
elif scaling_type == "mllama4":
|
||||
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_scaling:
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
|
@ -111,10 +111,12 @@ def _initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: Optional[type[nn.Module]] = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_config = vllm_config.model_config
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
@ -22,7 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
@ -292,7 +294,7 @@ class LlamaModel(nn.Module):
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm"
|
||||
"norm": "model.norm",
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = self._init_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
layer_type=layer_type)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
def _init_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||
return LlamaModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=layer_type)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
531
vllm/model_executor/models/llama4.py
Normal file
531
vllm/model_executor/models/llama4.py
Normal file
@ -0,0 +1,531 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Llama4TextConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
is_pp_missing_parameter)
|
||||
|
||||
|
||||
class Llama4MoE(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def custom_routing_function(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(
|
||||
hidden_states.dtype)
|
||||
return (router_scores, router_indices.to(torch.int32))
|
||||
|
||||
def __init__(self,
|
||||
config: Llama4TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
intermediate_size_moe = config.intermediate_size
|
||||
self.router = ReplicatedLinear(config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router")
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
apply_router_weight_on_input=True,
|
||||
reduce_results=False,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
self.shared_expert = LlamaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=False, # We need to do scatter before reduce
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
routed_out = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
experts_out = routed_out + shared_out
|
||||
|
||||
if self.tp_size > 1:
|
||||
experts_out = tensor_model_parallel_all_reduce(experts_out)
|
||||
|
||||
return experts_out
|
||||
|
||||
|
||||
class Llama4Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Llama4TextConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
self.no_rope_layers = config.no_rope_layers
|
||||
self.nope = self.no_rope_layers[self.layer_idx] == 0
|
||||
self.use_qk_norm = config.use_qk_norm and not self.nope
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = config.head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
# TODO: attn_temperature_tuning should be a bool in huggingface
|
||||
self.attn_temperature_tuning = self.nope and \
|
||||
config.attn_temperature_tuning > 0
|
||||
|
||||
self.floor_scale = getattr(config, "floor_scale", 8192.0)
|
||||
self.attn_scale = getattr(config, "attn_scale", 0.1)
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.n_rep = self.num_heads // self.num_kv_heads
|
||||
self.q_norm = RMSNorm(
|
||||
hidden_size=self.q_size,
|
||||
eps=config.rms_norm_eps,
|
||||
has_weight=False,
|
||||
dtype=torch.float32,
|
||||
) if self.use_qk_norm else None
|
||||
self.k_norm = RMSNorm(
|
||||
hidden_size=self.kv_size,
|
||||
eps=config.rms_norm_eps,
|
||||
has_weight=False,
|
||||
dtype=torch.float32,
|
||||
) if self.use_qk_norm else None
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias_o_proj,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
is_neox_style = True
|
||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||
if is_gguf and config.model_type == "llama":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=int(rope_theta),
|
||||
rope_scaling=rope_scaling if rope_scaling != "default" else None,
|
||||
is_neox_style=is_neox_style,
|
||||
) if not self.nope else None
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=None,
|
||||
use_irope=not self.nope,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
||||
|
||||
return attn_scale.unsqueeze(-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
if self.rotary_emb is not None:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.q_norm is not None:
|
||||
q = self.q_norm(q.float()).to(q.dtype)
|
||||
if self.k_norm is not None:
|
||||
k = self.k_norm(k.float()).to(k.dtype)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
|
||||
# to NoPE layers, where the inference-time temperature tuning function
|
||||
# is customized to not affect short context
|
||||
# while working at very long context
|
||||
# https://arxiv.org/abs/2501.19399
|
||||
#
|
||||
# We should apply temperature tuning between (after) rotary / QK norm
|
||||
# and (before) attention.
|
||||
if self.attn_temperature_tuning and self.nope:
|
||||
attn_scale = self._get_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Llama4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
self.self_attn = Llama4Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
bias_o_proj=False,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
is_moe_layer = (self.layer_idx +
|
||||
1) % config.interleave_moe_layer_step == 0
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
self.feed_forward = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size_mlp,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Llama4Model(LlamaModel):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
|
||||
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=layer_type)
|
||||
|
||||
def load_moe_expert_weights(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
params_dict: Dict[str, nn.Parameter],
|
||||
loaded_params: Set[str],
|
||||
expert_params_mapping: List[Tuple[str, str, int, str]],
|
||||
fused: bool = True,
|
||||
) -> bool:
|
||||
expert_param_loaded = False
|
||||
if "experts.gate_up_proj" in name:
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-1)
|
||||
for (param_name, weight_name, expert_id,
|
||||
shard_id) in expert_params_mapping:
|
||||
new_loaded_weight = loaded_weight
|
||||
if fused:
|
||||
e_str, _, proj_str, _ = weight_name.split('.')
|
||||
weight_name = f"{e_str}.{proj_str}"
|
||||
param_name = f"{param_name}weight"
|
||||
if weight_name not in name:
|
||||
continue
|
||||
full_param_name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[full_param_name]
|
||||
weight_loader = param.weight_loader
|
||||
if fused:
|
||||
if "w13" in full_param_name:
|
||||
shard_idx = 0 if shard_id == "w1" else 1
|
||||
new_loaded_weight = new_loaded_weight[shard_idx]
|
||||
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
|
||||
layer_idx = extract_layer_index(name)
|
||||
# EP mapping
|
||||
expert_map = self.layers[
|
||||
layer_idx].feed_forward.experts.expert_map
|
||||
if expert_map is not None:
|
||||
local_expert_indices = (expert_map != -1) \
|
||||
.nonzero() \
|
||||
.flatten() \
|
||||
.to(new_loaded_weight.device)
|
||||
new_loaded_weight = new_loaded_weight[local_expert_indices]
|
||||
expert_id = local_expert_indices[0].item()
|
||||
else:
|
||||
# TODO: add EP support for non fused weights
|
||||
pass
|
||||
weight_loader(param,
|
||||
new_loaded_weight,
|
||||
full_param_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
|
||||
loaded_params.add(full_param_name)
|
||||
expert_param_loaded = True
|
||||
return expert_param_loaded
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
fused_experts_params = False
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.num_experts)
|
||||
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="gate_up_proj",
|
||||
num_experts=1)
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||
fused_experts_params = True
|
||||
expert_params_mapping = expert_params_mapping_fused
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name or "experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
moe_loaded = self.load_moe_expert_weights(
|
||||
name,
|
||||
loaded_weight,
|
||||
params_dict,
|
||||
loaded_params,
|
||||
expert_params_mapping,
|
||||
fused=fused_experts_params)
|
||||
|
||||
if not moe_loaded:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Update temperature tuning config from generation config
|
||||
gen_config = vllm_config.model_config.try_get_generation_config()
|
||||
gen_config.update(vllm_config.model_config.override_generation_config)
|
||||
vllm_config.model_config.hf_config.attn_temperature_tuning \
|
||||
= gen_config.get("attn_temperature_tuning", False)
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=Llama4DecoderLayer)
|
||||
|
||||
def _init_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
|
||||
return Llama4Model(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=layer_type)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
weights = [
|
||||
self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
]
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def permute_qk_weight_for_rotary(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if ("wk" in modules or "k_proj" in modules) \
|
||||
and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif ("wq" in modules or "q_proj" in modules) \
|
||||
and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
return name, loaded_weight
|
895
vllm/model_executor/models/mllama4.py
Normal file
895
vllm/model_executor/models/mllama4.py
Normal file
@ -0,0 +1,895 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import cached_property
|
||||
from itertools import tee
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.models.llama4 import Llama4Processor
|
||||
from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
find_supported_resolutions, get_best_fit)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.loader import _initialize_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Llama4ImagePatchInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
flat_data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_chunks, num_channels, image size, image size)`
|
||||
"""
|
||||
patches_per_image: torch.Tensor
|
||||
"""
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
"""
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
"""
|
||||
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A list of aspect ratios corresponding to the number of tiles
|
||||
in each dimension that each image in the batch corresponds to.
|
||||
|
||||
Shape:
|
||||
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
|
||||
"""
|
||||
|
||||
|
||||
class Llama4VisionMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
intermediate_size: int,
|
||||
output_size: int,
|
||||
bias: bool,
|
||||
output_activation: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
input_size=input_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
self.activation_fn = nn.GELU()
|
||||
self.output_activation = output_activation
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
if self.output_activation:
|
||||
return self.activation_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Llama4MultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.linear_1 = ColumnParallelLinear(
|
||||
input_size=config.vision_config.vision_output_dim,
|
||||
output_size=config.text_config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix=f"{prefix}.linear_1",
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states, _ = self.linear_1(image_features)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def pixel_shuffle(input_tensor, shuffle_ratio):
|
||||
# input_tensor: [batch_size, num_patches, channels]
|
||||
batch_size, num_patches, channels = input_tensor.shape
|
||||
patch_size = int(math.sqrt(num_patches))
|
||||
|
||||
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
||||
batch_size, height, width, channels = input_tensor.size()
|
||||
|
||||
reshaped_tensor = input_tensor.view(batch_size, height,
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / shuffle_ratio))
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
reshaped_tensor = reshaped_tensor.view(batch_size,
|
||||
int(height * shuffle_ratio),
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / (shuffle_ratio**2)))
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
output_tensor = reshaped_tensor.view(batch_size, -1,
|
||||
reshaped_tensor.shape[-1])
|
||||
return output_tensor
|
||||
|
||||
|
||||
class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
||||
self.inner_dim = int(config.projector_input_dim //
|
||||
(self.pixel_shuffle_ratio**2))
|
||||
self.output_dim = config.projector_output_dim
|
||||
self.mlp = Llama4VisionMLP(
|
||||
input_size=config.intermediate_size,
|
||||
intermediate_size=config.projector_input_dim,
|
||||
output_size=config.projector_output_dim,
|
||||
bias=config.multi_modal_projector_bias,
|
||||
output_activation=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = pixel_shuffle(encoded_patches,
|
||||
self.pixel_shuffle_ratio)
|
||||
return self.mlp(encoded_patches)
|
||||
|
||||
|
||||
class Llama4VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
assert self.num_heads % self.tp_size == 0
|
||||
self.num_local_heads = self.num_heads // self.tp_size
|
||||
self.q_size = self.num_local_heads * self.head_dim
|
||||
self.kv_size = self.num_local_heads * self.head_dim
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
|
||||
self.scaling)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
|
||||
# number of image patches
|
||||
max_position=(config.image_size // config.patch_size)**2,
|
||||
base=config.rope_theta,
|
||||
rope_scaling={"rope_type": "mllama4"},
|
||||
is_neox_style=False,
|
||||
dtype=torch.complex64, # important
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
|
||||
k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
|
||||
q, k = self.rotary_emb(q, k)
|
||||
|
||||
q = q.view(q.shape[0], q.shape[1], -1)
|
||||
k = k.view(k.shape[0], k.shape[1], -1)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output, _ = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Llama4VisionEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = Llama4VisionAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=True,
|
||||
output_activation=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
):
|
||||
# Self Attention
|
||||
residual = hidden_state
|
||||
hidden_state = self.input_layernorm(hidden_state)
|
||||
hidden_state = self.self_attn(hidden_state)
|
||||
hidden_state = residual + hidden_state
|
||||
|
||||
# Feed forward
|
||||
residual = hidden_state
|
||||
hidden_state = self.post_attention_layernorm(hidden_state)
|
||||
hidden_state = self.mlp(hidden_state)
|
||||
hidden_state = residual + hidden_state
|
||||
|
||||
outputs = (hidden_state, )
|
||||
return outputs
|
||||
|
||||
|
||||
class Llama4VisionEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
Llama4VisionEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
) for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape
|
||||
`(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to
|
||||
directly pass an embedded representation. This is useful if you
|
||||
want more control over how to convert `input_ids` indices into
|
||||
associated vectors than the model's internal embedding
|
||||
lookup matrix.
|
||||
"""
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(hidden_states)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Llama4UnfoldConvolution(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
kernel_size = config.patch_size
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
||||
stride=config.patch_size)
|
||||
self.linear = ColumnParallelLinear(config.num_channels *
|
||||
kernel_size[0] * kernel_size[1],
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix=f"{prefix}.linear")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.unfold(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
hidden_states, _ = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Llama4VisionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_channels = config.num_channels
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size)**2 + 1
|
||||
self.scale = config.hidden_size**-0.5
|
||||
|
||||
self.patch_embedding = Llama4UnfoldConvolution(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.patch_embedding")
|
||||
|
||||
self.class_embedding = nn.Parameter(self.scale *
|
||||
torch.randn(self.hidden_size))
|
||||
self.positional_embedding_vlm = nn.Parameter(
|
||||
self.scale * torch.randn(self.num_patches, self.hidden_size))
|
||||
|
||||
# layer norms
|
||||
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
|
||||
# encoders
|
||||
self.model = Llama4VisionEncoder(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.model")
|
||||
self.vision_adapter = Llama4VisionPixelShuffleMLP(
|
||||
config, quant_config, prefix=f"{prefix}.vision_adapter")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images_flattened: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Patch embedding
|
||||
hidden_state = self.patch_embedding(images_flattened)
|
||||
num_tiles, num_patches, hidden_dim = hidden_state.shape
|
||||
|
||||
# Add cls token
|
||||
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1,
|
||||
hidden_state.shape[-1])
|
||||
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
|
||||
num_patches += 1
|
||||
|
||||
# Position embeddings
|
||||
hidden_state = hidden_state.reshape(
|
||||
num_tiles,
|
||||
1,
|
||||
num_patches,
|
||||
hidden_dim,
|
||||
)
|
||||
positional_embedding = self.positional_embedding_vlm.to(
|
||||
dtype=hidden_state.dtype, device=hidden_state.device)
|
||||
hidden_state = hidden_state + positional_embedding
|
||||
hidden_state = self.layernorm_pre(hidden_state)
|
||||
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
|
||||
|
||||
# Apply encoder
|
||||
hidden_state = self.model(hidden_state)
|
||||
hidden_state = self.layernorm_post(hidden_state)
|
||||
|
||||
# Remove CLS token output
|
||||
hidden_state = hidden_state[:, :-1, :]
|
||||
|
||||
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
|
||||
hidden_state = self.vision_adapter(hidden_state)
|
||||
|
||||
return hidden_state
|
||||
|
||||
|
||||
class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(ctx)
|
||||
|
||||
def get_hf_config(self) -> Llama4Config:
|
||||
return self.ctx.get_hf_config(Llama4Config)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
|
||||
return self.ctx.get_hf_processor(Llama4Processor,
|
||||
use_fast=True,
|
||||
**kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 10}
|
||||
|
||||
@staticmethod
|
||||
def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
|
||||
image_size = vision_config.image_size
|
||||
patch_size = vision_config.patch_size
|
||||
|
||||
assert (
|
||||
image_size %
|
||||
patch_size == 0), f"chunk size {image_size} should be multiple of "
|
||||
f"patch_size {patch_size}"
|
||||
|
||||
ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
|
||||
return (image_size // patch_size)**2 // ds_ratio
|
||||
|
||||
def get_max_num_tiles(self) -> int:
|
||||
image_processor = self.get_hf_processor().image_processor
|
||||
return image_processor.max_patches
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
# image_start + local tiles * (patches + 1 x separator) +
|
||||
# 1 global tile * (image x 1 + patches) + image_end
|
||||
token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
|
||||
mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2
|
||||
return {"image": mm_max_tokens}
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
image_size = vision_config.image_size
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
return ImageSize(height=self.get_max_num_tiles() * image_size,
|
||||
width=image_size)
|
||||
|
||||
|
||||
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
if mm_data is None:
|
||||
return tokenizer(prompt, add_special_tokens=False) # exclude bos
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
image_processor = processor.image_processor
|
||||
vision_config = self.info.get_hf_config().vision_config
|
||||
|
||||
if processed_outputs.get("pixel_values") is not None:
|
||||
assert "images" in mm_data, \
|
||||
"images expected to be in mm_data when pixel_values is present"
|
||||
|
||||
images = mm_data["images"]
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
"image":
|
||||
images
|
||||
}).get_items("image", ImageProcessorItems))
|
||||
|
||||
tile_size = vision_config.image_size
|
||||
possible_resolutions = find_supported_resolutions(
|
||||
max_num_chunks=self.info.get_max_num_tiles(),
|
||||
patch_size=SizeDict(height=tile_size, width=tile_size),
|
||||
)
|
||||
best_fit_sizes = [
|
||||
get_best_fit(
|
||||
(image.size[1], image.size[0]),
|
||||
torch.tensor(possible_resolutions),
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas)
|
||||
for image in parsed_images
|
||||
]
|
||||
# TODO tile height/width do not necessarily need to match
|
||||
aspect_ratios = [(image_size[0] // tile_size,
|
||||
image_size[1] // tile_size)
|
||||
for image_size in best_fit_sizes]
|
||||
patches_per_image = [
|
||||
1 if r_h * r_w == 1 else 1 + r_h * r_w
|
||||
for (r_h, r_w) in aspect_ratios
|
||||
]
|
||||
|
||||
# embed_is_patch should have one feature per image-related token:
|
||||
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
|
||||
# -> False
|
||||
# <|patch|> -> True
|
||||
# embed_is_patch has no entries corresponding to non-image-related
|
||||
# tokens.
|
||||
patch_id = tokenizer.get_vocab()[processor.img_patch_token]
|
||||
num_patches_per_chunk = self.info.get_patch_per_chunk(
|
||||
vision_config)
|
||||
expanded_image_tokens_list = [
|
||||
processor._prompt_split_image(aspect_ratio,
|
||||
num_patches_per_chunk)
|
||||
for aspect_ratio in aspect_ratios
|
||||
]
|
||||
expanded_image_token_ids = [
|
||||
tokenizer.encode(image_tokens, add_special_tokens=False)
|
||||
for image_tokens in expanded_image_tokens_list
|
||||
]
|
||||
embed_is_patch = [
|
||||
torch.tensor(tokens) == patch_id
|
||||
for tokens in expanded_image_token_ids
|
||||
]
|
||||
|
||||
processed_outputs["aspect_ratios"] = aspect_ratios
|
||||
processed_outputs["patches_per_image"] = torch.tensor(
|
||||
patches_per_image)
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", patches_per_image),
|
||||
patches_per_image=MultiModalFieldConfig.batched("image"),
|
||||
aspect_ratios=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> List[PromptUpdate]:
|
||||
assert (
|
||||
mm_items.get_count("image", strict=False) == 0
|
||||
or "aspect_ratios" in out_mm_kwargs
|
||||
), "Transformers expect to include aspect_ratios in out_mm_kwargs"
|
||||
|
||||
config = self.info.get_hf_config()
|
||||
vision_config = config.vision_config
|
||||
|
||||
num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_token = hf_processor.image_token
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
|
||||
return hf_processor._prompt_split_image(
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_patches_per_chunk=num_patches_per_chunk)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
(target_width,
|
||||
target_height) = self.info.get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
image_token = self.info.get_hf_processor().fake_image_token
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Mllama4MultiModalProcessor,
|
||||
info=Mllama4ProcessingInfo,
|
||||
dummy_inputs=Mllama4DummyInputsBuilder,
|
||||
)
|
||||
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.vision_model = Llama4VisionModel(config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "vision_model"))
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
self.config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
|
||||
self.language_model = _initialize_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
model_class=Llama4ForCausalLM,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return get_sampler()
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
|
||||
# num_images, 1, num_chunks, channel, image_size, image_size
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
# num_images x num_chunks, channel, image_size, image_size
|
||||
# TODO: confirm handling for variable lengths
|
||||
flat_pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
aspect_ratios = kwargs.pop("aspect_ratios", None)
|
||||
if not isinstance(aspect_ratios, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of aspect_ratios. "
|
||||
f"Got type: {type(aspect_ratios)}")
|
||||
|
||||
return Llama4ImagePatchInputs(
|
||||
type="pixel_values",
|
||||
flat_data=flat_pixel_values,
|
||||
patches_per_image=patches_per_image,
|
||||
embed_is_patch=embed_is_patch,
|
||||
aspect_ratios=aspect_ratios,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
|
||||
flat_data = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
# num_images x [num_chunks, num_patches, hidden_dim]
|
||||
image_features = self._process_image_input(image_input)
|
||||
# num_images x [num_chunks x num_patches, hidden_dim]
|
||||
image_features_flat = [img.flatten(0, 1) for img in image_features]
|
||||
# num_images x [1, input_len] -> num_images x [input_len]
|
||||
embed_is_patch_flat = [
|
||||
is_patch.flatten(0, 1)
|
||||
for is_patch in image_input["embed_is_patch"]
|
||||
]
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features_flat,
|
||||
embed_is_patch_flat,
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
multimodal_embeddings = torch.cat(multimodal_embeddings)
|
||||
mm_embeddings = self.multi_modal_projector(multimodal_embeddings)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, select_patch_features(mm_embeddings),
|
||||
self.config.image_token_index)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner,
|
||||
# this condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
return self.language_model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def separate_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
prefix: str,
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[
|
||||
str, torch.Tensor]]]:
|
||||
weights1, weights2 = tee(weights, 2)
|
||||
|
||||
def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
for name, data in weights1:
|
||||
if name.startswith(prefix):
|
||||
yield (name, data)
|
||||
|
||||
def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
for name, data in weights2:
|
||||
if not name.startswith(prefix):
|
||||
yield (name, data)
|
||||
|
||||
return get_prefix_weights(), get_other_weights()
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params: Set[str] = set()
|
||||
|
||||
# language_model is an Llama4ForCausalLM instance. We load it's
|
||||
# using llama4's load_weights routine.
|
||||
language_model_weights, other_weights = self.separate_weights(
|
||||
weights, prefix="language_model.model.")
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_language_model_params = loader.load_weights(
|
||||
language_model_weights)
|
||||
assert loaded_language_model_params is not None
|
||||
updated_params.update(loaded_language_model_params)
|
||||
|
||||
for name, loaded_weight in other_weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
updated_params.add(name)
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(name)
|
||||
return updated_params
|
@ -196,6 +196,7 @@ _MULTIMODAL_MODELS = {
|
||||
# [Encoder-decoder]
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
|
||||
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
|
||||
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
@ -19,9 +19,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, Set, Tuple, Type
|
||||
from typing import Iterable, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -124,7 +125,7 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
|
||||
def _init_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
|
@ -22,9 +22,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -39,7 +38,7 @@ class TeleFLMModel(LlamaModel):
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer,
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
|
@ -96,6 +96,183 @@ class FlashAttentionMetadata:
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
page_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
||||
q_seqlens).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
||||
attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = \
|
||||
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1),
|
||||
attn_chunk_size)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
||||
.astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1],
|
||||
attn_chunk_size,
|
||||
dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||
(rarange * attn_chunk_size + \
|
||||
np.repeat(tokens_in_last_block, local_blocks))
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // page_size
|
||||
assert attn_chunk_size % page_size == 0, \
|
||||
f"attn_chunk_size {attn_chunk_size} is not " \
|
||||
f"divisible by page_size {page_size}"
|
||||
pages_per_local_batch = attn_chunk_size // page_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming page_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices= np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch)) \
|
||||
+ np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten()
|
||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch)
|
||||
block_table_local = block_table[batch_indices, block_indices]\
|
||||
.view(virtual_batches, -1)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||
block_table_local
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder:
|
||||
|
||||
@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
|
||||
non_blocking=True)
|
||||
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
||||
non_blocking=True)
|
||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table,
|
||||
self.runner.block_size,
|
||||
)
|
||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(
|
||||
virt_q_cu_seqlens_np).to(self.runner.device,
|
||||
non_blocking=True),
|
||||
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True),
|
||||
local_block_table=virt_block_table,
|
||||
local_max_query_len=seqlens_q_local_np.max(),
|
||||
local_max_seq_len=virt_k_seqlens_np.max(),
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# TODO: Optimize.
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.use_irope = use_irope
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
if not attn_metadata.use_cascade:
|
||||
# Regular attention (common case).
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if not attn_metadata.use_cascade or use_local_attn:
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
seqused_k=attn_metadata.seq_lens,
|
||||
max_seqlen_k=attn_metadata.max_seq_len,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return output
|
||||
|
||||
assert not use_local_attn, (
|
||||
"Cascade attention does not support local attention.")
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
|
@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.use_irope = use_irope
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -156,24 +158,41 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
sequesd_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
sequesd_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=attn_metadata.block_table,
|
||||
query_start_loc=attn_metadata.query_start_loc,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=sequesd_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
|
||||
return output
|
||||
|
@ -113,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user