[Frontend] Bad words sampling parameter (#9717)

Signed-off-by: Vasily Alexeev <alvasian@yandex.ru>
This commit is contained in:
Vasiliy Alekseev 2024-10-26 19:29:38 +03:00 committed by GitHub
parent 55137e8ee3
commit 07e981fdf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 339 additions and 16 deletions

View File

@ -0,0 +1,185 @@
"""Make sure bad_words works.
Run `pytest tests/samplers/test_no_bad_words.py`.
"""
from typing import List, Optional
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
def _generate(
model: LLM,
prompt: str,
num_prompt_tokens: int,
temperature: float = 0,
bad_words: Optional[List[str]] = None,
) -> List[int]:
sampling_params = SamplingParams(
temperature=temperature,
bad_words=bad_words,
)
# [([output_token_ids, ], [output_text, ]), ]
output = model.generate([prompt], sampling_params=sampling_params)
output_token_ids = output[0][0][0][num_prompt_tokens:]
# [0] first (and only) request output
# [0] token_ids (not text)
# [0] first (and only) output completion
return output_token_ids
class TestOneTokenBadWord:
MODEL = "TheBloke/Llama-2-7B-fp16"
PROMPT = "Hi! How are"
TARGET_TOKEN = "you"
def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)
self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id = self._encode(self.TARGET_TOKEN,
add_special_tokens=False)[0]
def test_one_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[0] == self.target_token_id
output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN])
assert self.target_token_id not in output_token_ids
def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)
def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids
class TestTwoTokenBadWord:
# Another model (with a different tokenizer behaviour)
MODEL = "openai-community/gpt2"
PROMPT = "How old are you? I am 10"
TARGET_TOKEN1 = "years"
TARGET_TOKEN2 = "old"
NEIGHBOUR_TOKEN2 = "older"
def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)
self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
add_special_tokens=False)[0]
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
add_special_tokens=False)[0]
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
add_special_tokens=False)[0]
def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1, self.target_token_id2
]
output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN1])
assert self.target_token_id1 not in output_token_ids
output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN2])
assert output_token_ids[0] == self.target_token_id1
assert self.target_token_id2 not in output_token_ids
output_token_ids = self._generate(
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
# Model dependent behaviour
assert output_token_ids[:2] == [
self.target_token_id1, self.neighbour_token_id2
]
output_token_ids = self._generate(
llm,
bad_words=[
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
assert output_token_ids[:2] != [
self.target_token_id1, self.neighbour_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.neighbour_token_id2])
assert ((self.target_token_id2 in output_token_ids)
or (self.neighbour_token_id2 in output_token_ids))
def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)
@staticmethod
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
searched = False
for start in range(len(sequence)):
end = start + len(subsequence)
current_subsequence = sequence[start:end]
if len(current_subsequence) < len(subsequence):
continue
searched = True
assert len(current_subsequence) == len(subsequence)
if current_subsequence == subsequence:
return True
assert searched, "All subsequences did not match in length..."
return False
def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids

View File

@ -26,7 +26,8 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
@ -34,6 +35,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType) EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor) get_local_guided_decoding_logits_processor)
@ -1963,6 +1965,7 @@ class LLMEngine:
logits_processors field. Returns the modified sampling params.""" logits_processors field. Returns the modified sampling params."""
logits_processors = [] logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None: if (guided_decoding := sampling_params.guided_decoding) is not None:
logger.debug( logger.debug(
@ -1984,7 +1987,7 @@ class LLMEngine:
if (sampling_params.logit_bias or sampling_params.allowed_token_ids): if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request) tokenizer = self.get_tokenizer(lora_request=lora_request)
processors = get_logits_processors( processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias, logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids, allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer) tokenizer=tokenizer)
@ -1994,6 +1997,12 @@ class LLMEngine:
sampling_params.logit_bias = None sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None sampling_params.allowed_token_ids = None
if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request)
processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors)
if logits_processors: if logits_processors:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors sampling_params.logits_processors = logits_processors

119
vllm/logits_process.py Normal file
View File

@ -0,0 +1,119 @@
from typing import Callable, List, Tuple, Union
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
def get_bad_words_logits_processors(
bad_words: List[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
bad_words_ids: List[List[int]] = list()
for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space
and prompt_token_ids[0] != bad_words_ids[-1][0]
and len(prompt_token_ids) == len(bad_words_ids[-1])):
bad_words_ids.append(prompt_token_ids)
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0
def __init__(self, bad_words_ids: List[List[int]]):
self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None
def __call__(
self,
past_tokens_ids: Union[List[int], Tuple[int]],
logits: torch.FloatTensor,
) -> torch.Tensor:
if self.word_bias is None:
self._init_word_bias(logits=logits)
last_token_bias = torch.zeros_like(logits)
for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1: # 1-token words already processed
continue
if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue
prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
actual_prefix = past_tokens_ids[-prefix_length:]
expected_prefix = bad_word_ids[:prefix_length]
assert len(actual_prefix) == len(expected_prefix)
is_match = tuple(actual_prefix) == tuple(expected_prefix)
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
else self._NEUTRAL_LOGIT)
logits = logits + self.word_bias + last_token_bias
return logits
def _init_word_bias(self, logits: torch.FloatTensor) -> None:
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
vocab_size = logits.shape[-1]
self._check_token_ids_bounds(vocab_size=vocab_size)
self.word_bias = torch.zeros((vocab_size, ),
dtype=torch.float,
device=logits.device)
for bad_word_ids in self.bad_words_ids:
if len(bad_word_ids) == 1:
bad_word_id = bad_word_ids[-1]
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT
def _check_token_ids_bounds(self, vocab_size: int) -> None:
invalid_token_ids = []
for bad_word_ids in self.bad_words_ids:
for token_id in bad_word_ids:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {vocab_size},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id < {vocab_size}.")

View File

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
async def get_guided_decoding_logits_processor( async def get_guided_decoding_logits_processor(

View File

@ -9,7 +9,8 @@ from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
def get_local_lm_format_enforcer_guided_decoding_logits_processor( def get_local_lm_format_enforcer_guided_decoding_logits_processor(

View File

@ -3,14 +3,14 @@ import copy
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
import msgspec import msgspec
import torch
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
logger = init_logger(__name__) logger = init_logger(__name__)
@ -24,16 +24,6 @@ class SamplingType(IntEnum):
RANDOM_SEED = 2 RANDOM_SEED = 2
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
# maybe make msgspec? # maybe make msgspec?
@dataclass @dataclass
class GuidedDecodingParams: class GuidedDecodingParams:
@ -139,6 +129,10 @@ class SamplingParams(
stop_token_ids: List of tokens that stop the generation when they are stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens. the stop tokens are special tokens.
bad_words: List of words that are not allowed to be generated.
More precisely, only the last token of a corresponding
token sequence is not allowed when the next generated token
can complete the sequence.
include_stop_str_in_output: Whether to include the stop strings in include_stop_str_in_output: Whether to include the stop strings in
output text. Defaults to False. output text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
@ -186,6 +180,7 @@ class SamplingParams(
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None stop_token_ids: Optional[List[int]] = None
bad_words: Optional[List[str]] = None
ignore_eos: bool = False ignore_eos: bool = False
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
min_tokens: int = 0 min_tokens: int = 0
@ -228,6 +223,7 @@ class SamplingParams(
seed: Optional[int] = None, seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
bad_words: Optional[List[str]] = None,
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: Optional[int] = 16, max_tokens: Optional[int] = 16,
@ -267,6 +263,7 @@ class SamplingParams(
seed=seed, seed=seed,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
bad_words=bad_words,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -298,26 +295,36 @@ class SamplingParams(
f"got n={self.n} and best_of={self.best_of}.") f"got n={self.n} and best_of={self.best_of}.")
self._real_n = self.n self._real_n = self.n
self.n = self.best_of self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP: if 0 < self.temperature < _MAX_TEMP:
logger.warning( logger.warning(
"temperature %s is less than %s, which may cause numerical " "temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.", "errors nan or inf in tensors. We have maxed it out to %s.",
self.temperature, _MAX_TEMP, _MAX_TEMP) self.temperature, _MAX_TEMP, _MAX_TEMP)
self.temperature = max(self.temperature, _MAX_TEMP) self.temperature = max(self.temperature, _MAX_TEMP)
if self.seed == -1: if self.seed == -1:
self.seed = None self.seed = None
else: else:
self.seed = self.seed self.seed = self.seed
if self.stop is None: if self.stop is None:
self.stop = [] self.stop = []
elif isinstance(self.stop, str): elif isinstance(self.stop, str):
self.stop = [self.stop] self.stop = [self.stop]
else: else:
self.stop = list(self.stop) self.stop = list(self.stop)
if self.stop_token_ids is None: if self.stop_token_ids is None:
self.stop_token_ids = [] self.stop_token_ids = []
else: else:
self.stop_token_ids = list(self.stop_token_ids) self.stop_token_ids = list(self.stop_token_ids)
if self.bad_words is None:
self.bad_words = []
else:
self.bad_words = list(self.bad_words)
self.logprobs = 1 if self.logprobs is True else self.logprobs self.logprobs = 1 if self.logprobs is True else self.logprobs
self.prompt_logprobs = (1 if self.prompt_logprobs is True else self.prompt_logprobs = (1 if self.prompt_logprobs is True else
self.prompt_logprobs) self.prompt_logprobs)
@ -468,6 +475,7 @@ class SamplingParams(
f"seed={self.seed}, " f"seed={self.seed}, "
f"stop={self.stop}, " f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, " f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "