# SPDX-License-Identifier: Apache-2.0 """Compare the outputs of HF and vLLM for Mistral models using greedy sampling. Run `pytest tests/models/test_mistral.py`. """ import json from dataclasses import asdict from typing import TYPE_CHECKING, Any, Optional import pytest from mistral_common.multimodal import download_image from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from transformers import AutoProcessor from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal.inputs import PlaceholderRange from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test from ...utils import check_logprobs_close if TYPE_CHECKING: from _typeshed import StrPath PIXTRAL_ID = "mistralai/Pixtral-12B-2409" MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID] IMG_URLS = [ "https://picsum.photos/id/237/400/300", "https://picsum.photos/id/231/200/300", "https://picsum.photos/id/27/500/500", "https://picsum.photos/id/17/150/600", ] PROMPT = "Describe each image in one short sentence." def _create_msg_format(urls: list[str]) -> list[dict[str, Any]]: return [{ "role": "user", "content": [{ "type": "text", "text": PROMPT, }] + [{ "type": "image_url", "image_url": { "url": url } } for url in urls], }] def _create_msg_format_hf(urls: list[str]) -> list[dict[str, Any]]: return [{ "role": "user", "content": [{ "type": "text", "content": PROMPT, }, *({ "type": "image", "image": download_image(url) } for url in urls)], }] def _create_engine_inputs(urls: list[str]) -> TokensPrompt: msg = _create_msg_format(urls) tokenizer = MistralTokenizer.from_model("pixtral") request = ChatCompletionRequest(messages=msg) # type: ignore[type-var] tokenized = tokenizer.encode_chat_completion(request) engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens) images = [] for chunk in request.messages[0].content: if isinstance(chunk, ImageURLChunk): images.append(image_from_chunk(chunk)) mm_data = MultiModalDataBuiltins(image=images) engine_inputs["multi_modal_data"] = mm_data return engine_inputs def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt: msg = _create_msg_format_hf(urls) tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b") prompt = tokenizer.apply_chat_template(msg) images = [] for chunk in msg[0]["content"]: if chunk["type"] == "image": images.append(chunk["image"]) mm_data = MultiModalDataBuiltins(image=images) engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data) return engine_inputs MSGS = [ _create_msg_format(IMG_URLS[:1]), _create_msg_format(IMG_URLS[:2]), _create_msg_format(IMG_URLS), ] ENGINE_INPUTS = [ _create_engine_inputs(IMG_URLS[:1]), _create_engine_inputs(IMG_URLS[:2]), _create_engine_inputs(IMG_URLS), ] SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) LIMIT_MM_PER_PROMPT = dict(image=4) MAX_MODEL_LEN = [8192, 65536] FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures" assert FIXTURES_PATH.exists() FIXTURE_LOGPROBS_CHAT = { PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json", MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json", } OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]] # For the test author to store golden output in JSON def _dump_outputs_w_logprobs( outputs: OutputsLogprobs, filename: "StrPath", ) -> None: json_data = [(tokens, text, [{ k: asdict(v) for k, v in token_logprobs.items() } for token_logprobs in (logprobs or [])]) for tokens, text, logprobs in outputs] with open(filename, "w") as f: json.dump(json_data, f) def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: with open(filename, "rb") as f: json_data = json.load(f) return [(tokens, text, [{ int(k): Logprob(**v) for k, v in token_logprobs.items() } for token_logprobs in logprobs]) for tokens, text, logprobs in json_data] @large_gpu_test(min_gb=80) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_chat( vllm_runner, max_model_len: int, model: str, dtype: str, ) -> None: EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs( FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( model, dtype=dtype, tokenizer_mode="mistral", max_model_len=max_model_len, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = [] for msg in MSGS: output = vllm_model.model.chat(msg, sampling_params=SAMPLING_PARAMS) outputs.extend(output) logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) # Remove last `None` prompt_logprobs to compare with fixture for i in range(len(logprobs)): assert logprobs[i][-1] is None logprobs[i] = logprobs[i][:-1] check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS, outputs_1_lst=logprobs, name_0="h100_ref", name_1="output") @large_gpu_test(min_gb=48) @pytest.mark.parametrize( "prompt,expected_ranges", [(_create_engine_inputs_hf(IMG_URLS[:1]), [{ "offset": 11, "length": 494 }]), (_create_engine_inputs_hf(IMG_URLS[1:4]), [{ "offset": 11, "length": 266 }, { "offset": 277, "length": 1056 }, { "offset": 1333, "length": 418 }])]) def test_multi_modal_placeholders(vllm_runner, prompt, expected_ranges: list[PlaceholderRange], monkeypatch) -> None: # This placeholder checking test only works with V0 engine # where `multi_modal_placeholders` is returned with `RequestOutput` monkeypatch.setenv("VLLM_USE_V1", "0") with vllm_runner( "mistral-community/pixtral-12b", max_model_len=8192, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = vllm_model.model.generate(prompt) assert len(outputs) == 1, f"{len(outputs)=}" output: RequestOutput = outputs[0] assert hasattr(output, "multi_modal_placeholders"), f"{output.__dict__=}" assert "image" in output.multi_modal_placeholders, \ f"{output.multi_modal_placeholders.keys()=}" image_placeholder_ranges: list[ PlaceholderRange] = output.multi_modal_placeholders["image"] assert len(image_placeholder_ranges) == len( expected_ranges), f"{image_placeholder_ranges=}" for real_range, expected_range in zip(image_placeholder_ranges, expected_ranges): assert real_range == expected_range, \ f"{real_range=} {expected_range=}"