[BugFix] Fix input positions for long context with sliding window (#2088)
This commit is contained in:
parent
096827c284
commit
f1c8520146
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -7,21 +8,32 @@ from transformers import AutoModelForCausalLM
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
_TEST_PROMPTS = [
|
_TEST_PROMPTS = ["prompts/example.txt"]
|
||||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
_LONG_PROMPTS = ["prompts/summary.txt"]
|
||||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
|
||||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
|
||||||
"Describe the basic components of a neural network and how it can be trained.",
|
def _read_prompts(filename: str) -> str:
|
||||||
"Write a short story about a robot that dreams for the first time.",
|
prompts = []
|
||||||
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
with open(filename, "r") as f:
|
||||||
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
prompt = f.readline()
|
||||||
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
prompts.append(prompt)
|
||||||
]
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def example_prompts() -> List[str]:
|
def example_prompts() -> List[str]:
|
||||||
return _TEST_PROMPTS
|
prompts = []
|
||||||
|
for filename in _TEST_PROMPTS:
|
||||||
|
prompts += _read_prompts(os.path.join("tests", filename))
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_long_prompts() -> List[str]:
|
||||||
|
prompts = []
|
||||||
|
for filename in _LONG_PROMPTS:
|
||||||
|
prompts += _read_prompts(os.path.join("tests", filename))
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
37
tests/models/test_mistral.py
Normal file
37
tests/models/test_mistral.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_mistral.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_long_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_long_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
8
tests/prompts/example.txt
Normal file
8
tests/prompts/example.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
|
||||||
|
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
|
||||||
|
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
|
||||||
|
Describe the basic components of a neural network and how it can be trained.
|
||||||
|
Write a short story about a robot that dreams for the first time.
|
||||||
|
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
|
||||||
|
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
|
||||||
|
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
|
1
tests/prompts/summary.txt
Normal file
1
tests/prompts/summary.txt
Normal file
File diff suppressed because one or more lines are too long
@ -134,14 +134,14 @@ class ModelRunner:
|
|||||||
generation_token = seq_data.get_last_token_id()
|
generation_token = seq_data.get_last_token_id()
|
||||||
input_tokens.append([generation_token])
|
input_tokens.append([generation_token])
|
||||||
|
|
||||||
context_len = seq_data.get_len()
|
seq_len = seq_data.get_len()
|
||||||
if self.sliding_window is not None:
|
position = seq_len - 1
|
||||||
context_len = min(context_len, self.sliding_window)
|
|
||||||
context_lens.append(context_len)
|
|
||||||
|
|
||||||
position = context_len - 1
|
|
||||||
input_positions.append([position])
|
input_positions.append([position])
|
||||||
|
|
||||||
|
context_len = seq_len if self.sliding_window is None else min(
|
||||||
|
seq_len, self.sliding_window)
|
||||||
|
context_lens.append(context_len)
|
||||||
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
block_number = block_table[position // self.block_size]
|
block_number = block_table[position // self.block_size]
|
||||||
block_offset = position % self.block_size
|
block_offset = position % self.block_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user