[Bugfix] Add validation for seed (#4529)
This commit is contained in:
parent
24bb4fe432
commit
c47ba4aaa9
@ -13,6 +13,7 @@ import pytest
|
|||||||
# and debugging.
|
# and debugging.
|
||||||
import ray
|
import ray
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
@ -870,5 +871,24 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
|||||||
assert len(logprobs.tokens) > 5
|
assert len(logprobs.tokens) > 5
|
||||||
|
|
||||||
|
|
||||||
|
async def test_long_seed(server, client: openai.AsyncOpenAI):
|
||||||
|
for seed in [
|
||||||
|
torch.iinfo(torch.long).min - 1,
|
||||||
|
torch.iinfo(torch.long).max + 1
|
||||||
|
]:
|
||||||
|
with pytest.raises(BadRequestError) as exc_info:
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.",
|
||||||
|
}],
|
||||||
|
temperature=0,
|
||||||
|
seed=seed)
|
||||||
|
|
||||||
|
assert ("greater_than_equal" in exc_info.value.message
|
||||||
|
or "less_than_equal" in exc_info.value.message)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
@ -79,7 +79,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = Field(None,
|
||||||
|
ge=torch.iinfo(torch.long).min,
|
||||||
|
le=torch.iinfo(torch.long).max)
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = 0.7
|
||||||
@ -228,7 +230,9 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = 16
|
||||||
n: int = 1
|
n: int = 1
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = Field(None,
|
||||||
|
ge=torch.iinfo(torch.long).min,
|
||||||
|
le=torch.iinfo(torch.long).max)
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
suffix: Optional[str] = None
|
suffix: Optional[str] = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user