Add tests for models (#922)
This commit is contained in:
parent
c128d69856
commit
32b6816e55
@ -10,3 +10,4 @@ types-setuptools
|
||||
|
||||
# testing
|
||||
pytest
|
||||
pytest-forked
|
||||
|
132
tests/conftest.py
Normal file
132
tests/conftest.py
Normal file
@ -0,0 +1,132 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
"Describe the basic components of a neural network and how it can be trained.",
|
||||
"Write a short story about a robot that dreams for the first time.",
|
||||
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
return _TEST_PROMPTS
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
}
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
**kwargs,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs: List[Tuple[List[int], str]] = []
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
output_ids = self.model.generate(
|
||||
input_ids.cuda(),
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
output_str = self.tokenizer.batch_decode(
|
||||
output_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)[0]
|
||||
output_ids = output_ids[0].cpu().tolist()
|
||||
outputs.append((output_ids, output_str))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
return self.generate(prompts, do_sample=False,
|
||||
max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_runner():
|
||||
return HfRunner
|
||||
|
||||
|
||||
class VllmRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
model=model_name,
|
||||
tokenizer=tokenizer_name,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=0,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
req_outputs = self.model.generate(
|
||||
prompts, sampling_params=sampling_params)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
prompt_ids = req_output.prompt_token_ids
|
||||
output_str = req_output.outputs[0].text
|
||||
output_ids = req_output.outputs[0].token_ids
|
||||
outputs.append((prompt_ids + output_ids, prompt_str + output_str))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
return self.generate(prompts, greedy_params)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_runner():
|
||||
return VllmRunner
|
45
tests/models/test_models.py
Normal file
45
tests/models/test_models.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_models.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
Loading…
x
Reference in New Issue
Block a user