[BugFix][V1] Fix overhead related to bad_words sampling when not in use (#14894)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-16 14:53:34 -07:00 committed by GitHub
parent f6137adbcb
commit fc1f67715d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 7 deletions

View File

@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata(
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids
if req.sampling_params.bad_words_token_ids:
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,

View File

@ -235,7 +235,7 @@ class SamplingParams(
# Fields used for bad words
bad_words: Optional[list[str]] = None
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
_bad_words_token_ids: Optional[list[list[int]]] = None
@staticmethod
def from_optional(
@ -464,8 +464,9 @@ class SamplingParams(
self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
if self.bad_words is None:
if not self.bad_words:
return
self._bad_words_token_ids = []
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
@ -516,7 +517,7 @@ class SamplingParams(
return self._all_stop_token_ids
@property
def bad_words_token_ids(self) -> list[list[int]]:
def bad_words_token_ids(self) -> Optional[list[list[int]]]:
# For internal use only. Backward compatibility not guaranteed
return self._bad_words_token_ids

View File

@ -324,8 +324,9 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
# Add request lora ID
if request.lora_request: