[BugFix][V1] Fix overhead related to bad_words sampling when not in use (#14894)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
f6137adbcb
commit
fc1f67715d
@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata(
|
|||||||
if req.sampling_params.allowed_token_ids:
|
if req.sampling_params.allowed_token_ids:
|
||||||
allowed_token_ids_mask[index_in_input_batch][
|
allowed_token_ids_mask[index_in_input_batch][
|
||||||
req.sampling_params.allowed_token_ids] = True
|
req.sampling_params.allowed_token_ids] = True
|
||||||
bad_words_token_ids[
|
if req.sampling_params.bad_words_token_ids:
|
||||||
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
bad_words_token_ids[
|
||||||
|
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||||
|
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||||
|
@ -235,7 +235,7 @@ class SamplingParams(
|
|||||||
|
|
||||||
# Fields used for bad words
|
# Fields used for bad words
|
||||||
bad_words: Optional[list[str]] = None
|
bad_words: Optional[list[str]] = None
|
||||||
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
|
_bad_words_token_ids: Optional[list[list[int]]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_optional(
|
def from_optional(
|
||||||
@ -464,8 +464,9 @@ class SamplingParams(
|
|||||||
self.stop_token_ids = list(eos_ids)
|
self.stop_token_ids = list(eos_ids)
|
||||||
|
|
||||||
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||||
if self.bad_words is None:
|
if not self.bad_words:
|
||||||
return
|
return
|
||||||
|
self._bad_words_token_ids = []
|
||||||
for bad_word in self.bad_words:
|
for bad_word in self.bad_words:
|
||||||
# To prohibit words both at the beginning
|
# To prohibit words both at the beginning
|
||||||
# and in the middle of text
|
# and in the middle of text
|
||||||
@ -516,7 +517,7 @@ class SamplingParams(
|
|||||||
return self._all_stop_token_ids
|
return self._all_stop_token_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bad_words_token_ids(self) -> list[list[int]]:
|
def bad_words_token_ids(self) -> Optional[list[list[int]]]:
|
||||||
# For internal use only. Backward compatibility not guaranteed
|
# For internal use only. Backward compatibility not guaranteed
|
||||||
return self._bad_words_token_ids
|
return self._bad_words_token_ids
|
||||||
|
|
||||||
|
@ -324,8 +324,9 @@ class InputBatch:
|
|||||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||||
sampling_params.allowed_token_ids] = False
|
sampling_params.allowed_token_ids] = False
|
||||||
|
|
||||||
self.bad_words_token_ids[
|
if sampling_params.bad_words_token_ids:
|
||||||
req_index] = sampling_params.bad_words_token_ids
|
self.bad_words_token_ids[
|
||||||
|
req_index] = sampling_params.bad_words_token_ids
|
||||||
|
|
||||||
# Add request lora ID
|
# Add request lora ID
|
||||||
if request.lora_request:
|
if request.lora_request:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user