[Frontend] Bad words sampling parameter (#9717)
Signed-off-by: Vasily Alexeev <alvasian@yandex.ru>
This commit is contained in:
parent
55137e8ee3
commit
07e981fdf4
185
tests/samplers/test_no_bad_words.py
Normal file
185
tests/samplers/test_no_bad_words.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
"""Make sure bad_words works.
|
||||||
|
|
||||||
|
Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
model: LLM,
|
||||||
|
prompt: str,
|
||||||
|
num_prompt_tokens: int,
|
||||||
|
temperature: float = 0,
|
||||||
|
bad_words: Optional[List[str]] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
bad_words=bad_words,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [([output_token_ids, ], [output_text, ]), ]
|
||||||
|
output = model.generate([prompt], sampling_params=sampling_params)
|
||||||
|
|
||||||
|
output_token_ids = output[0][0][0][num_prompt_tokens:]
|
||||||
|
# [0] first (and only) request output
|
||||||
|
# [0] token_ids (not text)
|
||||||
|
# [0] first (and only) output completion
|
||||||
|
|
||||||
|
return output_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestOneTokenBadWord:
|
||||||
|
MODEL = "TheBloke/Llama-2-7B-fp16"
|
||||||
|
|
||||||
|
PROMPT = "Hi! How are"
|
||||||
|
TARGET_TOKEN = "you"
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||||
|
add_prefix_space=True)
|
||||||
|
|
||||||
|
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||||
|
self.target_token_id = self._encode(self.TARGET_TOKEN,
|
||||||
|
add_special_tokens=False)[0]
|
||||||
|
|
||||||
|
def test_one_token_bad_word(self, vllm_runner):
|
||||||
|
with vllm_runner(self.MODEL) as llm:
|
||||||
|
output_token_ids = self._generate(llm)
|
||||||
|
assert output_token_ids[0] == self.target_token_id
|
||||||
|
|
||||||
|
output_token_ids = self._generate(llm,
|
||||||
|
bad_words=[self.TARGET_TOKEN])
|
||||||
|
assert self.target_token_id not in output_token_ids
|
||||||
|
|
||||||
|
def _generate(self,
|
||||||
|
model: LLM,
|
||||||
|
bad_words: Optional[List[str]] = None) -> List[int]:
|
||||||
|
return _generate(
|
||||||
|
model=model,
|
||||||
|
prompt=self.PROMPT,
|
||||||
|
num_prompt_tokens=self.num_prompt_tokens,
|
||||||
|
bad_words=bad_words,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _encode(self,
|
||||||
|
prompt: str,
|
||||||
|
add_special_tokens: bool = True) -> List[int]:
|
||||||
|
return self.tokenizer(prompt,
|
||||||
|
add_special_tokens=add_special_tokens).input_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwoTokenBadWord:
|
||||||
|
# Another model (with a different tokenizer behaviour)
|
||||||
|
MODEL = "openai-community/gpt2"
|
||||||
|
|
||||||
|
PROMPT = "How old are you? I am 10"
|
||||||
|
TARGET_TOKEN1 = "years"
|
||||||
|
TARGET_TOKEN2 = "old"
|
||||||
|
NEIGHBOUR_TOKEN2 = "older"
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||||
|
add_prefix_space=True)
|
||||||
|
|
||||||
|
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||||
|
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
|
||||||
|
add_special_tokens=False)[0]
|
||||||
|
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
|
||||||
|
add_special_tokens=False)[0]
|
||||||
|
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
|
||||||
|
add_special_tokens=False)[0]
|
||||||
|
|
||||||
|
def test_two_token_bad_word(self, vllm_runner):
|
||||||
|
with vllm_runner(self.MODEL) as llm:
|
||||||
|
output_token_ids = self._generate(llm)
|
||||||
|
assert output_token_ids[:2] == [
|
||||||
|
self.target_token_id1, self.target_token_id2
|
||||||
|
]
|
||||||
|
|
||||||
|
output_token_ids = self._generate(llm,
|
||||||
|
bad_words=[self.TARGET_TOKEN1])
|
||||||
|
assert self.target_token_id1 not in output_token_ids
|
||||||
|
|
||||||
|
output_token_ids = self._generate(llm,
|
||||||
|
bad_words=[self.TARGET_TOKEN2])
|
||||||
|
assert output_token_ids[0] == self.target_token_id1
|
||||||
|
assert self.target_token_id2 not in output_token_ids
|
||||||
|
|
||||||
|
output_token_ids = self._generate(
|
||||||
|
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
|
||||||
|
assert output_token_ids[0] == self.target_token_id1
|
||||||
|
assert output_token_ids[:2] != [
|
||||||
|
self.target_token_id1, self.target_token_id2
|
||||||
|
]
|
||||||
|
assert not self._contains(
|
||||||
|
output_token_ids,
|
||||||
|
[self.target_token_id1, self.target_token_id2])
|
||||||
|
# Model dependent behaviour
|
||||||
|
assert output_token_ids[:2] == [
|
||||||
|
self.target_token_id1, self.neighbour_token_id2
|
||||||
|
]
|
||||||
|
|
||||||
|
output_token_ids = self._generate(
|
||||||
|
llm,
|
||||||
|
bad_words=[
|
||||||
|
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
|
||||||
|
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
|
||||||
|
])
|
||||||
|
assert output_token_ids[0] == self.target_token_id1
|
||||||
|
assert output_token_ids[:2] != [
|
||||||
|
self.target_token_id1, self.target_token_id2
|
||||||
|
]
|
||||||
|
assert not self._contains(
|
||||||
|
output_token_ids,
|
||||||
|
[self.target_token_id1, self.target_token_id2])
|
||||||
|
assert output_token_ids[:2] != [
|
||||||
|
self.target_token_id1, self.neighbour_token_id2
|
||||||
|
]
|
||||||
|
assert not self._contains(
|
||||||
|
output_token_ids,
|
||||||
|
[self.target_token_id1, self.neighbour_token_id2])
|
||||||
|
assert ((self.target_token_id2 in output_token_ids)
|
||||||
|
or (self.neighbour_token_id2 in output_token_ids))
|
||||||
|
|
||||||
|
def _generate(self,
|
||||||
|
model: LLM,
|
||||||
|
bad_words: Optional[List[str]] = None) -> List[int]:
|
||||||
|
return _generate(
|
||||||
|
model=model,
|
||||||
|
prompt=self.PROMPT,
|
||||||
|
num_prompt_tokens=self.num_prompt_tokens,
|
||||||
|
bad_words=bad_words,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
|
||||||
|
searched = False
|
||||||
|
|
||||||
|
for start in range(len(sequence)):
|
||||||
|
end = start + len(subsequence)
|
||||||
|
current_subsequence = sequence[start:end]
|
||||||
|
|
||||||
|
if len(current_subsequence) < len(subsequence):
|
||||||
|
continue
|
||||||
|
|
||||||
|
searched = True
|
||||||
|
|
||||||
|
assert len(current_subsequence) == len(subsequence)
|
||||||
|
|
||||||
|
if current_subsequence == subsequence:
|
||||||
|
return True
|
||||||
|
|
||||||
|
assert searched, "All subsequences did not match in length..."
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _encode(self,
|
||||||
|
prompt: str,
|
||||||
|
add_special_tokens: bool = True) -> List[int]:
|
||||||
|
return self.tokenizer(prompt,
|
||||||
|
add_special_tokens=add_special_tokens).input_ids
|
@ -26,7 +26,8 @@ from vllm.engine.output_processor.interfaces import (
|
|||||||
SequenceGroupOutputProcessor)
|
SequenceGroupOutputProcessor)
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
from vllm.entrypoints.openai.logits_processors import (
|
||||||
|
get_logits_processors as get_openai_logits_processors)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
@ -34,6 +35,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
|||||||
EncoderDecoderInputs, InputRegistry, PromptType)
|
EncoderDecoderInputs, InputRegistry, PromptType)
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.logits_process import get_bad_words_logits_processors
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_local_guided_decoding_logits_processor)
|
get_local_guided_decoding_logits_processor)
|
||||||
@ -1963,6 +1965,7 @@ class LLMEngine:
|
|||||||
logits_processors field. Returns the modified sampling params."""
|
logits_processors field. Returns the modified sampling params."""
|
||||||
|
|
||||||
logits_processors = []
|
logits_processors = []
|
||||||
|
|
||||||
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -1984,7 +1987,7 @@ class LLMEngine:
|
|||||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||||
|
|
||||||
processors = get_logits_processors(
|
processors = get_openai_logits_processors(
|
||||||
logit_bias=sampling_params.logit_bias,
|
logit_bias=sampling_params.logit_bias,
|
||||||
allowed_token_ids=sampling_params.allowed_token_ids,
|
allowed_token_ids=sampling_params.allowed_token_ids,
|
||||||
tokenizer=tokenizer)
|
tokenizer=tokenizer)
|
||||||
@ -1994,6 +1997,12 @@ class LLMEngine:
|
|||||||
sampling_params.logit_bias = None
|
sampling_params.logit_bias = None
|
||||||
sampling_params.allowed_token_ids = None
|
sampling_params.allowed_token_ids = None
|
||||||
|
|
||||||
|
if len(sampling_params.bad_words) > 0:
|
||||||
|
tokenizer = self.get_tokenizer(lora_request)
|
||||||
|
processors = get_bad_words_logits_processors(
|
||||||
|
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
|
||||||
|
logits_processors.extend(processors)
|
||||||
|
|
||||||
if logits_processors:
|
if logits_processors:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
sampling_params.logits_processors = logits_processors
|
sampling_params.logits_processors = logits_processors
|
||||||
|
119
vllm/logits_process.py
Normal file
119
vllm/logits_process.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
|
||||||
|
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
||||||
|
Callable[[List[int], List[int], torch.Tensor],
|
||||||
|
torch.Tensor]]
|
||||||
|
"""LogitsProcessor is a function that takes a list
|
||||||
|
of previously generated tokens, the logits tensor
|
||||||
|
for the next token and, optionally, prompt tokens as a
|
||||||
|
first argument, and returns a modified tensor of logits
|
||||||
|
to sample from."""
|
||||||
|
|
||||||
|
|
||||||
|
def get_bad_words_logits_processors(
|
||||||
|
bad_words: List[str],
|
||||||
|
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
|
||||||
|
bad_words_ids: List[List[int]] = list()
|
||||||
|
|
||||||
|
for bad_word in bad_words:
|
||||||
|
# To prohibit words both at the beginning
|
||||||
|
# and in the middle of text
|
||||||
|
# (related to add_prefix_space tokenizer parameter)
|
||||||
|
for add_prefix_space in [False, True]:
|
||||||
|
prefix = " " if add_prefix_space else ""
|
||||||
|
prompt = prefix + bad_word.lstrip()
|
||||||
|
|
||||||
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
# Mistral tokenizers should not add special tokens
|
||||||
|
prompt_token_ids = tokenizer.encode(prompt=prompt)
|
||||||
|
else:
|
||||||
|
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
|
# If no space at the beginning
|
||||||
|
# or if prefix space produces a new word token
|
||||||
|
if (not add_prefix_space) or (
|
||||||
|
add_prefix_space
|
||||||
|
and prompt_token_ids[0] != bad_words_ids[-1][0]
|
||||||
|
and len(prompt_token_ids) == len(bad_words_ids[-1])):
|
||||||
|
bad_words_ids.append(prompt_token_ids)
|
||||||
|
|
||||||
|
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
|
||||||
|
|
||||||
|
|
||||||
|
class NoBadWordsLogitsProcessor:
|
||||||
|
_SMALLEST_LOGIT = float("-inf")
|
||||||
|
_NEUTRAL_LOGIT = 0.0
|
||||||
|
|
||||||
|
def __init__(self, bad_words_ids: List[List[int]]):
|
||||||
|
self.bad_words_ids = bad_words_ids
|
||||||
|
self.word_bias: torch.FloatTensor = None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
past_tokens_ids: Union[List[int], Tuple[int]],
|
||||||
|
logits: torch.FloatTensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.word_bias is None:
|
||||||
|
self._init_word_bias(logits=logits)
|
||||||
|
|
||||||
|
last_token_bias = torch.zeros_like(logits)
|
||||||
|
|
||||||
|
for bad_word_ids in self.bad_words_ids:
|
||||||
|
if len(bad_word_ids) == 1: # 1-token words already processed
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prefix_length = len(bad_word_ids) - 1
|
||||||
|
last_token_id = bad_word_ids[-1]
|
||||||
|
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||||
|
expected_prefix = bad_word_ids[:prefix_length]
|
||||||
|
|
||||||
|
assert len(actual_prefix) == len(expected_prefix)
|
||||||
|
|
||||||
|
is_match = tuple(actual_prefix) == tuple(expected_prefix)
|
||||||
|
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
|
||||||
|
else self._NEUTRAL_LOGIT)
|
||||||
|
|
||||||
|
logits = logits + self.word_bias + last_token_bias
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _init_word_bias(self, logits: torch.FloatTensor) -> None:
|
||||||
|
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
|
||||||
|
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
|
||||||
|
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
|
||||||
|
self._check_token_ids_bounds(vocab_size=vocab_size)
|
||||||
|
|
||||||
|
self.word_bias = torch.zeros((vocab_size, ),
|
||||||
|
dtype=torch.float,
|
||||||
|
device=logits.device)
|
||||||
|
|
||||||
|
for bad_word_ids in self.bad_words_ids:
|
||||||
|
if len(bad_word_ids) == 1:
|
||||||
|
bad_word_id = bad_word_ids[-1]
|
||||||
|
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT
|
||||||
|
|
||||||
|
def _check_token_ids_bounds(self, vocab_size: int) -> None:
|
||||||
|
invalid_token_ids = []
|
||||||
|
|
||||||
|
for bad_word_ids in self.bad_words_ids:
|
||||||
|
for token_id in bad_word_ids:
|
||||||
|
if token_id < 0 or token_id >= vocab_size:
|
||||||
|
invalid_token_ids.append(token_id)
|
||||||
|
|
||||||
|
if len(invalid_token_ids) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model vocabulary size is {vocab_size},"
|
||||||
|
f" but the following tokens"
|
||||||
|
f" were specified as bad: {invalid_token_ids}."
|
||||||
|
f" All token id values should be integers satisfying:"
|
||||||
|
f" 0 <= token_id < {vocab_size}.")
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
from vllm.logits_process import LogitsProcessor
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
|
||||||
async def get_guided_decoding_logits_processor(
|
async def get_guided_decoding_logits_processor(
|
||||||
|
@ -9,7 +9,8 @@ from lmformatenforcer.integrations.vllm import (
|
|||||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
from vllm.logits_process import LogitsProcessor
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
|
||||||
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
|
@ -3,14 +3,14 @@ import copy
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import torch
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.logits_process import LogitsProcessor
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,16 +24,6 @@ class SamplingType(IntEnum):
|
|||||||
RANDOM_SEED = 2
|
RANDOM_SEED = 2
|
||||||
|
|
||||||
|
|
||||||
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
|
||||||
Callable[[List[int], List[int], torch.Tensor],
|
|
||||||
torch.Tensor]]
|
|
||||||
"""LogitsProcessor is a function that takes a list
|
|
||||||
of previously generated tokens, the logits tensor
|
|
||||||
for the next token and, optionally, prompt tokens as a
|
|
||||||
first argument, and returns a modified tensor of logits
|
|
||||||
to sample from."""
|
|
||||||
|
|
||||||
|
|
||||||
# maybe make msgspec?
|
# maybe make msgspec?
|
||||||
@dataclass
|
@dataclass
|
||||||
class GuidedDecodingParams:
|
class GuidedDecodingParams:
|
||||||
@ -139,6 +129,10 @@ class SamplingParams(
|
|||||||
stop_token_ids: List of tokens that stop the generation when they are
|
stop_token_ids: List of tokens that stop the generation when they are
|
||||||
generated. The returned output will contain the stop tokens unless
|
generated. The returned output will contain the stop tokens unless
|
||||||
the stop tokens are special tokens.
|
the stop tokens are special tokens.
|
||||||
|
bad_words: List of words that are not allowed to be generated.
|
||||||
|
More precisely, only the last token of a corresponding
|
||||||
|
token sequence is not allowed when the next generated token
|
||||||
|
can complete the sequence.
|
||||||
include_stop_str_in_output: Whether to include the stop strings in
|
include_stop_str_in_output: Whether to include the stop strings in
|
||||||
output text. Defaults to False.
|
output text. Defaults to False.
|
||||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||||
@ -186,6 +180,7 @@ class SamplingParams(
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
stop_token_ids: Optional[List[int]] = None
|
stop_token_ids: Optional[List[int]] = None
|
||||||
|
bad_words: Optional[List[str]] = None
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = 16
|
||||||
min_tokens: int = 0
|
min_tokens: int = 0
|
||||||
@ -228,6 +223,7 @@ class SamplingParams(
|
|||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
bad_words: Optional[List[str]] = None,
|
||||||
include_stop_str_in_output: bool = False,
|
include_stop_str_in_output: bool = False,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: Optional[int] = 16,
|
max_tokens: Optional[int] = 16,
|
||||||
@ -267,6 +263,7 @@ class SamplingParams(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
|
bad_words=bad_words,
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -298,26 +295,36 @@ class SamplingParams(
|
|||||||
f"got n={self.n} and best_of={self.best_of}.")
|
f"got n={self.n} and best_of={self.best_of}.")
|
||||||
self._real_n = self.n
|
self._real_n = self.n
|
||||||
self.n = self.best_of
|
self.n = self.best_of
|
||||||
|
|
||||||
if 0 < self.temperature < _MAX_TEMP:
|
if 0 < self.temperature < _MAX_TEMP:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"temperature %s is less than %s, which may cause numerical "
|
"temperature %s is less than %s, which may cause numerical "
|
||||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||||
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
||||||
self.temperature = max(self.temperature, _MAX_TEMP)
|
self.temperature = max(self.temperature, _MAX_TEMP)
|
||||||
|
|
||||||
if self.seed == -1:
|
if self.seed == -1:
|
||||||
self.seed = None
|
self.seed = None
|
||||||
else:
|
else:
|
||||||
self.seed = self.seed
|
self.seed = self.seed
|
||||||
|
|
||||||
if self.stop is None:
|
if self.stop is None:
|
||||||
self.stop = []
|
self.stop = []
|
||||||
elif isinstance(self.stop, str):
|
elif isinstance(self.stop, str):
|
||||||
self.stop = [self.stop]
|
self.stop = [self.stop]
|
||||||
else:
|
else:
|
||||||
self.stop = list(self.stop)
|
self.stop = list(self.stop)
|
||||||
|
|
||||||
if self.stop_token_ids is None:
|
if self.stop_token_ids is None:
|
||||||
self.stop_token_ids = []
|
self.stop_token_ids = []
|
||||||
else:
|
else:
|
||||||
self.stop_token_ids = list(self.stop_token_ids)
|
self.stop_token_ids = list(self.stop_token_ids)
|
||||||
|
|
||||||
|
if self.bad_words is None:
|
||||||
|
self.bad_words = []
|
||||||
|
else:
|
||||||
|
self.bad_words = list(self.bad_words)
|
||||||
|
|
||||||
self.logprobs = 1 if self.logprobs is True else self.logprobs
|
self.logprobs = 1 if self.logprobs is True else self.logprobs
|
||||||
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
|
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
|
||||||
self.prompt_logprobs)
|
self.prompt_logprobs)
|
||||||
@ -468,6 +475,7 @@ class SamplingParams(
|
|||||||
f"seed={self.seed}, "
|
f"seed={self.seed}, "
|
||||||
f"stop={self.stop}, "
|
f"stop={self.stop}, "
|
||||||
f"stop_token_ids={self.stop_token_ids}, "
|
f"stop_token_ids={self.stop_token_ids}, "
|
||||||
|
f"bad_words={self.bad_words}, "
|
||||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||||
f"ignore_eos={self.ignore_eos}, "
|
f"ignore_eos={self.ignore_eos}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user