[Frontend] Bad words sampling parameter (#9717)
Signed-off-by: Vasily Alexeev <alvasian@yandex.ru>
This commit is contained in:
parent
55137e8ee3
commit
07e981fdf4
185
tests/samplers/test_no_bad_words.py
Normal file
185
tests/samplers/test_no_bad_words.py
Normal file
@ -0,0 +1,185 @@
|
||||
"""Make sure bad_words works.
|
||||
|
||||
Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def _generate(
|
||||
model: LLM,
|
||||
prompt: str,
|
||||
num_prompt_tokens: int,
|
||||
temperature: float = 0,
|
||||
bad_words: Optional[List[str]] = None,
|
||||
) -> List[int]:
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
# [([output_token_ids, ], [output_text, ]), ]
|
||||
output = model.generate([prompt], sampling_params=sampling_params)
|
||||
|
||||
output_token_ids = output[0][0][0][num_prompt_tokens:]
|
||||
# [0] first (and only) request output
|
||||
# [0] token_ids (not text)
|
||||
# [0] first (and only) output completion
|
||||
|
||||
return output_token_ids
|
||||
|
||||
|
||||
class TestOneTokenBadWord:
|
||||
MODEL = "TheBloke/Llama-2-7B-fp16"
|
||||
|
||||
PROMPT = "Hi! How are"
|
||||
TARGET_TOKEN = "you"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||
add_prefix_space=True)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id = self._encode(self.TARGET_TOKEN,
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_one_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[0] == self.target_token_id
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN])
|
||||
assert self.target_token_id not in output_token_ids
|
||||
|
||||
def _generate(self,
|
||||
model: LLM,
|
||||
bad_words: Optional[List[str]] = None) -> List[int]:
|
||||
return _generate(
|
||||
model=model,
|
||||
prompt=self.PROMPT,
|
||||
num_prompt_tokens=self.num_prompt_tokens,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
def _encode(self,
|
||||
prompt: str,
|
||||
add_special_tokens: bool = True) -> List[int]:
|
||||
return self.tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens).input_ids
|
||||
|
||||
|
||||
class TestTwoTokenBadWord:
|
||||
# Another model (with a different tokenizer behaviour)
|
||||
MODEL = "openai-community/gpt2"
|
||||
|
||||
PROMPT = "How old are you? I am 10"
|
||||
TARGET_TOKEN1 = "years"
|
||||
TARGET_TOKEN2 = "old"
|
||||
NEIGHBOUR_TOKEN2 = "older"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||
add_prefix_space=True)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
|
||||
add_special_tokens=False)[0]
|
||||
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
|
||||
add_special_tokens=False)[0]
|
||||
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN1])
|
||||
assert self.target_token_id1 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN2])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert self.target_token_id2 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.target_token_id2])
|
||||
# Model dependent behaviour
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.neighbour_token_id2
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm,
|
||||
bad_words=[
|
||||
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
|
||||
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
|
||||
])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.target_token_id2])
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.neighbour_token_id2
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.neighbour_token_id2])
|
||||
assert ((self.target_token_id2 in output_token_ids)
|
||||
or (self.neighbour_token_id2 in output_token_ids))
|
||||
|
||||
def _generate(self,
|
||||
model: LLM,
|
||||
bad_words: Optional[List[str]] = None) -> List[int]:
|
||||
return _generate(
|
||||
model=model,
|
||||
prompt=self.PROMPT,
|
||||
num_prompt_tokens=self.num_prompt_tokens,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
|
||||
searched = False
|
||||
|
||||
for start in range(len(sequence)):
|
||||
end = start + len(subsequence)
|
||||
current_subsequence = sequence[start:end]
|
||||
|
||||
if len(current_subsequence) < len(subsequence):
|
||||
continue
|
||||
|
||||
searched = True
|
||||
|
||||
assert len(current_subsequence) == len(subsequence)
|
||||
|
||||
if current_subsequence == subsequence:
|
||||
return True
|
||||
|
||||
assert searched, "All subsequences did not match in length..."
|
||||
|
||||
return False
|
||||
|
||||
def _encode(self,
|
||||
prompt: str,
|
||||
add_special_tokens: bool = True) -> List[int]:
|
||||
return self.tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens).input_ids
|
@ -26,7 +26,8 @@ from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
@ -34,6 +35,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputRegistry, PromptType)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
@ -1963,6 +1965,7 @@ class LLMEngine:
|
||||
logits_processors field. Returns the modified sampling params."""
|
||||
|
||||
logits_processors = []
|
||||
|
||||
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
||||
|
||||
logger.debug(
|
||||
@ -1984,7 +1987,7 @@ class LLMEngine:
|
||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
|
||||
processors = get_logits_processors(
|
||||
processors = get_openai_logits_processors(
|
||||
logit_bias=sampling_params.logit_bias,
|
||||
allowed_token_ids=sampling_params.allowed_token_ids,
|
||||
tokenizer=tokenizer)
|
||||
@ -1994,6 +1997,12 @@ class LLMEngine:
|
||||
sampling_params.logit_bias = None
|
||||
sampling_params.allowed_token_ids = None
|
||||
|
||||
if len(sampling_params.bad_words) > 0:
|
||||
tokenizer = self.get_tokenizer(lora_request)
|
||||
processors = get_bad_words_logits_processors(
|
||||
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
|
||||
logits_processors.extend(processors)
|
||||
|
||||
if logits_processors:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = logits_processors
|
||||
|
119
vllm/logits_process.py
Normal file
119
vllm/logits_process.py
Normal file
@ -0,0 +1,119 @@
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
||||
Callable[[List[int], List[int], torch.Tensor],
|
||||
torch.Tensor]]
|
||||
"""LogitsProcessor is a function that takes a list
|
||||
of previously generated tokens, the logits tensor
|
||||
for the next token and, optionally, prompt tokens as a
|
||||
first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
def get_bad_words_logits_processors(
|
||||
bad_words: List[str],
|
||||
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
|
||||
bad_words_ids: List[List[int]] = list()
|
||||
|
||||
for bad_word in bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# Mistral tokenizers should not add special tokens
|
||||
prompt_token_ids = tokenizer.encode(prompt=prompt)
|
||||
else:
|
||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||
add_special_tokens=False)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
if (not add_prefix_space) or (
|
||||
add_prefix_space
|
||||
and prompt_token_ids[0] != bad_words_ids[-1][0]
|
||||
and len(prompt_token_ids) == len(bad_words_ids[-1])):
|
||||
bad_words_ids.append(prompt_token_ids)
|
||||
|
||||
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor:
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
_NEUTRAL_LOGIT = 0.0
|
||||
|
||||
def __init__(self, bad_words_ids: List[List[int]]):
|
||||
self.bad_words_ids = bad_words_ids
|
||||
self.word_bias: torch.FloatTensor = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
past_tokens_ids: Union[List[int], Tuple[int]],
|
||||
logits: torch.FloatTensor,
|
||||
) -> torch.Tensor:
|
||||
if self.word_bias is None:
|
||||
self._init_word_bias(logits=logits)
|
||||
|
||||
last_token_bias = torch.zeros_like(logits)
|
||||
|
||||
for bad_word_ids in self.bad_words_ids:
|
||||
if len(bad_word_ids) == 1: # 1-token words already processed
|
||||
continue
|
||||
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
is_match = tuple(actual_prefix) == tuple(expected_prefix)
|
||||
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
|
||||
else self._NEUTRAL_LOGIT)
|
||||
|
||||
logits = logits + self.word_bias + last_token_bias
|
||||
|
||||
return logits
|
||||
|
||||
def _init_word_bias(self, logits: torch.FloatTensor) -> None:
|
||||
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
|
||||
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
|
||||
self._check_token_ids_bounds(vocab_size=vocab_size)
|
||||
|
||||
self.word_bias = torch.zeros((vocab_size, ),
|
||||
dtype=torch.float,
|
||||
device=logits.device)
|
||||
|
||||
for bad_word_ids in self.bad_words_ids:
|
||||
if len(bad_word_ids) == 1:
|
||||
bad_word_id = bad_word_ids[-1]
|
||||
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT
|
||||
|
||||
def _check_token_ids_bounds(self, vocab_size: int) -> None:
|
||||
invalid_token_ids = []
|
||||
|
||||
for bad_word_ids in self.bad_words_ids:
|
||||
for token_id in bad_word_ids:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if len(invalid_token_ids) > 0:
|
||||
raise ValueError(
|
||||
f"The model vocabulary size is {vocab_size},"
|
||||
f" but the following tokens"
|
||||
f" were specified as bad: {invalid_token_ids}."
|
||||
f" All token id values should be integers satisfying:"
|
||||
f" 0 <= token_id < {vocab_size}.")
|
@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
|
@ -9,7 +9,8 @@ from lmformatenforcer.integrations.vllm import (
|
||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
|
@ -3,14 +3,14 @@ import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,16 +24,6 @@ class SamplingType(IntEnum):
|
||||
RANDOM_SEED = 2
|
||||
|
||||
|
||||
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
||||
Callable[[List[int], List[int], torch.Tensor],
|
||||
torch.Tensor]]
|
||||
"""LogitsProcessor is a function that takes a list
|
||||
of previously generated tokens, the logits tensor
|
||||
for the next token and, optionally, prompt tokens as a
|
||||
first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
# maybe make msgspec?
|
||||
@dataclass
|
||||
class GuidedDecodingParams:
|
||||
@ -139,6 +129,10 @@ class SamplingParams(
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
generated. The returned output will contain the stop tokens unless
|
||||
the stop tokens are special tokens.
|
||||
bad_words: List of words that are not allowed to be generated.
|
||||
More precisely, only the last token of a corresponding
|
||||
token sequence is not allowed when the next generated token
|
||||
can complete the sequence.
|
||||
include_stop_str_in_output: Whether to include the stop strings in
|
||||
output text. Defaults to False.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
@ -186,6 +180,7 @@ class SamplingParams(
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
bad_words: Optional[List[str]] = None
|
||||
ignore_eos: bool = False
|
||||
max_tokens: Optional[int] = 16
|
||||
min_tokens: int = 0
|
||||
@ -228,6 +223,7 @@ class SamplingParams(
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
bad_words: Optional[List[str]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
@ -267,6 +263,7 @@ class SamplingParams(
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
bad_words=bad_words,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
ignore_eos=ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
@ -298,26 +295,36 @@ class SamplingParams(
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
self._real_n = self.n
|
||||
self.n = self.best_of
|
||||
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
||||
self.temperature = max(self.temperature, _MAX_TEMP)
|
||||
|
||||
if self.seed == -1:
|
||||
self.seed = None
|
||||
else:
|
||||
self.seed = self.seed
|
||||
|
||||
if self.stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
else:
|
||||
self.stop = list(self.stop)
|
||||
|
||||
if self.stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(self.stop_token_ids)
|
||||
|
||||
if self.bad_words is None:
|
||||
self.bad_words = []
|
||||
else:
|
||||
self.bad_words = list(self.bad_words)
|
||||
|
||||
self.logprobs = 1 if self.logprobs is True else self.logprobs
|
||||
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
|
||||
self.prompt_logprobs)
|
||||
@ -468,6 +475,7 @@ class SamplingParams(
|
||||
f"seed={self.seed}, "
|
||||
f"stop={self.stop}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"bad_words={self.bad_words}, "
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
|
Loading…
x
Reference in New Issue
Block a user