From b6087a6beead9165f4c77ceba592b3651bb37de9 Mon Sep 17 00:00:00 2001 From: Tobias Pitters <31857876+CloseChoice@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:18:15 +0100 Subject: [PATCH] [mypy] Pass type checking in vllm/inputs (#11680) Signed-off-by: Tobias Pitters --- tools/mypy.sh | 1 + vllm/inputs/data.py | 21 +++++++++++---------- vllm/inputs/preprocess.py | 6 +++--- vllm/inputs/registry.py | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tools/mypy.sh b/tools/mypy.sh index 2454ff9f..bf95e4c5 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -23,6 +23,7 @@ run_mypy vllm/compilation run_mypy vllm/distributed run_mypy vllm/engine run_mypy vllm/executor +run_mypy vllm/inputs run_mypy vllm/lora run_mypy vllm/model_executor run_mypy vllm/plugins diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index d54cbb5c..cdaf6dd7 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -250,7 +250,7 @@ class SingletonInputsAdapter: if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("prompt") - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def prompt_token_ids(self) -> List[int]: @@ -259,7 +259,7 @@ class SingletonInputsAdapter: if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("prompt_token_ids", []) - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def token_type_ids(self) -> List[int]: @@ -268,7 +268,7 @@ class SingletonInputsAdapter: if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("token_type_ids", []) - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def prompt_embeds(self) -> Optional[torch.Tensor]: @@ -277,7 +277,7 @@ class SingletonInputsAdapter: if inputs["type"] == "token" or inputs["type"] == "multimodal": return None - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def multi_modal_data(self) -> "MultiModalDataDict": @@ -289,7 +289,7 @@ class SingletonInputsAdapter: if inputs["type"] == "multimodal": return inputs.get("mm_kwargs", {}) - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]: @@ -301,7 +301,7 @@ class SingletonInputsAdapter: if inputs["type"] == "multimodal": return inputs.get("mm_kwargs", {}) - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def multi_modal_hashes(self) -> List[str]: @@ -311,9 +311,10 @@ class SingletonInputsAdapter: return inputs.get("multi_modal_hashes", []) if inputs["type"] == "multimodal": - return inputs.get("mm_hashes", []) + # only the case when we use MultiModalInputsV2 + return inputs.get("mm_hashes", []) # type: ignore[return-value] - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": @@ -325,7 +326,7 @@ class SingletonInputsAdapter: if inputs["type"] == "multimodal": return inputs.get("mm_placeholders", {}) - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] @cached_property def mm_processor_kwargs(self) -> Dict[str, Any]: @@ -337,7 +338,7 @@ class SingletonInputsAdapter: if inputs["type"] == "multimodal": return {} - assert_never(inputs) + assert_never(inputs) # type: ignore[arg-type] ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3d606817..aaa10d27 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -436,7 +436,7 @@ class InputPreprocessor: or encoder_inputs["type"] == "multimodal"): pass else: - assert_never(encoder_inputs) + assert_never(encoder_inputs) # type: ignore[arg-type] if decoder_inputs is None: dec_token_ids = self._prepare_decoder_input_ids_for_generation( @@ -452,7 +452,7 @@ class InputPreprocessor: raise ValueError("Multi-modal decoder inputs of encoder-" "decoder models are not supported yet") else: - assert_never(encoder_inputs) + assert_never(encoder_inputs) # type: ignore[arg-type] return EncoderDecoderInputs( encoder=encoder_inputs, @@ -569,7 +569,7 @@ class InputPreprocessor: prompt_adapter_request=prompt_adapter_request, ) else: - assert_never(prompt_inputs) + assert_never(prompt_inputs) # type: ignore[arg-type] return prompt_inputs diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 09034770..2d9d024e 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -419,7 +419,7 @@ class InputRegistry: # Be more strict in V2 assert "mm_kwargs" in inputs else: - assert_never(inputs["type"]) + assert_never(inputs["type"]) # type: ignore[arg-type] def process_input(self, model_config: "ModelConfig", inputs: ProcessorInputs) -> ProcessorInputs: