diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index a2a98abe..1323dba4 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -13,6 +13,7 @@ import pytest # and debugging. import ray import requests +import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -870,5 +871,24 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +async def test_long_seed(server, client: openai.AsyncOpenAI): + for seed in [ + torch.iinfo(torch.long).min - 1, + torch.iinfo(torch.long).max + 1 + ]: + with pytest.raises(BadRequestError) as exc_info: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant.", + }], + temperature=0, + seed=seed) + + assert ("greater_than_equal" in exc_info.value.message + or "less_than_equal" in exc_info.value.message) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 731596e8..3cd9ddad 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -79,7 +79,9 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False temperature: Optional[float] = 0.7 @@ -228,7 +230,9 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False suffix: Optional[str] = None