[BugFix][V1] Fix overhead related to bad_words sampling when not in use (#14894)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-16 14:53:34 -07:00 committed by GitHub
parent f6137adbcb
commit fc1f67715d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 7 deletions

View File

@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata(
if req.sampling_params.allowed_token_ids: if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][ allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True req.sampling_params.allowed_token_ids] = True
bad_words_token_ids[ if req.sampling_params.bad_words_token_ids:
index_in_input_batch] = req.sampling_params.bad_words_token_ids bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids
return SamplingMetadata( return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, temperature=torch.tensor(temperature, dtype=torch.float,

View File

@ -235,7 +235,7 @@ class SamplingParams(
# Fields used for bad words # Fields used for bad words
bad_words: Optional[list[str]] = None bad_words: Optional[list[str]] = None
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list) _bad_words_token_ids: Optional[list[list[int]]] = None
@staticmethod @staticmethod
def from_optional( def from_optional(
@ -464,8 +464,9 @@ class SamplingParams(
self.stop_token_ids = list(eos_ids) self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
if self.bad_words is None: if not self.bad_words:
return return
self._bad_words_token_ids = []
for bad_word in self.bad_words: for bad_word in self.bad_words:
# To prohibit words both at the beginning # To prohibit words both at the beginning
# and in the middle of text # and in the middle of text
@ -516,7 +517,7 @@ class SamplingParams(
return self._all_stop_token_ids return self._all_stop_token_ids
@property @property
def bad_words_token_ids(self) -> list[list[int]]: def bad_words_token_ids(self) -> Optional[list[list[int]]]:
# For internal use only. Backward compatibility not guaranteed # For internal use only. Backward compatibility not guaranteed
return self._bad_words_token_ids return self._bad_words_token_ids

View File

@ -324,8 +324,9 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[req_index][ self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False sampling_params.allowed_token_ids] = False
self.bad_words_token_ids[ if sampling_params.bad_words_token_ids:
req_index] = sampling_params.bad_words_token_ids self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request: