[Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (#16529)
Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>
This commit is contained in:
parent
93e5f3c5fb
commit
6c11ecf8d3
88
tests/entrypoints/openai/test_chat_logit_bias_validation.py
Normal file
88
tests/entrypoints/openai/test_chat_logit_bias_validation.py
Normal file
@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
|
||||
def get_vocab_size(model_name):
|
||||
config = ModelConfig(
|
||||
model=model_name,
|
||||
task="auto",
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
return config.get_vocab_size()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_logit_bias_valid(client):
|
||||
"""Test that valid logit_bias values are accepted in chat completions."""
|
||||
vocab_size = get_vocab_size(MODEL_NAME)
|
||||
valid_token_id = vocab_size - 1
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Testing valid logit bias"
|
||||
}],
|
||||
max_tokens=5,
|
||||
logit_bias={str(valid_token_id): 1.0},
|
||||
)
|
||||
|
||||
assert completion.choices[0].message.content is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_logit_bias_invalid(client):
|
||||
"""Test that invalid logit_bias values are rejected in chat completions."""
|
||||
vocab_size = get_vocab_size(MODEL_NAME)
|
||||
invalid_token_id = vocab_size + 1
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as excinfo:
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Testing invalid logit bias"
|
||||
}],
|
||||
max_tokens=5,
|
||||
logit_bias={str(invalid_token_id): 1.0},
|
||||
)
|
||||
|
||||
error = excinfo.value
|
||||
error_message = str(error)
|
||||
|
||||
assert error.status_code == 400
|
||||
assert str(invalid_token_id) in error_message
|
||||
assert str(vocab_size) in error_message
|
@ -77,6 +77,7 @@ class Processor:
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
@ -87,6 +88,26 @@ class Processor:
|
||||
raise ValueError(
|
||||
"allowed_token_ids contains out-of-vocab token id!")
|
||||
|
||||
def _validate_logit_bias(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||||
if not params.logit_bias:
|
||||
return
|
||||
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
invalid_token_ids = []
|
||||
|
||||
for token_id in params.logit_bias:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if invalid_token_ids:
|
||||
raise ValueError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
|
||||
|
||||
def _validate_supported_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
|
@ -230,9 +230,19 @@ class Sampler(nn.Module):
|
||||
# TODO(houseroad): this implementation is extremely inefficient.
|
||||
# One idea is implement this as a PyTorch C++ op, and we may
|
||||
# even optimize the logit_bias layout.
|
||||
|
||||
# Get vocabulary size from logits
|
||||
vocab_size = logits.shape[-1]
|
||||
|
||||
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
|
||||
if logit_bias:
|
||||
for token_id, bias in logit_bias.items():
|
||||
# Check token_id bounds to ensure within vocabulary
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
raise ValueError(
|
||||
f"token_id {token_id} in logit_bias contains "
|
||||
f"out-of-vocab token id. Vocabulary size: "
|
||||
f"{vocab_size}")
|
||||
logits[i, token_id] += bias
|
||||
return logits
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user