[CI/Build] Update pixtral tests to use JSON (#8436)

This commit is contained in:
Cyrus Leung 2024-09-13 11:47:52 +08:00 committed by GitHub
parent 3f79bc3d1a
commit 8427550488
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 42 additions and 18 deletions

View File

@ -76,7 +76,7 @@ exclude = [
[tool.codespell]
ignore-words-list = "dout, te, indicies, subtile"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort]
use_parentheses = true

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -2,9 +2,10 @@
Run `pytest tests/models/test_mistral.py`.
"""
import pickle
import json
import uuid
from typing import Any, Dict, List
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple
import pytest
from mistral_common.protocol.instruct.messages import ImageURLChunk
@ -14,6 +15,7 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs
from .utils import check_logprobs_close
@ -81,13 +83,33 @@ SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4)
MAX_MODEL_LEN = [8192, 65536]
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
def load_logprobs(filename: str) -> Any:
with open(filename, 'rb') as f:
return pickle.load(f)
# For the test author to store golden output in JSON
def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
json_data = [(tokens, text,
[{k: asdict(v)
for k, v in token_logprobs.items()}
for token_logprobs in (logprobs or [])])
for tokens, text, logprobs in outputs]
with open(filename, "w") as f:
json.dump(json_data, f)
def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
with open(filename, "rb") as f:
json_data = json.load(f)
return [(tokens, text,
[{int(k): Logprob(**v)
for k, v in token_logprobs.items()}
for token_logprobs in logprobs])
for tokens, text, logprobs in json_data]
@pytest.mark.skip(
@ -103,7 +125,7 @@ def test_chat(
model: str,
dtype: str,
) -> None:
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
with vllm_runner(
model,
dtype=dtype,
@ -120,10 +142,10 @@ def test_chat(
outputs.extend(output)
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
name_0="output",
name_1="h100_ref")
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@pytest.mark.skip(
@ -133,7 +155,7 @@ def test_chat(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
break
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
name_0="output",
name_1="h100_ref")
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")