[Core] Avoid the need to pass None values to Sequence.inputs (#5099)

This commit is contained in:
Cyrus Leung 2024-05-30 07:05:01 +08:00 committed by GitHub
parent eb6c50cdc2
commit b1c255630d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 6 additions and 19 deletions

View File

@ -234,7 +234,6 @@ def test_append_slot_cow():
inputs={
"prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
"multi_modal_data": None
},
block_size=block_size)
@ -525,7 +524,6 @@ def test_sliding_window_multi_seq():
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
"multi_modal_data": None
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1",

View File

@ -25,7 +25,6 @@ def create_dummy_prompt(
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
@ -103,11 +102,7 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
inputs={"prompt_token_ids": prompt_token_ids},
block_size=16,
)

View File

@ -15,11 +15,7 @@ def sequence_with_eos(text: str, eos_token: str,
"""
seq = Sequence(
seq_id=0,
inputs={
"prompt": "",
"prompt_token_ids": [],
"multi_modal_data": None,
},
inputs={"prompt_token_ids": []},
block_size=16,
eos_token_id=eos_token_id,
)

View File

@ -74,7 +74,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,

View File

@ -126,7 +126,6 @@ def create_sequence(prompt_token_ids=None):
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

View File

@ -126,5 +126,5 @@ PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalData"]
prompt: NotRequired[Optional[str]]
multi_modal_data: NotRequired[Optional["MultiModalData"]]

View File

@ -249,7 +249,7 @@ class Sequence:
@property
def prompt(self) -> Optional[str]:
return self.inputs["prompt"]
return self.inputs.get("prompt")
@property
def prompt_token_ids(self) -> List[int]:
@ -257,7 +257,7 @@ class Sequence:
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs["multi_modal_data"]
return self.inputs.get("multi_modal_data")
@property
def lora_int_id(self) -> int: