vllm/tests/models/test_jamba.py
Mor Zusman 9d6a8daa87
[Model] Jamba support (#4115)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: Erez Schwartz <erezs@ai21.com>
Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Tomer Asida <tomera@ai21.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
2024-07-02 23:11:29 +00:00

66 lines
2.1 KiB
Python

import pytest
MODELS = ["ai21labs/Jamba-tiny-random"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)