[Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (#16529)

Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>
This commit is contained in:
Ryan McConville 2025-04-12 21:19:19 +01:00 committed by GitHub
parent 93e5f3c5fb
commit 6c11ecf8d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 0 deletions

View File

@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
import openai
import pytest
import pytest_asyncio
from vllm.config import ModelConfig
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
def get_vocab_size(model_name):
config = ModelConfig(
model=model_name,
task="auto",
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
return config.get_vocab_size()
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"1024",
"--enforce-eager",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_chat_logit_bias_valid(client):
"""Test that valid logit_bias values are accepted in chat completions."""
vocab_size = get_vocab_size(MODEL_NAME)
valid_token_id = vocab_size - 1
completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "Testing valid logit bias"
}],
max_tokens=5,
logit_bias={str(valid_token_id): 1.0},
)
assert completion.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_chat_logit_bias_invalid(client):
"""Test that invalid logit_bias values are rejected in chat completions."""
vocab_size = get_vocab_size(MODEL_NAME)
invalid_token_id = vocab_size + 1
with pytest.raises(openai.BadRequestError) as excinfo:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "Testing invalid logit bias"
}],
max_tokens=5,
logit_bias={str(invalid_token_id): 1.0},
)
error = excinfo.value
error_message = str(error)
assert error.status_code == 400
assert str(invalid_token_id) in error_message
assert str(vocab_size) in error_message

View File

@ -77,6 +77,7 @@ class Processor:
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
@ -87,6 +88,26 @@ class Processor:
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,

View File

@ -230,9 +230,19 @@ class Sampler(nn.Module):
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
# Get vocabulary size from logits
vocab_size = logits.shape[-1]
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
# Check token_id bounds to ensure within vocabulary
if token_id < 0 or token_id >= vocab_size:
raise ValueError(
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
logits[i, token_id] += bias
return logits