[BugFix] Fix input positions for long context with sliding window (#2088)
This commit is contained in:
parent
096827c284
commit
f1c8520146
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
@ -7,21 +8,32 @@ from transformers import AutoModelForCausalLM
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
"Describe the basic components of a neural network and how it can be trained.",
|
||||
"Write a short story about a robot that dreams for the first time.",
|
||||
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||
]
|
||||
_TEST_PROMPTS = ["prompts/example.txt"]
|
||||
_LONG_PROMPTS = ["prompts/summary.txt"]
|
||||
|
||||
|
||||
def _read_prompts(filename: str) -> str:
|
||||
prompts = []
|
||||
with open(filename, "r") as f:
|
||||
prompt = f.readline()
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
return _TEST_PROMPTS
|
||||
prompts = []
|
||||
for filename in _TEST_PROMPTS:
|
||||
prompts += _read_prompts(os.path.join("tests", filename))
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_long_prompts() -> List[str]:
|
||||
prompts = []
|
||||
for filename in _LONG_PROMPTS:
|
||||
prompts += _read_prompts(os.path.join("tests", filename))
|
||||
return prompts
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
|
37
tests/models/test_mistral.py
Normal file
37
tests/models/test_mistral.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_mistral.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_long_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_long_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
8
tests/prompts/example.txt
Normal file
8
tests/prompts/example.txt
Normal file
@ -0,0 +1,8 @@
|
||||
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
|
||||
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
|
||||
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
|
||||
Describe the basic components of a neural network and how it can be trained.
|
||||
Write a short story about a robot that dreams for the first time.
|
||||
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
|
||||
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
|
||||
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
|
1
tests/prompts/summary.txt
Normal file
1
tests/prompts/summary.txt
Normal file
File diff suppressed because one or more lines are too long
@ -134,14 +134,14 @@ class ModelRunner:
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
context_len = seq_data.get_len()
|
||||
if self.sliding_window is not None:
|
||||
context_len = min(context_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
|
||||
position = context_len - 1
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
|
||||
context_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
|
Loading…
x
Reference in New Issue
Block a user