2024-09-11 23:41:55 +02:00
|
|
|
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
|
|
|
|
|
|
|
Run `pytest tests/models/test_mistral.py`.
|
|
|
|
"""
|
2024-09-13 11:47:52 +08:00
|
|
|
import json
|
2024-09-13 00:21:51 +02:00
|
|
|
import uuid
|
2024-09-13 11:47:52 +08:00
|
|
|
from dataclasses import asdict
|
2024-09-14 01:20:06 +08:00
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
2024-09-13 00:21:51 +02:00
|
|
|
|
2024-09-11 23:41:55 +02:00
|
|
|
import pytest
|
2024-11-18 00:06:16 -08:00
|
|
|
from mistral_common.multimodal import download_image
|
2024-09-13 00:21:51 +02:00
|
|
|
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
|
|
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
|
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
|
|
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
2024-11-18 00:06:16 -08:00
|
|
|
from transformers import AutoProcessor
|
2024-09-13 00:21:51 +02:00
|
|
|
|
2024-11-18 00:06:16 -08:00
|
|
|
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
|
|
|
|
TextPrompt, TokensPrompt)
|
2024-09-13 00:21:51 +02:00
|
|
|
from vllm.multimodal import MultiModalDataBuiltins
|
2024-11-18 00:06:16 -08:00
|
|
|
from vllm.multimodal.inputs import PlaceholderRange
|
2024-09-13 11:47:52 +08:00
|
|
|
from vllm.sequence import Logprob, SampleLogprobs
|
2024-09-11 23:41:55 +02:00
|
|
|
|
2024-09-29 10:50:51 +08:00
|
|
|
from ....utils import VLLM_PATH, large_gpu_test
|
2024-09-14 01:20:06 +08:00
|
|
|
from ...utils import check_logprobs_close
|
2024-09-11 23:41:55 +02:00
|
|
|
|
2024-09-14 01:20:06 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from _typeshed import StrPath
|
2024-09-11 23:41:55 +02:00
|
|
|
|
|
|
|
MODELS = ["mistralai/Pixtral-12B-2409"]
|
2024-09-13 00:21:51 +02:00
|
|
|
IMG_URLS = [
|
|
|
|
"https://picsum.photos/id/237/400/300",
|
|
|
|
"https://picsum.photos/id/231/200/300",
|
|
|
|
"https://picsum.photos/id/27/500/500",
|
|
|
|
"https://picsum.photos/id/17/150/600",
|
|
|
|
]
|
|
|
|
PROMPT = "Describe each image in one short sentence."
|
|
|
|
|
|
|
|
|
|
|
|
def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
|
|
|
|
return [{
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content": [{
|
|
|
|
"type": "text",
|
|
|
|
"text": PROMPT,
|
|
|
|
}] + [{
|
|
|
|
"type": "image_url",
|
|
|
|
"image_url": {
|
|
|
|
"url": url
|
|
|
|
}
|
|
|
|
} for url in urls],
|
|
|
|
}]
|
|
|
|
|
|
|
|
|
2024-11-18 00:06:16 -08:00
|
|
|
def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]:
|
|
|
|
return [{
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content": [{
|
|
|
|
"type": "text",
|
|
|
|
"content": PROMPT,
|
|
|
|
}, *({
|
|
|
|
"type": "image",
|
|
|
|
"image": download_image(url)
|
|
|
|
} for url in urls)],
|
|
|
|
}]
|
|
|
|
|
|
|
|
|
2024-09-13 00:21:51 +02:00
|
|
|
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
|
|
|
msg = _create_msg_format(urls)
|
|
|
|
|
|
|
|
tokenizer = MistralTokenizer.from_model("pixtral")
|
|
|
|
|
|
|
|
request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
|
|
|
|
tokenized = tokenizer.encode_chat_completion(request)
|
|
|
|
|
|
|
|
engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
|
|
|
|
|
|
|
|
images = []
|
|
|
|
for chunk in request.messages[0].content:
|
|
|
|
if isinstance(chunk, ImageURLChunk):
|
|
|
|
images.append(image_from_chunk(chunk))
|
|
|
|
|
|
|
|
mm_data = MultiModalDataBuiltins(image=images)
|
|
|
|
engine_inputs["multi_modal_data"] = mm_data
|
|
|
|
|
|
|
|
return engine_inputs
|
|
|
|
|
|
|
|
|
2024-11-18 00:06:16 -08:00
|
|
|
def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt:
|
|
|
|
msg = _create_msg_format_hf(urls)
|
|
|
|
|
|
|
|
tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b")
|
|
|
|
prompt = tokenizer.apply_chat_template(msg)
|
|
|
|
|
|
|
|
images = []
|
|
|
|
for chunk in msg[0]["content"]:
|
|
|
|
if chunk["type"] == "image":
|
|
|
|
images.append(chunk["image"])
|
|
|
|
|
|
|
|
mm_data = MultiModalDataBuiltins(image=images)
|
|
|
|
engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data)
|
|
|
|
|
|
|
|
return engine_inputs
|
|
|
|
|
|
|
|
|
2024-09-13 00:21:51 +02:00
|
|
|
MSGS = [
|
|
|
|
_create_msg_format(IMG_URLS[:1]),
|
|
|
|
_create_msg_format(IMG_URLS[:2]),
|
|
|
|
_create_msg_format(IMG_URLS),
|
|
|
|
]
|
|
|
|
ENGINE_INPUTS = [
|
|
|
|
_create_engine_inputs(IMG_URLS[:1]),
|
|
|
|
_create_engine_inputs(IMG_URLS[:2]),
|
|
|
|
_create_engine_inputs(IMG_URLS),
|
|
|
|
]
|
|
|
|
|
|
|
|
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
|
|
|
|
LIMIT_MM_PER_PROMPT = dict(image=4)
|
|
|
|
|
|
|
|
MAX_MODEL_LEN = [8192, 65536]
|
2024-09-14 01:20:06 +08:00
|
|
|
|
|
|
|
FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
|
|
|
|
assert FIXTURES_PATH.exists()
|
|
|
|
|
|
|
|
FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json"
|
|
|
|
FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json"
|
2024-09-13 00:21:51 +02:00
|
|
|
|
2024-09-13 11:47:52 +08:00
|
|
|
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
|
2024-09-13 00:21:51 +02:00
|
|
|
|
2024-09-13 11:47:52 +08:00
|
|
|
|
|
|
|
# For the test author to store golden output in JSON
|
2024-09-14 01:20:06 +08:00
|
|
|
def _dump_outputs_w_logprobs(
|
|
|
|
outputs: OutputsLogprobs,
|
|
|
|
filename: "StrPath",
|
|
|
|
) -> None:
|
2025-01-28 00:23:08 +00:00
|
|
|
json_data = [(tokens, text, [{
|
|
|
|
k: asdict(v)
|
|
|
|
for k, v in token_logprobs.items()
|
|
|
|
} for token_logprobs in (logprobs or [])])
|
2024-09-13 11:47:52 +08:00
|
|
|
for tokens, text, logprobs in outputs]
|
|
|
|
|
|
|
|
with open(filename, "w") as f:
|
|
|
|
json.dump(json_data, f)
|
|
|
|
|
|
|
|
|
2024-09-14 01:20:06 +08:00
|
|
|
def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
|
2024-09-13 11:47:52 +08:00
|
|
|
with open(filename, "rb") as f:
|
|
|
|
json_data = json.load(f)
|
|
|
|
|
2025-01-28 00:23:08 +00:00
|
|
|
return [(tokens, text, [{
|
|
|
|
int(k): Logprob(**v)
|
|
|
|
for k, v in token_logprobs.items()
|
|
|
|
} for token_logprobs in logprobs]) for tokens, text, logprobs in json_data]
|
2024-09-11 23:41:55 +02:00
|
|
|
|
|
|
|
|
2024-09-29 10:50:51 +08:00
|
|
|
@large_gpu_test(min_gb=80)
|
2024-09-11 23:41:55 +02:00
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
2024-09-13 00:21:51 +02:00
|
|
|
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
|
2024-09-11 23:41:55 +02:00
|
|
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
2024-09-13 00:21:51 +02:00
|
|
|
def test_chat(
|
2024-09-11 23:41:55 +02:00
|
|
|
vllm_runner,
|
2024-09-13 00:21:51 +02:00
|
|
|
max_model_len: int,
|
2024-09-11 23:41:55 +02:00
|
|
|
model: str,
|
|
|
|
dtype: str,
|
|
|
|
) -> None:
|
2024-09-13 11:47:52 +08:00
|
|
|
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
|
2024-09-13 00:21:51 +02:00
|
|
|
with vllm_runner(
|
|
|
|
model,
|
|
|
|
dtype=dtype,
|
|
|
|
tokenizer_mode="mistral",
|
|
|
|
enable_chunked_prefill=False,
|
|
|
|
max_model_len=max_model_len,
|
|
|
|
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
|
|
|
) as vllm_model:
|
|
|
|
outputs = []
|
|
|
|
for msg in MSGS:
|
|
|
|
output = vllm_model.model.chat(msg,
|
|
|
|
sampling_params=SAMPLING_PARAMS)
|
|
|
|
|
|
|
|
outputs.extend(output)
|
|
|
|
|
|
|
|
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
2024-09-13 11:47:52 +08:00
|
|
|
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
|
|
|
|
outputs_1_lst=logprobs,
|
|
|
|
name_0="h100_ref",
|
|
|
|
name_1="output")
|
2024-09-13 00:21:51 +02:00
|
|
|
|
|
|
|
|
2024-09-29 10:50:51 +08:00
|
|
|
@large_gpu_test(min_gb=80)
|
2024-09-13 00:21:51 +02:00
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
|
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
|
|
|
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
2024-09-13 11:47:52 +08:00
|
|
|
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
2024-09-13 00:21:51 +02:00
|
|
|
args = EngineArgs(
|
|
|
|
model=model,
|
|
|
|
tokenizer_mode="mistral",
|
|
|
|
enable_chunked_prefill=False,
|
|
|
|
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
|
|
|
dtype=dtype,
|
|
|
|
)
|
|
|
|
engine = LLMEngine.from_engine_args(args)
|
|
|
|
|
|
|
|
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
|
|
|
|
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
|
|
|
|
|
|
|
|
outputs = []
|
|
|
|
count = 0
|
|
|
|
while True:
|
|
|
|
out = engine.step()
|
|
|
|
count += 1
|
|
|
|
for request_output in out:
|
|
|
|
if request_output.finished:
|
|
|
|
outputs.append(request_output)
|
|
|
|
|
|
|
|
if count == 2:
|
|
|
|
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
|
|
|
|
SAMPLING_PARAMS)
|
|
|
|
if not engine.has_unfinished_requests():
|
|
|
|
break
|
|
|
|
|
|
|
|
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
2024-09-13 11:47:52 +08:00
|
|
|
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
|
|
|
|
outputs_1_lst=logprobs,
|
|
|
|
name_0="h100_ref",
|
|
|
|
name_1="output")
|
2024-11-18 00:06:16 -08:00
|
|
|
|
|
|
|
|
2024-12-06 00:05:52 +08:00
|
|
|
@large_gpu_test(min_gb=48)
|
2024-11-18 00:06:16 -08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"prompt,expected_ranges",
|
|
|
|
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
|
|
|
|
"offset": 10,
|
|
|
|
"length": 494
|
|
|
|
}]),
|
|
|
|
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
|
|
|
|
"offset": 10,
|
|
|
|
"length": 266
|
|
|
|
}, {
|
|
|
|
"offset": 276,
|
|
|
|
"length": 1056
|
|
|
|
}, {
|
|
|
|
"offset": 1332,
|
|
|
|
"length": 418
|
|
|
|
}])])
|
|
|
|
def test_multi_modal_placeholders(
|
|
|
|
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
|
|
|
|
with vllm_runner(
|
|
|
|
"mistral-community/pixtral-12b",
|
|
|
|
max_model_len=8192,
|
|
|
|
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
|
|
|
) as vllm_model:
|
|
|
|
outputs = vllm_model.model.generate(prompt)
|
|
|
|
|
|
|
|
assert len(outputs) == 1, f"{len(outputs)=}"
|
|
|
|
output: RequestOutput = outputs[0]
|
|
|
|
assert hasattr(output,
|
|
|
|
"multi_modal_placeholders"), f"{output.__dict__=}"
|
|
|
|
assert "image" in output.multi_modal_placeholders, \
|
|
|
|
f"{output.multi_modal_placeholders.keys()=}"
|
|
|
|
image_placeholder_ranges: list[
|
|
|
|
PlaceholderRange] = output.multi_modal_placeholders["image"]
|
|
|
|
assert len(image_placeholder_ranges) == len(
|
|
|
|
expected_ranges), f"{image_placeholder_ranges=}"
|
|
|
|
for real_range, expected_range in zip(image_placeholder_ranges,
|
|
|
|
expected_ranges):
|
|
|
|
assert real_range == expected_range, \
|
|
|
|
f"{real_range=} {expected_range=}"
|