[Mistral-Small 3.1] Update docs and tests (#14977)

Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Patrick von Platen 2025-03-18 11:29:42 +01:00 committed by GitHub
parent 400d483e87
commit f863ffc965
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 60 deletions

View File

@ -879,7 +879,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `PixtralForConditionalGeneration`
* Pixtral
* T + I<sup>+</sup>
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc.
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc.
*
* ✅︎
* ✅︎

View File

@ -6,14 +6,14 @@ import argparse
from vllm import LLM
from vllm.sampling_params import SamplingParams
# This script is an offline demo for running Pixtral.
# This script is an offline demo for running Mistral-Small-3
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384
# ```
#
# - Client:
@ -23,7 +23,7 @@ from vllm.sampling_params import SamplingParams
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Pixtral-12B-2409",
# "model": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
# "messages": [
# {
# "role": "user",
@ -44,7 +44,7 @@ from vllm.sampling_params import SamplingParams
def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
sampling_params = SamplingParams(max_tokens=8192)
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
@ -83,7 +83,7 @@ def run_simple_demo(args: argparse.Namespace):
def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
max_img_per_msg = 5
max_tokens_per_img = 4096

View File

@ -4,7 +4,6 @@
Run `pytest tests/models/test_mistral.py`.
"""
import json
import uuid
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Optional
@ -16,8 +15,7 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
TextPrompt, TokensPrompt)
from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sequence import Logprob, SampleLogprobs
@ -28,7 +26,11 @@ from ...utils import check_logprobs_close
if TYPE_CHECKING:
from _typeshed import StrPath
MODELS = ["mistralai/Pixtral-12B-2409"]
PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]
IMG_URLS = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/231/200/300",
@ -125,8 +127,10 @@ MAX_MODEL_LEN = [8192, 65536]
FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
assert FIXTURES_PATH.exists()
FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json"
FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json"
FIXTURE_LOGPROBS_CHAT = {
PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
}
OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]]
@ -166,12 +170,12 @@ def test_chat(
model: str,
dtype: str,
) -> None:
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
FIXTURE_LOGPROBS_CHAT[model])
with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
@ -183,70 +187,40 @@ def test_chat(
outputs.extend(output)
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
# Remove last `None` prompt_logprobs to compare with fixture
for i in range(len(logprobs)):
assert logprobs[i][-1] is None
logprobs[i] = logprobs[i][:-1]
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
dtype=dtype,
)
engine = LLMEngine.from_engine_args(args)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
outputs = []
count = 0
while True:
out = engine.step()
count += 1
for request_output in out:
if request_output.finished:
outputs.append(request_output)
if count == 2:
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
SAMPLING_PARAMS)
if not engine.has_unfinished_requests():
break
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize(
"prompt,expected_ranges",
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
"offset": 10,
"offset": 11,
"length": 494
}]),
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
"offset": 10,
"offset": 11,
"length": 266
}, {
"offset": 276,
"offset": 277,
"length": 1056
}, {
"offset": 1332,
"offset": 1333,
"length": 418
}])])
def test_multi_modal_placeholders(
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
def test_multi_modal_placeholders(vllm_runner, prompt,
expected_ranges: list[PlaceholderRange],
monkeypatch) -> None:
# This placeholder checking test only works with V0 engine
# where `multi_modal_placeholders` is returned with `RequestOutput`
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(
"mistral-community/pixtral-12b",
max_model_len=8192,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long