From 55dcce91df150f576c28520d987eaf1498fcb0bd Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Mon, 7 Apr 2025 08:06:27 -0700 Subject: [PATCH] Upstream Llama4 Support to Main (#16113) Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com> Signed-off-by: Chris Thi Signed-off-by: drisspg Signed-off-by: Jon Swenson Signed-off-by: Keyun Tong Signed-off-by: Lu Fang Signed-off-by: Xiaodong Wang Signed-off-by: Yang Chen Signed-off-by: Ye (Charlotte) Qi Signed-off-by: Yong Hoon Shin Signed-off-by: Zijing Liu Signed-off-by: Lu Fang Signed-off-by: Lu Fang Signed-off-by: Lucia Fang Signed-off-by: Roger Wang Signed-off-by: DarkLight1337 Co-authored-by: Lu Fang Co-authored-by: Roger Wang Co-authored-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 3 +- benchmarks/kernels/benchmark_moe.py | 3 + docs/source/models/supported_models.md | 11 +- examples/offline_inference/audio_language.py | 2 +- examples/offline_inference/vision_language.py | 37 + .../vision_language_multi_image.py | 38 + requirements/common.txt | 2 +- requirements/test.in | 2 +- requirements/test.txt | 2 +- .../audio_language/test_ultravox.py | 14 +- .../vision_language/test_models.py | 37 +- .../vision_language/test_phi3v.py | 9 + .../vision_language/test_pixtral.py | 2 + .../multimodal/processing/test_common.py | 1 + .../multimodal/processing/test_llama4.py | 99 ++ tests/models/registry.py | 6 +- tests/models/test_initialization.py | 17 +- vllm/config.py | 2 + vllm/entrypoints/chat_utils.py | 2 +- ...=1024,device_name=AMD_Instinct_MI300X.json | 200 ++++ .../layers/fused_moe/cutlass_moe.py | 15 +- .../layers/fused_moe/fused_moe.py | 28 +- vllm/model_executor/layers/fused_moe/layer.py | 65 +- vllm/model_executor/layers/layernorm.py | 7 +- .../layers/quantization/awq_marlin.py | 5 + .../compressed_tensors_moe.py | 38 +- .../layers/quantization/experts_int8.py | 27 +- .../model_executor/layers/quantization/fp8.py | 2 + .../layers/quantization/gguf.py | 6 + .../layers/quantization/gptq_marlin.py | 5 + .../layers/quantization/moe_wna16.py | 33 +- .../layers/quantization/quark/quark_moe.py | 30 +- .../model_executor/layers/rotary_embedding.py | 68 ++ vllm/model_executor/model_loader/loader.py | 4 +- vllm/model_executor/models/llama.py | 26 +- vllm/model_executor/models/llama4.py | 531 +++++++++++ vllm/model_executor/models/mllama4.py | 895 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/telechat2.py | 5 +- vllm/model_executor/models/teleflm.py | 5 +- vllm/v1/attention/backends/flash_attn.py | 250 ++++- vllm/v1/attention/backends/triton_attn.py | 55 +- vllm/v1/worker/gpu_model_runner.py | 1 + 43 files changed, 2436 insertions(+), 155 deletions(-) create mode 100644 tests/models/multimodal/processing/test_llama4.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/models/llama4.py create mode 100644 vllm/model_executor/models/mllama4.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0b775851..55530d0d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -389,7 +389,8 @@ steps: - pytest -v -s models/test_transformers.py - pytest -v -s models/test_registry.py # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4' + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' - label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd] diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index f1803b39..afe0b530 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -553,6 +553,9 @@ def main(args: argparse.Namespace): intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size else: + if not hasattr(config, "hidden_size"): + # Support for llama4 + config = config.text_config # Default: Mixtral. E = config.num_local_experts topk = config.num_experts_per_tok diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 8b568de7..2fb969ea 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -24,7 +24,7 @@ vLLM also supports model implementations that are available in Transformers. Thi To check if the modeling backend is Transformers, you can simply do this: -```python +```python from vllm import LLM llm = LLM(model=..., task="generate") # Name or path of your model llm.apply_model(lambda model: print(type(model))) @@ -55,7 +55,7 @@ If your model is neither supported natively by vLLM or Transformers, you can sti Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers. Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM! -```python +```python from vllm import LLM llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model llm.apply_model(lambda model: print(model.__class__)) @@ -850,6 +850,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Llama4ForConditionalGeneration` + * Llama-4-17B-Omni-Instruct + * T + I+ + * `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. + * + * ✅︎ + * ✅︎ - * `LlavaForConditionalGeneration` * LLaVA-1.5 * T + IE+ diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 840892ea..f33efbab 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - max_num_seqs=5, + max_num_seqs=2, limit_mm_per_prompt={"audio": audio_count}, ) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index c1115708..61d53dda 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: ) +def run_llama4(questions: list[str], modality: str): + assert modality == "image" + + model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=4, + tensor_parallel_size=8, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + gpu_memory_utilization=0.4, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [[{ + "role": + "user", + "content": [{ + "type": "image" + }, { + "type": "text", + "text": f"{question}" + }] + }] for question in questions] + prompts = tokenizer.apply_chat_template(messages, + add_generation_prompt=True, + tokenize=False) + stop_token_ids = None + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=stop_token_ids, + ) + + # Molmo def run_molmo(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -907,6 +943,7 @@ model_example_map = { "minicpmv": run_minicpmv, "mistral3": run_mistral3, "mllama": run_mllama, + "llama4": run_llama4, "molmo": run_molmo, "NVLM_D": run_nvlm_d, "paligemma": run_paligemma, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 39951e5e..e03ebe48 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=4, + tensor_parallel_size=8, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -567,6 +604,7 @@ model_example_map = { "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, + "llama4": load_llama4, "mistral3": load_mistral3, "mllama": load_mllama, "NVLM_D": load_nvlm_d, diff --git a/requirements/common.txt b/requirements/common.txt index 7365a5b4..24a1e6d6 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -6,7 +6,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.50.3 +transformers >= 4.51.0 huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. diff --git a/requirements/test.in b/requirements/test.in index 364747e9..ac7f451e 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test -transformers==4.50.3 +transformers==4.51.0 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. # quantization bitsandbytes>=0.45.3 diff --git a/requirements/test.txt b/requirements/test.txt index 236b8be3..39d6ed1a 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -645,7 +645,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.50.3 +transformers==4.51.0 # via # -r requirements/test.in # genai-perf diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 83ece5d2..a843e41a 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -12,6 +12,7 @@ from vllm.sequence import SampleLogprobs from ....conftest import HfRunner, VllmRunner from ....utils import RemoteOpenAIServer +from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" @@ -55,7 +56,10 @@ def server(request, audio_assets): for key, value in request.param.items() ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, + args, + env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": + "30"}) as remote_server: yield remote_server @@ -106,6 +110,10 @@ def run_test( **kwargs, ): """Inference result should be the same between hf and vllm.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it @@ -156,6 +164,10 @@ def run_multi_audio_test( num_logprobs: int, **kwargs, ): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + with vllm_runner(model, dtype=dtype, enforce_eager=True, diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3b34f012..9d9e8278 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -160,17 +160,32 @@ VLM_TEST_SETTINGS = { ), "aya_vision": VLMTestInfo( models=["CohereForAI/aya-vision-8b"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + test_type=(VLMTestType.IMAGE), prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 single_image_prompts=IMAGE_ASSETS.prompts({ "stop_sign": "What's the content in the center of the image?", # noqa: E501 "cherry_blossom": "What is the season?", # noqa: E501 }), multi_image_prompt="Describe the two images in detail.", # noqa: E501 - max_model_len=8192, + max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}} + vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}, + ), + "aya_vision-multi_image": VLMTestInfo( + models=["CohereForAI/aya-vision-8b"], + test_type=(VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "What's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "What is the season?", # noqa: E501 + }), + multi_image_prompt="Describe the two images in detail.", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}, + marks=[large_gpu_mark(min_gb=32)], ), "blip2": VLMTestInfo( # TODO: Change back to 2.7b once head_dim = 80 is supported @@ -303,6 +318,22 @@ VLM_TEST_SETTINGS = { use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "llama4": VLMTestInfo( + models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"], + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda _: "<|image|>", + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + distributed_executor_backend="mp", + image_size_factors=[(.25, 0.5, 1.0)], + hf_model_kwargs={"device_map": "auto"}, + max_model_len=8192, + max_num_seqs=4, + dtype="bfloat16", + auto_cls=AutoModelForImageTextToText, + tensor_parallel_size=8, + vllm_runner_kwargs={"gpu_memory_utilization": 0.8}, + marks=multi_gpu_marks(num_gpus=8), + ), "llava_next": VLMTestInfo( models=["llava-hf/llava-v1.6-mistral-7b-hf"], test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS), diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index 53b183b2..237d499d 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -5,7 +5,9 @@ import re from typing import Optional import pytest +from packaging.version import Version from transformers import AutoTokenizer +from transformers import __version__ as TRANSFORMERS_VERSION from vllm.multimodal.image import rescale_image_size from vllm.platforms import current_platform @@ -81,6 +83,13 @@ def run_test( from transformers import AutoImageProcessor # noqa: F401 from transformers import AutoProcessor # noqa: F401 + # Once the model repo is updated to 4.49, we should be able to run the + # test in `test_models.py` without the above workaround + if Version(TRANSFORMERS_VERSION) >= Version("4.49"): + pytest.skip(f"`transformers=={TRANSFORMERS_VERSION}` installed, " + "but `transformers<=4.49` is required to run this model. " + "Reason: Cannot run HF implementation") + # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index ee619d8d..2f14a8ea 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -176,6 +176,8 @@ def test_chat( model, dtype=dtype, tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", max_model_len=max_model_len, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index fdcd7a9e..35334ef1 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -257,6 +257,7 @@ def _test_processing_correctness_mistral( "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", "HuggingFaceM4/Idefics3-8B-Llama3", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py new file mode 100644 index 00000000..7ec7c800 --- /dev/null +++ b/tests/models/multimodal/processing/test_llama4.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Llama4's multimodal preprocessing kwargs.""" + +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.transformers_utils.tokenizer import encode_tokens + +from ....conftest import _ImageAssets +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", + ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) +@pytest.mark.parametrize("mm_processor_kwargs", [{}]) +@pytest.mark.parametrize("num_imgs", [1, 5]) +@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) +@pytest.mark.parametrize("tokenized_prompt", [True, False]) +def test_processor_override( + image_assets: _ImageAssets, + model_id: str, + mm_processor_kwargs: dict, + num_imgs: int, + disable_mm_preprocessor_cache: bool, + tokenized_prompt: bool, +): + """Ensure llama4 processor works properly.""" + ctx = build_model_context( + model_id, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": num_imgs}, + disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + config = processor.info.get_hf_config() + tokenizer = processor.info.get_tokenizer() + hf_processor = processor.info.get_hf_processor() + vocab = tokenizer.get_vocab() + + prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \ + + "<|image|>" * num_imgs \ + + "<|eot|><|header_start|>assistant<|header_end|>" + mm_data = { + "image": [ + image_assets[(i % len(image_assets))].pil_image + for i in range(num_imgs) + ] + } + if tokenized_prompt: + prompt = encode_tokens(tokenizer, prompt) + + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + mm_kwargs = processed_inputs["mm_kwargs"] + + # place holder replacements + prompt_token_ids = processed_inputs["prompt_token_ids"] + assert prompt_token_ids.count(config.boi_token_index) == num_imgs + assert prompt_token_ids.count(config.eoi_token_index) == num_imgs + assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs + aspect_ratios = mm_kwargs["aspect_ratios"] + num_x_separators = num_y_separators = 0 + for tiles_y, tiles_x in aspect_ratios: + if tiles_x * tiles_y > 1: + num_x_separators += (tiles_x - 1) * tiles_y + num_y_separators += tiles_y + assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \ + == num_x_separators + assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \ + == num_y_separators + + # image token offsets + img_locs = processed_inputs["mm_placeholders"].get("image", []) + assert len(img_locs) == num_imgs + assert [img_loc["offset"] for img_loc in img_locs] == \ + [i for i, v in enumerate(prompt_token_ids) \ + if v == config.boi_token_index] + + # patch sizes and masks + assert prompt_token_ids.count(config.image_token_index) \ + == sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"]) + patch_token_id = vocab[hf_processor.img_patch_token] + num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id) + mm_counts = {"image": num_imgs} + assert num_patches / num_imgs <= \ + processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"] + num_patches_per_chunk = processor.info.get_patch_per_chunk( + config.vision_config) + assert prompt_token_ids.count(config.image_token_index) \ + == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk + assert mm_kwargs["pixel_values"].shape[0] \ + == mm_kwargs["patches_per_image"].sum() + + for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"], + mm_kwargs["aspect_ratios"]): + assert embed_is_patch.shape[0] == \ + len(tokenizer.encode( + hf_processor._prompt_split_image( + aspect_ratio, num_patches_per_chunk), + add_special_tokens=False)) diff --git a/tests/models/registry.py b/tests/models/registry.py index 574b8d9e..e61cbc57 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -287,12 +287,16 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", - extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501 + extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 + max_transformers_version="4.48", # noqa: E501 + transformers_version_reason="HF model is not compatible."), # noqa: E501 "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501 trust_remote_code=True), "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 + "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 + min_transformers_version="4.51"), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 58705637..cd2b8f00 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -7,6 +7,8 @@ from transformers import PretrainedConfig from vllm import LLM from vllm.engine.llm_engine import LLMEngine as V0LLMEngine +from vllm.utils import GiB_bytes +from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.engine.core import EngineCore as V1EngineCore from .registry import HF_EXAMPLE_MODELS @@ -42,14 +44,21 @@ def test_can_initialize(model_arch): self.cache_config.num_gpu_blocks = 0 self.cache_config.num_cpu_blocks = 0 - def _initalize_kv_caches_v1(self, vllm_config): - # gpu_blocks (> 0), cpu_blocks - return 1, 0 + def _initialize_kv_caches_v1(self, vllm_config): + kv_cache_specs = self.model_executor.get_kv_cache_specs() + scheduler_kv_cache_config = get_kv_cache_config( + vllm_config, + kv_cache_specs[0], + 20 * GiB_bytes, + ) + + # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config + return 1, 0, scheduler_kv_cache_config with (patch.object(V0LLMEngine, "_initialize_kv_caches", _initialize_kv_caches_v0), patch.object(V1EngineCore, "_initialize_kv_caches", - _initalize_kv_caches_v1)): + _initialize_kv_caches_v1)): LLM( model_info.default, tokenizer=model_info.tokenizer, diff --git a/vllm/config.py b/vllm/config.py index d6f931ca..c232f0f5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -358,6 +358,8 @@ class ModelConfig: self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) + self.attention_chunk_size = getattr(self.hf_text_config, + "attention_chunk_size", None) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=hf_token, revision=revision) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 9041b92a..d7e8d045 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -500,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): "internvl_chat", "skywork_chat", "NVLM_D", "h2ovl_chat", "idefics3"): return "" - if model_type == "mllama": + if model_type in ("mllama", "llama4"): return "<|image|>" if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|image_pad|><|vision_end|>" diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 00000000..f10e3948 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index a17afd1b..d6a27aa0 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -23,6 +23,7 @@ def cutlass_moe_fp8( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.half, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -96,8 +97,14 @@ def cutlass_moe_fp8( n = w2_q.size(1) topk = topk_ids.size(1) + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) + if apply_router_weight_on_input: + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + # TODO: this only works for topK=1, will need to update for topK>1 + a = a * topk_weights.to(out_dtype) a_q, a1_scale = ops.scaled_fp8_quant( a, a1_scale, use_per_token_if_dynamic=per_act_token) @@ -139,6 +146,8 @@ def cutlass_moe_fp8( ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, expert_offsets[:-1], problem_sizes2, ab_strides2, ab_strides2, c_strides2) - - return (c2[c_map].view(m, topk, k) * - topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) + # Gather tokens + c2 = c2[c_map].view(m, topk, k) + if not apply_router_weight_on_input: + c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) + return c2.sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0817879c..4ab99acb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, use_fp8_w8a8, use_int8_w8a16, - use_int4_w4a16, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + activation, apply_router_weight_on_input, use_fp8_w8a8, + use_int8_w8a16, use_int4_w4a16, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale, block_shape) def inplace_fused_experts_fake( @@ -980,6 +981,7 @@ def inplace_fused_experts_fake( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1010,6 +1012,7 @@ def outplace_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1023,10 +1026,11 @@ def outplace_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, - False, activation, use_fp8_w8a8, use_int8_w8a16, - use_int4_w4a16, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape) + False, activation, apply_router_weight_on_input, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, + global_num_experts, expert_map, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) def outplace_fused_experts_fake( @@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor, allow_deep_gemm: bool = False) -> torch.Tensor: if (allow_deep_gemm and use_fp8_w8a8 and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, @@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, @@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - False, + apply_router_weight_on_input, top_k_num, config, compute_type=compute_type, @@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - True, + not apply_router_weight_on_input, 1, config, compute_type=compute_type, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 661fb52b..0e35d8a8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -65,7 +65,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", ) -> torch.Tensor: raise NotImplementedError @@ -156,22 +158,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - return self.forward(x=x, - layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - global_num_experts=global_num_experts, - expert_map=expert_map, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - activation=activation) + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) def forward_cuda( self, @@ -188,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( @@ -202,15 +208,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map) + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) def forward_cpu( self, @@ -228,9 +236,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, **kwargs, ): assert activation == "silu", f"{activation} is not supported." + assert apply_router_weight_on_input is False return layer.ipex_fusion( x, use_grouped_topk, @@ -259,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert not use_grouped_topk @@ -266,6 +277,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): assert topk_group is None assert custom_routing_function is None assert layer is not None + assert apply_router_weight_on_input is False if scoring_func != "softmax": raise NotImplementedError( "Only softmax scoring function is supported for HPU.") @@ -290,12 +302,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None assert custom_routing_function is None + assert apply_router_weight_on_input is False if scoring_func != "softmax": raise NotImplementedError( "Only softmax scoring function is supported for TPU.") @@ -401,6 +415,7 @@ class FusedMoE(torch.nn.Module): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ): super().__init__() @@ -486,6 +501,7 @@ class FusedMoE(torch.nn.Module): self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None + self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, @@ -853,6 +869,7 @@ class FusedMoE(torch.nn.Module): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, ) if self.dp_size > 1: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 76d3acb9..5e8eb6c5 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -92,6 +92,7 @@ class RMSNorm(CustomOp): eps: float = 1e-6, var_hidden_size: Optional[int] = None, has_weight: bool = True, + dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() @@ -100,8 +101,10 @@ class RMSNorm(CustomOp): self.variance_size_override = (None if var_hidden_size == hidden_size else var_hidden_size) self.has_weight = has_weight - - self.weight = torch.ones(hidden_size) + if dtype is not None: + self.weight = torch.ones(hidden_size, dtype=dtype) + else: + self.weight = torch.ones(hidden_size) if self.has_weight: self.weight = nn.Parameter(self.weight) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 473816fc..cb1d5400 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." @@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase): raise NotImplementedError( "Expert Parallelism is not supported for " "fused Marlin MoE method.") + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused Marlin MoE method.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index bf32bee8..f573c8ae 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -240,20 +241,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): @@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: @@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, out_dtype=x.dtype, + apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." @@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): raise NotImplementedError( "Expert Parallelism is not supported for " "fused Marlin MoE method.") + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for " + "fused Marlin MoE method.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index d18ca55a..be19b809 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_int8_w8a16=True, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int8_w8a16=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale) @staticmethod def quantizing_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e7c733db..4435644c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): activation=activation, use_fp8_w8a8=True, global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, w1_scale=(layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale), diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9861e0a8..6b499f81 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ): assert activation == "silu", "Only SiLU activation is supported." + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused GGUF MoE method.") + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9f53ffc1..0615bb4a 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." + if apply_router_weight_on_input is not None: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused Marlin MoE method.") # The input must currently be float16 orig_dtype = x.dtype diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 41b75c9b..00c4b661 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -312,21 +313,23 @@ class MoeWNA16Method(FusedMoEMethodBase): weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp - return fused_experts(x, - layer.w13_qweight, - layer.w2_qweight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales, - w1_zp=layer.w13_qzeros if has_zp else None, - w2_zp=layer.w2_qzeros if has_zp else None, - block_shape=[0, layer.group_size]) + return fused_experts( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size]) @staticmethod def get_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index bc26a455..d1146c0f 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index fd27775b..624ed63a 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding): return new_freqs +class Llama4VisionRotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ): + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + inv_freqs = inv_freqs[:(self.rotary_dim // 2)] + return inv_freqs + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + + # self.max_position_embeddings here is number of image patches + # i.e. (image_size // patch_size) ** 2 + num_patches = self.max_position_embeddings + img_idx = torch.arange(num_patches, + dtype=torch.int32) \ + .reshape(num_patches, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN + num_patches_single_dim = int(math.sqrt(num_patches)) + frequencies_x = img_idx % num_patches_single_dim + frequencies_y = img_idx // num_patches_single_dim + freqs_x = ((frequencies_x + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs_y = ((frequencies_y + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], + dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + cache = torch.view_as_complex( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + return cache + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) + query_ = torch.view_as_complex(query.float().reshape( + *query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape( + *key.shape[:-1], -1, 2)) + broadcast_shape = [ + d if i == 1 or i == (query_.ndim - 1) else 1 + for i, d in enumerate(query_.shape) + ] + freqs_ci = self.cos_sin_cache.view(*broadcast_shape) + query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) + key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) + return query_out.type_as(query), key_out.type_as(key) + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1130,6 +1194,10 @@ def get_rope( scaling_factor, low_freq_factor, high_freq_factor, original_max_position) + elif scaling_type == "mllama4": + rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) elif scaling_type == "default": if "mrope_section" in rope_scaling: rotary_emb = MRotaryEmbedding( diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 5649cf2d..7e434388 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -111,10 +111,12 @@ def _initialize_model( vllm_config: VllmConfig, *, prefix: str = "", + model_class: Optional[type[nn.Module]] = None, ) -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config - model_class, _ = get_model_architecture(model_config) + if model_class is None: + model_class, _ = get_model_architecture(model_config) if vllm_config.quant_config is not None: configure_quant_config(vllm_config.quant_config, model_class) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 81b5d9bd..caa4a510 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union import torch from torch import nn @@ -65,6 +65,7 @@ class LlamaMLP(nn.Module): quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", + reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -79,6 +80,7 @@ class LlamaMLP(nn.Module): output_size=hidden_size, bias=bias, quant_config=quant_config, + reduce_results=reduce_results, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": @@ -292,7 +294,7 @@ class LlamaModel(nn.Module): *, vllm_config: VllmConfig, prefix: str = "", - layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): + layer_type: type[nn.Module] = LlamaDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config @@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "ffn_norm": "post_attention_layernorm", "tok_embeddings": "model.embed_tokens", "output": "lm_head", - "norm": "model.norm" + "norm": "model.norm", } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): - return LlamaModel(vllm_config=vllm_config, prefix=prefix) + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer): + return LlamaModel(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py new file mode 100644 index 00000000..029f6044 --- /dev/null +++ b/vllm/model_executor/models/llama4.py @@ -0,0 +1,531 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import Llama4TextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter) + + +class Llama4MoE(nn.Module): + + @staticmethod + def custom_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) + router_scores = torch.sigmoid(router_scores.float()).to( + hidden_states.dtype) + return (router_scores, router_indices.to(torch.int32)) + + def __init__(self, + config: Llama4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.top_k = config.num_experts_per_tok + + intermediate_size_moe = config.intermediate_size + self.router = ReplicatedLinear(config.hidden_size, + config.num_local_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.router") + + self.experts = FusedMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + custom_routing_function=Llama4MoE.custom_routing_function, + intermediate_size=intermediate_size_moe, + apply_router_weight_on_input=True, + reduce_results=False, + renormalize=False, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + reduce_results=False, # We need to do scatter before reduce + ) + + def forward(self, hidden_states): + router_logits, _ = self.router(hidden_states) + shared_out = self.shared_expert(hidden_states) + routed_out = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + experts_out = routed_out + shared_out + + if self.tp_size > 1: + experts_out = tensor_model_parallel_all_reduce(experts_out) + + return experts_out + + +class Llama4Attention(nn.Module): + + def __init__(self, + config: Llama4TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + self.no_rope_layers = config.no_rope_layers + self.nope = self.no_rope_layers[self.layer_idx] == 0 + self.use_qk_norm = config.use_qk_norm and not self.nope + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + # TODO: attn_temperature_tuning should be a bool in huggingface + self.attn_temperature_tuning = self.nope and \ + config.attn_temperature_tuning > 0 + + self.floor_scale = getattr(config, "floor_scale", 8192.0) + self.attn_scale = getattr(config, "attn_scale", 0.1) + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.n_rep = self.num_heads // self.num_kv_heads + self.q_norm = RMSNorm( + hidden_size=self.q_size, + eps=config.rms_norm_eps, + has_weight=False, + dtype=torch.float32, + ) if self.use_qk_norm else None + self.k_norm = RMSNorm( + hidden_size=self.kv_size, + eps=config.rms_norm_eps, + has_weight=False, + dtype=torch.float32, + ) if self.use_qk_norm else None + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling if rope_scaling != "default" else None, + is_neox_style=is_neox_style, + ) if not self.nope else None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=None, + use_irope=not self.nope, + prefix=f"{prefix}.attn", + ) + + def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: + floor = torch.floor((positions + 1.0) / self.floor_scale) + attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 + + return attn_scale.unsqueeze(-1) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + if self.rotary_emb is not None: + q, k = self.rotary_emb(positions, q, k) + if self.q_norm is not None: + q = self.q_norm(q.float()).to(q.dtype) + if self.k_norm is not None: + k = self.k_norm(k.float()).to(k.dtype) + + # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) + # to NoPE layers, where the inference-time temperature tuning function + # is customized to not affect short context + # while working at very long context + # https://arxiv.org/abs/2501.19399 + # + # We should apply temperature tuning between (after) rotary / QK norm + # and (before) attention. + if self.attn_temperature_tuning and self.nope: + attn_scale = self._get_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Llama4DecoderLayer(nn.Module): + + def __init__( + self, + config: Llama4TextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.layer_idx = extract_layer_index(prefix) + self.hidden_size = config.hidden_size + rope_theta = config.rope_theta + rope_scaling = config.rope_scaling + max_position_embeddings = config.max_position_embeddings + + self.self_attn = Llama4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + bias_o_proj=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + is_moe_layer = (self.layer_idx + + 1) % config.interleave_moe_layer_step == 0 + if is_moe_layer: + self.feed_forward = Llama4MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size_mlp, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.feed_forward", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Llama4Model(LlamaModel): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer): + self.num_experts = vllm_config.model_config.hf_config.num_local_experts + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def load_moe_expert_weights( + self, + name: str, + loaded_weight: torch.Tensor, + params_dict: Dict[str, nn.Parameter], + loaded_params: Set[str], + expert_params_mapping: List[Tuple[str, str, int, str]], + fused: bool = True, + ) -> bool: + expert_param_loaded = False + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-1) + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + new_loaded_weight = loaded_weight + if fused: + e_str, _, proj_str, _ = weight_name.split('.') + weight_name = f"{e_str}.{proj_str}" + param_name = f"{param_name}weight" + if weight_name not in name: + continue + full_param_name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[full_param_name] + weight_loader = param.weight_loader + if fused: + if "w13" in full_param_name: + shard_idx = 0 if shard_id == "w1" else 1 + new_loaded_weight = new_loaded_weight[shard_idx] + new_loaded_weight = new_loaded_weight.transpose(-1, -2) + layer_idx = extract_layer_index(name) + # EP mapping + expert_map = self.layers[ + layer_idx].feed_forward.experts.expert_map + if expert_map is not None: + local_expert_indices = (expert_map != -1) \ + .nonzero() \ + .flatten() \ + .to(new_loaded_weight.device) + new_loaded_weight = new_loaded_weight[local_expert_indices] + expert_id = local_expert_indices[0].item() + else: + # TODO: add EP support for non fused weights + pass + weight_loader(param, + new_loaded_weight, + full_param_name, + shard_id=shard_id, + expert_id=expert_id) + + loaded_params.add(full_param_name) + expert_param_loaded = True + return expert_param_loaded + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + fused_experts_params = False + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.num_experts) + expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_up_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="gate_up_proj", + num_experts=1) + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + fused_experts_params = True + expert_params_mapping = expert_params_mapping_fused + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name or "experts" in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + moe_loaded = self.load_moe_expert_weights( + name, + loaded_weight, + params_dict, + loaded_params, + expert_params_mapping, + fused=fused_experts_params) + + if not moe_loaded: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Llama4ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Update temperature tuning config from generation config + gen_config = vllm_config.model_config.try_get_generation_config() + gen_config.update(vllm_config.model_config.override_generation_config) + vllm_config.model_config.hf_config.attn_temperature_tuning \ + = gen_config.get("attn_temperature_tuning", False) + + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=Llama4DecoderLayer) + + def _init_model(self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer): + return Llama4Model(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + weights = [ + self.permute_qk_weight_for_rotary(name, loaded_weight) + for name, loaded_weight in weights + ] + return loader.load_weights(weights) + + def permute_qk_weight_for_rotary( + self, + name: str, + loaded_weight: torch.Tensor, + ) -> Tuple[str, torch.Tensor]: + + def permute(w: torch.Tensor, n_heads: int): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + modules = name.split(".") + + # rotary embeds should be sliced + if ("wk" in modules or "k_proj" in modules) \ + and modules[-1] == "weight": + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif ("wq" in modules or "q_proj" in modules) \ + and modules[-1] == "weight": + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + return name, loaded_weight diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py new file mode 100644 index 00000000..dae98093 --- /dev/null +++ b/vllm/model_executor/models/mllama4.py @@ -0,0 +1,895 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections.abc import Iterable, Mapping +from functools import cached_property +from itertools import tee +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import BatchFeature, Llama4Config, Llama4VisionConfig +from transformers.image_utils import SizeDict +from transformers.models.llama4 import Llama4Processor +from transformers.models.llama4.image_processing_llama4_fast import ( + find_supported_resolutions, get_best_fit) + +from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import InputProcessingContext +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .llama4 import Llama4ForCausalLM +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features + +logger = init_logger(__name__) + + +class Llama4ImagePatchInputs(TypedDict): + type: Literal["pixel_values"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_chunks, num_channels, image size, image size)` + """ + patches_per_image: torch.Tensor + """ + The number of total patches for each image in the batch. + + This is used to split the embeddings which has the first two dimensions + flattened just like `flat_data`. + """ + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + """ + aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] + """ + A list of aspect ratios corresponding to the number of tiles + in each dimension that each image in the batch corresponds to. + + Shape: + `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)` + """ + + +class Llama4VisionMLP(nn.Module): + + def __init__(self, + input_size: int, + intermediate_size: int, + output_size: int, + bias: bool, + output_activation: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.fc1 = ColumnParallelLinear( + input_size=input_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + input_size=intermediate_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.activation_fn = nn.GELU() + self.output_activation = output_activation + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + if self.output_activation: + return self.activation_fn(hidden_states) + return hidden_states + + +class Llama4MultiModalProjector(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear_1 = ColumnParallelLinear( + input_size=config.vision_config.vision_output_dim, + output_size=config.text_config.hidden_size, + bias=False, + quant_config=quant_config, + gather_output=True, + prefix=f"{prefix}.linear_1", + ) + + def forward(self, image_features): + hidden_states, _ = self.linear_1(image_features) + return hidden_states + + +def pixel_shuffle(input_tensor, shuffle_ratio): + # input_tensor: [batch_size, num_patches, channels] + batch_size, num_patches, channels = input_tensor.shape + patch_size = int(math.sqrt(num_patches)) + + input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) + batch_size, height, width, channels = input_tensor.size() + + reshaped_tensor = input_tensor.view(batch_size, height, + int(width * shuffle_ratio), + int(channels / shuffle_ratio)) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + reshaped_tensor = reshaped_tensor.view(batch_size, + int(height * shuffle_ratio), + int(width * shuffle_ratio), + int(channels / (shuffle_ratio**2))) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + output_tensor = reshaped_tensor.view(batch_size, -1, + reshaped_tensor.shape[-1]) + return output_tensor + + +class Llama4VisionPixelShuffleMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.pixel_shuffle_ratio = config.pixel_shuffle_ratio + self.inner_dim = int(config.projector_input_dim // + (self.pixel_shuffle_ratio**2)) + self.output_dim = config.projector_output_dim + self.mlp = Llama4VisionMLP( + input_size=config.intermediate_size, + intermediate_size=config.projector_input_dim, + output_size=config.projector_output_dim, + bias=config.multi_modal_projector_bias, + output_activation=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: + encoded_patches = pixel_shuffle(encoded_patches, + self.pixel_shuffle_ratio) + return self.mlp(encoded_patches) + + +class Llama4VisionAttention(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ): + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // self.num_heads + assert self.num_heads % self.tp_size == 0 + self.num_local_heads = self.num_heads // self.tp_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + self.attention_dropout = config.attention_dropout + self.scaling = self.head_dim**-0.5 + + self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, + self.scaling) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=True, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=config.hidden_size // config.num_attention_heads // 2, + # number of image patches + max_position=(config.image_size // config.patch_size)**2, + base=config.rope_theta, + rope_scaling={"rope_type": "mllama4"}, + is_neox_style=False, + dtype=torch.complex64, # important + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_shape = hidden_states.shape[:-1] + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim) + q, k = self.rotary_emb(q, k) + + q = q.view(q.shape[0], q.shape[1], -1) + k = k.view(k.shape[0], k.shape[1], -1) + + attn_output = self.attn(q, k, v) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output, _ = self.o_proj(attn_output) + + return attn_output + + +class Llama4VisionEncoderLayer(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.intermediate_size = config.intermediate_size + + self.self_attn = Llama4VisionAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = Llama4VisionMLP(input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + output_activation=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = nn.LayerNorm(config.hidden_size) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_state: torch.Tensor, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state) + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = residual + hidden_state + + outputs = (hidden_state, ) + return outputs + + +class Llama4VisionEncoder(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Llama4VisionEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you + want more control over how to convert `input_ids` indices into + associated vectors than the model's internal embedding + lookup matrix. + """ + + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs[0] + + return hidden_states + + +class Llama4UnfoldConvolution(nn.Module): + + def __init__(self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + kernel_size = config.patch_size + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, + stride=config.patch_size) + self.linear = ColumnParallelLinear(config.num_channels * + kernel_size[0] * kernel_size[1], + config.hidden_size, + bias=False, + quant_config=quant_config, + gather_output=True, + prefix=f"{prefix}.linear") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.unfold(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states, _ = self.linear(hidden_states) + return hidden_states + + +class Llama4VisionModel(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = Llama4UnfoldConvolution( + config, + quant_config=quant_config, + prefix=f"{prefix}.patch_embedding") + + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.positional_embedding_vlm = nn.Parameter( + self.scale * torch.randn(self.num_patches, self.hidden_size)) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5) + self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5) + + # encoders + self.model = Llama4VisionEncoder(config, + quant_config=quant_config, + prefix=f"{prefix}.model") + self.vision_adapter = Llama4VisionPixelShuffleMLP( + config, quant_config, prefix=f"{prefix}.vision_adapter") + + def forward( + self, + images_flattened: torch.Tensor, + ) -> torch.Tensor: + # Patch embedding + hidden_state = self.patch_embedding(images_flattened) + num_tiles, num_patches, hidden_dim = hidden_state.shape + + # Add cls token + class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, + hidden_state.shape[-1]) + hidden_state = torch.cat([hidden_state, class_embedding], dim=1) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + num_tiles, + 1, + num_patches, + hidden_dim, + ) + positional_embedding = self.positional_embedding_vlm.to( + dtype=hidden_state.dtype, device=hidden_state.device) + hidden_state = hidden_state + positional_embedding + hidden_state = self.layernorm_pre(hidden_state) + hidden_state = hidden_state.view(num_tiles, -1, hidden_dim) + + # Apply encoder + hidden_state = self.model(hidden_state) + hidden_state = self.layernorm_post(hidden_state) + + # Remove CLS token output + hidden_state = hidden_state[:, :-1, :] + + # now, we use Llama4VisionPixelShuffle + mlp to project embeddings + hidden_state = self.vision_adapter(hidden_state) + + return hidden_state + + +class Mllama4ProcessingInfo(BaseProcessingInfo): + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__(ctx) + + def get_hf_config(self) -> Llama4Config: + return self.ctx.get_hf_config(Llama4Config) + + def get_hf_processor(self, **kwargs: object) -> Llama4Processor: + return self.ctx.get_hf_processor(Llama4Processor, + use_fast=True, + **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 10} + + @staticmethod + def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int: + image_size = vision_config.image_size + patch_size = vision_config.patch_size + + assert ( + image_size % + patch_size == 0), f"chunk size {image_size} should be multiple of " + f"patch_size {patch_size}" + + ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2))) + return (image_size // patch_size)**2 // ds_ratio + + def get_max_num_tiles(self) -> int: + image_processor = self.get_hf_processor().image_processor + return image_processor.max_patches + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + vision_config = self.get_hf_config().vision_config + # image_start + local tiles * (patches + 1 x separator) + + # 1 global tile * (image x 1 + patches) + image_end + token_per_chunk = self.get_patch_per_chunk(vision_config) + 1 + mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2 + return {"image": mm_max_tokens} + + def get_image_size_with_most_features(self) -> ImageSize: + vision_config = self.get_hf_config().vision_config + image_size = vision_config.image_size + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=self.get_max_num_tiles() * image_size, + width=image_size) + + +class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] + ): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + + if mm_data is None: + return tokenizer(prompt, add_special_tokens=False) # exclude bos + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + processor = self.info.get_hf_processor(**mm_kwargs) + image_processor = processor.image_processor + vision_config = self.info.get_hf_config().vision_config + + if processed_outputs.get("pixel_values") is not None: + assert "images" in mm_data, \ + "images expected to be in mm_data when pixel_values is present" + + images = mm_data["images"] + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems)) + + tile_size = vision_config.image_size + possible_resolutions = find_supported_resolutions( + max_num_chunks=self.info.get_max_num_tiles(), + patch_size=SizeDict(height=tile_size, width=tile_size), + ) + best_fit_sizes = [ + get_best_fit( + (image.size[1], image.size[0]), + torch.tensor(possible_resolutions), + resize_to_max_canvas=image_processor.resize_to_max_canvas) + for image in parsed_images + ] + # TODO tile height/width do not necessarily need to match + aspect_ratios = [(image_size[0] // tile_size, + image_size[1] // tile_size) + for image_size in best_fit_sizes] + patches_per_image = [ + 1 if r_h * r_w == 1 else 1 + r_h * r_w + for (r_h, r_w) in aspect_ratios + ] + + # embed_is_patch should have one feature per image-related token: + # <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|> + # -> False + # <|patch|> -> True + # embed_is_patch has no entries corresponding to non-image-related + # tokens. + patch_id = tokenizer.get_vocab()[processor.img_patch_token] + num_patches_per_chunk = self.info.get_patch_per_chunk( + vision_config) + expanded_image_tokens_list = [ + processor._prompt_split_image(aspect_ratio, + num_patches_per_chunk) + for aspect_ratio in aspect_ratios + ] + expanded_image_token_ids = [ + tokenizer.encode(image_tokens, add_special_tokens=False) + for image_tokens in expanded_image_tokens_list + ] + embed_is_patch = [ + torch.tensor(tokens) == patch_id + for tokens in expanded_image_token_ids + ] + + processed_outputs["aspect_ratios"] = aspect_ratios + processed_outputs["patches_per_image"] = torch.tensor( + patches_per_image) + processed_outputs["embed_is_patch"] = embed_is_patch + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0)) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", patches_per_image), + patches_per_image=MultiModalFieldConfig.batched("image"), + aspect_ratios=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> List[PromptUpdate]: + assert ( + mm_items.get_count("image", strict=False) == 0 + or "aspect_ratios" in out_mm_kwargs + ), "Transformers expect to include aspect_ratios in out_mm_kwargs" + + config = self.info.get_hf_config() + vision_config = config.vision_config + + num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.image_token + + def get_replacement(item_idx: int): + aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] + return hf_processor._prompt_split_image( + aspect_ratio=aspect_ratio, + num_patches_per_chunk=num_patches_per_chunk) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement, + ) + ] + + +class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + (target_width, + target_height) = self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + image_token = self.info.get_hf_processor().fake_image_token + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Mllama4MultiModalProcessor, + info=Mllama4ProcessingInfo, + dummy_inputs=Mllama4DummyInputsBuilder, +) +class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.vision_model = Llama4VisionModel(config.vision_config, + None, + prefix=maybe_prefix( + prefix, "vision_model")) + self.multi_modal_projector = Llama4MultiModalProjector( + self.config, + None, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + + self.language_model = _initialize_model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model"), + model_class=Llama4ForCausalLM, + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: + # num_images, 1, num_chunks, channel, image_size, image_size + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + # num_images x num_chunks, channel, image_size, image_size + # TODO: confirm handling for variable lengths + flat_pixel_values = flatten_bn(pixel_values, concat=True) + patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) + + embed_is_patch = kwargs.pop("embed_is_patch", None) + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + aspect_ratios = kwargs.pop("aspect_ratios", None) + if not isinstance(aspect_ratios, (torch.Tensor, list)): + raise ValueError("Incorrect type of aspect_ratios. " + f"Got type: {type(aspect_ratios)}") + + return Llama4ImagePatchInputs( + type="pixel_values", + flat_data=flat_pixel_values, + patches_per_image=patches_per_image, + embed_is_patch=embed_is_patch, + aspect_ratios=aspect_ratios, + ) + + def _process_image_input( + self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: + flat_data = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"].tolist() + vision_embeddings_flat = self.vision_model(flat_data) + return vision_embeddings_flat.split(patches_per_image, dim=0) + + def get_multimodal_embeddings(self, + **kwargs) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + # num_images x [num_chunks, num_patches, hidden_dim] + image_features = self._process_image_input(image_input) + # num_images x [num_chunks x num_patches, hidden_dim] + image_features_flat = [img.flatten(0, 1) for img in image_features] + # num_images x [1, input_len] -> num_images x [input_len] + embed_is_patch_flat = [ + is_patch.flatten(0, 1) + for is_patch in image_input["embed_is_patch"] + ] + + return scatter_patch_features( + image_features_flat, + embed_is_patch_flat, + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + multimodal_embeddings = torch.cat(multimodal_embeddings) + mm_embeddings = self.multi_modal_projector(multimodal_embeddings) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, select_patch_features(mm_embeddings), + self.config.image_token_index) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, + # this condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + return self.language_model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def separate_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str, + ) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[ + str, torch.Tensor]]]: + weights1, weights2 = tee(weights, 2) + + def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]: + for name, data in weights1: + if name.startswith(prefix): + yield (name, data) + + def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]: + for name, data in weights2: + if not name.startswith(prefix): + yield (name, data) + + return get_prefix_weights(), get_other_weights() + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + updated_params: Set[str] = set() + + # language_model is an Llama4ForCausalLM instance. We load it's + # using llama4's load_weights routine. + language_model_weights, other_weights = self.separate_weights( + weights, prefix="language_model.model.") + loader = AutoWeightsLoader(self) + loaded_language_model_params = loader.load_weights( + language_model_weights) + assert loaded_language_model_params is not None + updated_params.update(loaded_language_model_params) + + for name, loaded_weight in other_weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + weight_loader(param, loaded_weight) + updated_params.add(name) + return updated_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 080aef89..3abbb1f0 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -196,6 +196,7 @@ _MULTIMODAL_MODELS = { # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 + "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 062b1c2c..379e19e1 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -19,9 +19,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Set, Tuple, Type +from typing import Iterable, Set, Tuple import torch +import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -124,7 +125,7 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): def _init_model(self, vllm_config: VllmConfig, prefix: str = "", - layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): + layer_type: type[nn.Module] = LlamaDecoderLayer): return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) def load_weights(self, weights: Iterable[Tuple[str, diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py index e670b1df..e05f23f9 100644 --- a/vllm/model_executor/models/teleflm.py +++ b/vllm/model_executor/models/teleflm.py @@ -22,9 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type - import torch +import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -39,7 +38,7 @@ class TeleFLMModel(LlamaModel): *, vllm_config: VllmConfig, prefix: str = "", - layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer, + layer_type: type[nn.Module] = LlamaDecoderLayer, ): super().__init__(vllm_config=vllm_config, prefix=prefix, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 92e4ffd0..1a8d2420 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -96,6 +96,183 @@ class FlashAttentionMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. + # for local attention + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor + local_seqused_k: torch.Tensor + local_block_table: torch.Tensor + local_max_query_len: int + local_max_seq_len: int + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + +# +# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into +# local attention blocks, where each block is passed to the attention kernel +# as an independent local ("virtual") batch item. +# +# For example, if are performing a chunked prefill a batch of 3 sequences: +# q_seqlens = [4, 10, 5] +# kv_seqlens = [6, 17, 9] +# Then normally for regular attention we would compute with an attention mask +# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 1 1 1 1 +# 3 | 1 1 1 1 1 1 +# +# for local attention (with attn_chunk_size = 4) we would compute with an +# attention mask like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 +# 3 | 1 1 +# +# We can simulate this mask using standard flash-attention by breaking the +# sequences into local ("virtual") batches, where each local batch item is a +# local attention block, so in this case batch idx 0 would be broken up into: +# +# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) +# k_toks > 0 1 2 3 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) +# k_toks > 4 5 +# q_toks v _____________ +# 2 | 1 +# 3 | 1 1 +# +# e.g. if we have: +# attn_chunk_size = 4 +# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) +# Then this function would return: +# __b0__ ______b1______ __b2__ < orig batch indices +# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] +# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] +# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] +# block_table_local : shape[local_virtual_batches, pages_per_local_batch] +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + page_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), + q_seqlens).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, + attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = \ + np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), + attn_chunk_size)[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ + .astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], + attn_chunk_size, + dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ + (rarange * attn_chunk_size + \ + np.repeat(tokens_in_last_block, local_blocks)) + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // page_size + assert attn_chunk_size % page_size == 0, \ + f"attn_chunk_size {attn_chunk_size} is not " \ + f"divisible by page_size {page_size}" + pages_per_local_batch = attn_chunk_size // page_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming page_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices= np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch)) \ + + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten() + batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch) + block_table_local = block_table[batch_indices, block_indices]\ + .view(virtual_batches, -1) + + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ + block_table_local + class FlashAttentionMetadataBuilder: @@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder: def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - self.runner.device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, - non_blocking=True) + query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] + query_start_loc = query_start_loc_cpu.to(self.runner.device, + non_blocking=True) + seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] + seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ + virt_block_table = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[:num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table, + self.runner.block_size, + ) + local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=torch.from_numpy( + virt_q_cu_seqlens_np).to(self.runner.device, + non_blocking=True), + local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True), + local_block_table=virt_block_table, + local_max_query_len=seqlens_q_local_np.max(), + local_max_seq_len=virt_k_seqlens_np.max(), + ) + use_cascade = common_prefix_len > 0 if use_cascade: - # TODO: Optimize. cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, device=self.runner.device) @@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder: cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, ) return attn_metadata @@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl): blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, + use_irope: bool = False, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl): "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") + self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ and not flash_attn_supports_fp8(): @@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl): layer._k_scale, layer._v_scale, ) - descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, - key.shape[1]) + if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(torch.float8_e4m3fn) value_cache = value_cache.view(torch.float8_e4m3fn) @@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl): query = query.reshape((num_tokens, num_heads, head_size)) # Compute attention and update output up to `num_actual_tokens`. - if not attn_metadata.use_cascade: - # Regular attention (common case). + use_local_attn = \ + (self.use_irope and attn_metadata.local_attn_metadata is not None) + + if not attn_metadata.use_cascade or use_local_attn: + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + seqused_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + else: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, v=value_cache, out=output[:num_actual_tokens], - cu_seqlens_q=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - seqused_k=attn_metadata.seq_lens, - max_seqlen_k=attn_metadata.max_seq_len, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, - block_table=attn_metadata.block_table, + block_table=block_table, softcap=self.logits_soft_cap, fa_version=self.vllm_flash_attn_version, q_descale=layer._q_scale.expand(descale_shape), @@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl): ) return output + assert not use_local_attn, ( + "Cascade attention does not support local attention.") # Cascade attention (rare case). cascade_attention( output[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 15b49b14..5f961047 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl): blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, + use_irope: bool = False, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl): else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype + self.use_irope = use_irope assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -156,24 +158,41 @@ class TritonAttentionImpl(AttentionImpl): layer._v_scale, ) + use_local_attn = \ + (self.use_irope and attn_metadata.local_attn_metadata is not None) + + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + sequesd_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + else: + cu_seqlens_q = attn_metadata.query_start_loc + sequesd_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode( - query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=attn_metadata.block_table, - query_start_loc=attn_metadata.query_start_loc, - seq_lens=attn_metadata.seq_lens, - max_seq_len=attn_metadata.max_seq_len, - max_query_len=attn_metadata.max_query_len, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + chunked_prefill_paged_decode(query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=sequesd_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale) return output diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 82b07c6c..5133c637 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -113,6 +113,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() + self.attention_chunk_size = model_config.attention_chunk_size self.attn_backend = get_attn_backend( self.head_size,