[Minor] Fix duplication of ignored seq group in engine step (#1666)

This commit is contained in:
Simon Mo 2023-11-16 13:11:41 -08:00 committed by GitHub
parent 2a2c135b41
commit cb08cd0d75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 1 deletions

27
tests/test_regression.py Normal file
View File

@ -0,0 +1,27 @@
"""Containing tests that check for regressions in vLLM's behavior.
It should include tests that are reported by users and making sure they
will never happen again.
"""
from vllm import LLM, SamplingParams
def test_duplicated_ignored_sequence_group():
"""https://github.com/vllm-project/vllm/issues/1655"""
sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=256)
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(prompts) == len(outputs)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -567,7 +567,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
return self._process_model_outputs(output, scheduler_outputs) + ignored
return self._process_model_outputs(output, scheduler_outputs)
def _log_system_stats(
self,