[BugFix][V1] Fix overhead related to bad_words sampling when not in use (#14894)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
f6137adbcb
commit
fc1f67715d
@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata(
|
||||
if req.sampling_params.allowed_token_ids:
|
||||
allowed_token_ids_mask[index_in_input_batch][
|
||||
req.sampling_params.allowed_token_ids] = True
|
||||
bad_words_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||
if req.sampling_params.bad_words_token_ids:
|
||||
bad_words_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
|
@ -235,7 +235,7 @@ class SamplingParams(
|
||||
|
||||
# Fields used for bad words
|
||||
bad_words: Optional[list[str]] = None
|
||||
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
|
||||
_bad_words_token_ids: Optional[list[list[int]]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
@ -464,8 +464,9 @@ class SamplingParams(
|
||||
self.stop_token_ids = list(eos_ids)
|
||||
|
||||
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
if self.bad_words is None:
|
||||
if not self.bad_words:
|
||||
return
|
||||
self._bad_words_token_ids = []
|
||||
for bad_word in self.bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
@ -516,7 +517,7 @@ class SamplingParams(
|
||||
return self._all_stop_token_ids
|
||||
|
||||
@property
|
||||
def bad_words_token_ids(self) -> list[list[int]]:
|
||||
def bad_words_token_ids(self) -> Optional[list[list[int]]]:
|
||||
# For internal use only. Backward compatibility not guaranteed
|
||||
return self._bad_words_token_ids
|
||||
|
||||
|
@ -324,8 +324,9 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
|
Loading…
x
Reference in New Issue
Block a user