[Bugfix][Frontend] Guard against bad token ids (#9634)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
0ad216f575
commit
67bdf8e523
@ -4,6 +4,12 @@ from vllm import LLM
|
|||||||
|
|
||||||
|
|
||||||
def test_empty_prompt():
|
def test_empty_prompt():
|
||||||
llm = LLM(model="gpt2")
|
llm = LLM(model="gpt2", enforce_eager=True)
|
||||||
with pytest.raises(ValueError, match='Prompt cannot be empty'):
|
with pytest.raises(ValueError, match='Prompt cannot be empty'):
|
||||||
llm.generate([""])
|
llm.generate([""])
|
||||||
|
|
||||||
|
|
||||||
|
def test_out_of_vocab_token():
|
||||||
|
llm = LLM(model="gpt2", enforce_eager=True)
|
||||||
|
with pytest.raises(ValueError, match='out of vocabulary'):
|
||||||
|
llm.generate({"prompt_token_ids": [999999]})
|
||||||
|
@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
completion = await client.completions.create(
|
with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
|
||||||
model=MODEL_NAME,
|
# Added tokens should be rejected by the base model
|
||||||
prompt=[0, 0, 32000, 32001, 32002],
|
await client.completions.create(
|
||||||
echo=True,
|
model=MODEL_NAME,
|
||||||
max_tokens=5,
|
prompt=[0, 0, 32000, 32001, 32002],
|
||||||
temperature=0.0,
|
echo=True,
|
||||||
)
|
max_tokens=5,
|
||||||
# Added tokens should not appear in tokenized prompt
|
temperature=0.0,
|
||||||
assert "vllm" not in completion.choices[0].text
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -20,3 +20,18 @@ async def test_empty_prompt():
|
|||||||
prompt="",
|
prompt="",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_out_of_vocab_token_ids():
|
||||||
|
model_name = "gpt2"
|
||||||
|
server_args = ["--enforce-eager"]
|
||||||
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
client = remote_server.get_async_client()
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError,
|
||||||
|
match=re.compile('.*out of vocabulary.*')):
|
||||||
|
await client.completions.create(model=model_name,
|
||||||
|
prompt=[999999],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0)
|
||||||
|
@ -412,6 +412,12 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
"""Stop the remote worker execution loop."""
|
"""Stop the remote worker execution loop."""
|
||||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||||
|
|
||||||
|
async def get_tokenizer_async(self,
|
||||||
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> AnyTokenizer:
|
||||||
|
return await (
|
||||||
|
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
|
||||||
|
|
||||||
@overload # DEPRECATED
|
@overload # DEPRECATED
|
||||||
async def add_request_async(
|
async def add_request_async(
|
||||||
self,
|
self,
|
||||||
@ -472,6 +478,10 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
tokenizer = await self.get_tokenizer_async(lora_request)
|
||||||
|
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
||||||
|
|
||||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||||
prompt,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
@ -488,7 +498,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
# implementation in the LLMEngine
|
# implementation in the LLMEngine
|
||||||
params = await build_guided_decoding_logits_processor_async(
|
params = await build_guided_decoding_logits_processor_async(
|
||||||
sampling_params=params,
|
sampling_params=params,
|
||||||
tokenizer=self.get_tokenizer(lora_request),
|
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||||
default_guided_backend=self.decoding_config.
|
default_guided_backend=self.decoding_config.
|
||||||
guided_decoding_backend)
|
guided_decoding_backend)
|
||||||
|
|
||||||
@ -715,8 +725,7 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> AnyTokenizer:
|
) -> AnyTokenizer:
|
||||||
return await (self.engine.get_tokenizer_group().
|
return await self.engine.get_tokenizer_async(lora_request)
|
||||||
get_lora_tokenizer_async(lora_request))
|
|
||||||
|
|
||||||
def start_background_loop(self) -> None:
|
def start_background_loop(self) -> None:
|
||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
|
@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
|
|||||||
from typing import Set, Type, Union, cast, overload
|
from typing import Set, Type, Union, cast, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeIs, TypeVar
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||||
@ -32,7 +32,8 @@ from vllm.executor.executor_base import ExecutorBase
|
|||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||||
EncoderDecoderInputs, InputRegistry, PromptType)
|
EncoderDecoderInputs, InputRegistry, PromptType,
|
||||||
|
TokensPrompt)
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logits_process import get_bad_words_logits_processors
|
from vllm.logits_process import get_bad_words_logits_processors
|
||||||
@ -667,7 +668,7 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self._validate_model_inputs(processed_inputs)
|
self._validate_model_inputs(processed_inputs, lora_request)
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
@ -829,6 +830,11 @@ class LLMEngine:
|
|||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self._validate_token_prompt(
|
||||||
|
prompt,
|
||||||
|
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
||||||
|
|
||||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||||
prompt,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
@ -855,6 +861,31 @@ class LLMEngine:
|
|||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _validate_token_prompt(self, prompt: PromptType,
|
||||||
|
tokenizer: AnyTokenizer):
|
||||||
|
# Guard against out-of-vocab tokens.
|
||||||
|
# For some tokenizers, tokenizer.decode will happily return empty text
|
||||||
|
# for token ids that are out of vocab, and we don't detect token ids
|
||||||
|
# that are greater than the max token id before running the model.
|
||||||
|
# However, these token ids will later crash a cuda kernel at runtime
|
||||||
|
# with an index out of bounds error. This will crash the entire engine.
|
||||||
|
# This needs to happen before multimodal input pre-processing, which
|
||||||
|
# may add dummy <image> tokens that aren't part of the tokenizer's
|
||||||
|
# vocabulary.
|
||||||
|
if self._is_token_prompt(prompt):
|
||||||
|
prompt_ids = prompt["prompt_token_ids"]
|
||||||
|
if len(prompt_ids) == 0:
|
||||||
|
# Empty prompt check is handled later
|
||||||
|
return
|
||||||
|
max_input_id = max(prompt_ids)
|
||||||
|
if max_input_id > tokenizer.max_token_id:
|
||||||
|
raise ValueError(
|
||||||
|
"Token id {} is out of vocabulary".format(max_input_id))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||||
|
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||||
|
|
||||||
def _create_sequence_group_with_sampling(
|
def _create_sequence_group_with_sampling(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -1942,7 +1973,8 @@ class LLMEngine:
|
|||||||
return self.input_preprocessor.is_encoder_decoder_model()
|
return self.input_preprocessor.is_encoder_decoder_model()
|
||||||
|
|
||||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||||
EncoderDecoderInputs]):
|
EncoderDecoderInputs],
|
||||||
|
lora_request: Optional[LoRARequest]):
|
||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
# For encoder-decoder multimodal models, the max_prompt_len
|
# For encoder-decoder multimodal models, the max_prompt_len
|
||||||
# restricts the decoder prompt length
|
# restricts the decoder prompt length
|
||||||
|
@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
|||||||
tokenizer.all_special_tokens_extended)
|
tokenizer.all_special_tokens_extended)
|
||||||
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
||||||
tokenizer_len = len(tokenizer)
|
tokenizer_len = len(tokenizer)
|
||||||
|
max_token_id = max(tokenizer.get_vocab().values())
|
||||||
|
|
||||||
class CachedTokenizer(tokenizer.__class__): # type: ignore
|
class CachedTokenizer(tokenizer.__class__): # type: ignore
|
||||||
|
|
||||||
@ -50,6 +51,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
|||||||
def all_special_tokens_extended(self):
|
def all_special_tokens_extended(self):
|
||||||
return tokenizer_all_special_tokens_extended
|
return tokenizer_all_special_tokens_extended
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_token_id(self):
|
||||||
|
return max_token_id
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return tokenizer_len
|
return tokenizer_len
|
||||||
|
|
||||||
|
@ -85,6 +85,7 @@ class MistralTokenizer:
|
|||||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||||
|
|
||||||
self.tokenizer = tokenizer_
|
self.tokenizer = tokenizer_
|
||||||
|
self._max_token_id = max(self._vocab.values())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,
|
def from_pretrained(cls,
|
||||||
@ -158,6 +159,10 @@ class MistralTokenizer:
|
|||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
return len(self._vocab)
|
return len(self._vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_token_id(self) -> int:
|
||||||
|
return self._max_token_id
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.vocab_size
|
return self.vocab_size
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user