[mypy] Pass type checking in vllm/inputs (#11680)

Signed-off-by: Tobias Pitters <tobias.pitters@gmail.com>
This commit is contained in:
Tobias Pitters 2025-01-02 17:18:15 +01:00 committed by GitHub
parent 23c1b10a4c
commit b6087a6bee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 14 deletions

View File

@ -23,6 +23,7 @@ run_mypy vllm/compilation
run_mypy vllm/distributed run_mypy vllm/distributed
run_mypy vllm/engine run_mypy vllm/engine
run_mypy vllm/executor run_mypy vllm/executor
run_mypy vllm/inputs
run_mypy vllm/lora run_mypy vllm/lora
run_mypy vllm/model_executor run_mypy vllm/model_executor
run_mypy vllm/plugins run_mypy vllm/plugins

View File

@ -250,7 +250,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "token" or inputs["type"] == "multimodal": if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt") return inputs.get("prompt")
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
@ -259,7 +259,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "token" or inputs["type"] == "multimodal": if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt_token_ids", []) return inputs.get("prompt_token_ids", [])
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def token_type_ids(self) -> List[int]: def token_type_ids(self) -> List[int]:
@ -268,7 +268,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "token" or inputs["type"] == "multimodal": if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("token_type_ids", []) return inputs.get("token_type_ids", [])
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]: def prompt_embeds(self) -> Optional[torch.Tensor]:
@ -277,7 +277,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "token" or inputs["type"] == "multimodal": if inputs["type"] == "token" or inputs["type"] == "multimodal":
return None return None
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
@ -289,7 +289,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "multimodal": if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {}) return inputs.get("mm_kwargs", {})
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]: def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
@ -301,7 +301,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "multimodal": if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {}) return inputs.get("mm_kwargs", {})
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def multi_modal_hashes(self) -> List[str]: def multi_modal_hashes(self) -> List[str]:
@ -311,9 +311,10 @@ class SingletonInputsAdapter:
return inputs.get("multi_modal_hashes", []) return inputs.get("multi_modal_hashes", [])
if inputs["type"] == "multimodal": if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", []) # only the case when we use MultiModalInputsV2
return inputs.get("mm_hashes", []) # type: ignore[return-value]
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
@ -325,7 +326,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "multimodal": if inputs["type"] == "multimodal":
return inputs.get("mm_placeholders", {}) return inputs.get("mm_placeholders", {})
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
@cached_property @cached_property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> Dict[str, Any]:
@ -337,7 +338,7 @@ class SingletonInputsAdapter:
if inputs["type"] == "multimodal": if inputs["type"] == "multimodal":
return {} return {}
assert_never(inputs) assert_never(inputs) # type: ignore[arg-type]
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]

View File

@ -436,7 +436,7 @@ class InputPreprocessor:
or encoder_inputs["type"] == "multimodal"): or encoder_inputs["type"] == "multimodal"):
pass pass
else: else:
assert_never(encoder_inputs) assert_never(encoder_inputs) # type: ignore[arg-type]
if decoder_inputs is None: if decoder_inputs is None:
dec_token_ids = self._prepare_decoder_input_ids_for_generation( dec_token_ids = self._prepare_decoder_input_ids_for_generation(
@ -452,7 +452,7 @@ class InputPreprocessor:
raise ValueError("Multi-modal decoder inputs of encoder-" raise ValueError("Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet") "decoder models are not supported yet")
else: else:
assert_never(encoder_inputs) assert_never(encoder_inputs) # type: ignore[arg-type]
return EncoderDecoderInputs( return EncoderDecoderInputs(
encoder=encoder_inputs, encoder=encoder_inputs,
@ -569,7 +569,7 @@ class InputPreprocessor:
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
else: else:
assert_never(prompt_inputs) assert_never(prompt_inputs) # type: ignore[arg-type]
return prompt_inputs return prompt_inputs

View File

@ -419,7 +419,7 @@ class InputRegistry:
# Be more strict in V2 # Be more strict in V2
assert "mm_kwargs" in inputs assert "mm_kwargs" in inputs
else: else:
assert_never(inputs["type"]) assert_never(inputs["type"]) # type: ignore[arg-type]
def process_input(self, model_config: "ModelConfig", def process_input(self, model_config: "ModelConfig",
inputs: ProcessorInputs) -> ProcessorInputs: inputs: ProcessorInputs) -> ProcessorInputs: