[CI/Build] Update pixtral tests to use JSON (#8436)
This commit is contained in:
parent
3f79bc3d1a
commit
8427550488
@ -76,7 +76,7 @@ exclude = [
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "dout, te, indicies, subtile"
|
||||
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||
|
||||
[tool.isort]
|
||||
use_parentheses = true
|
||||
|
1
tests/models/fixtures/pixtral_chat.json
Normal file
1
tests/models/fixtures/pixtral_chat.json
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
1
tests/models/fixtures/pixtral_chat_engine.json
Normal file
1
tests/models/fixtures/pixtral_chat_engine.json
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -2,9 +2,10 @@
|
||||
|
||||
Run `pytest tests/models/test_mistral.py`.
|
||||
"""
|
||||
import pickle
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||
@ -14,6 +15,7 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
@ -81,13 +83,33 @@ SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
|
||||
LIMIT_MM_PER_PROMPT = dict(image=4)
|
||||
|
||||
MAX_MODEL_LEN = [8192, 65536]
|
||||
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
|
||||
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
|
||||
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
|
||||
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"
|
||||
|
||||
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
|
||||
|
||||
|
||||
def load_logprobs(filename: str) -> Any:
|
||||
with open(filename, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
# For the test author to store golden output in JSON
|
||||
def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
|
||||
json_data = [(tokens, text,
|
||||
[{k: asdict(v)
|
||||
for k, v in token_logprobs.items()}
|
||||
for token_logprobs in (logprobs or [])])
|
||||
for tokens, text, logprobs in outputs]
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
|
||||
def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
|
||||
with open(filename, "rb") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
return [(tokens, text,
|
||||
[{int(k): Logprob(**v)
|
||||
for k, v in token_logprobs.items()}
|
||||
for token_logprobs in logprobs])
|
||||
for tokens, text, logprobs in json_data]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
@ -103,7 +125,7 @@ def test_chat(
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
|
||||
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
@ -120,10 +142,10 @@ def test_chat(
|
||||
outputs.extend(output)
|
||||
|
||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||
check_logprobs_close(outputs_0_lst=logprobs,
|
||||
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
|
||||
name_0="output",
|
||||
name_1="h100_ref")
|
||||
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
|
||||
outputs_1_lst=logprobs,
|
||||
name_0="h100_ref",
|
||||
name_1="output")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
@ -133,7 +155,7 @@ def test_chat(
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
||||
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
||||
args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer_mode="mistral",
|
||||
@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||
break
|
||||
|
||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||
check_logprobs_close(outputs_0_lst=logprobs,
|
||||
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
|
||||
name_0="output",
|
||||
name_1="h100_ref")
|
||||
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
|
||||
outputs_1_lst=logprobs,
|
||||
name_0="h100_ref",
|
||||
name_1="output")
|
||||
|
Loading…
x
Reference in New Issue
Block a user