[Bugfix] fix crash if max_tokens=None (#2570)
This commit is contained in:
parent
1e4277d2d1
commit
3209b49033
@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group():
|
||||
assert len(prompts) == len(outputs)
|
||||
|
||||
|
||||
def test_max_tokens_none():
|
||||
sampling_params = SamplingParams(temperature=0.01,
|
||||
top_p=0.1,
|
||||
max_tokens=None)
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1)
|
||||
prompts = ["Just say hello!"]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
||||
assert len(prompts) == len(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
|
13
tests/test_sampling_params.py
Normal file
13
tests/test_sampling_params.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""Tests for the SamplingParams class.
|
||||
"""
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
def test_max_tokens_none():
|
||||
"""max_tokens=None should be allowed"""
|
||||
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
@ -108,7 +108,7 @@ class SamplingParams:
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
max_tokens: Optional[int] = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
@ -183,7 +183,7 @@ class SamplingParams:
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError("min_p must be in [0, 1], got "
|
||||
f"{self.min_p}.")
|
||||
if self.max_tokens < 1:
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user