[CI/Build] Update Ruff version (#8469)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Aaron Pham 2024-09-18 07:00:56 -04:00 committed by GitHub
parent 6ffa3f314c
commit 9d104b5beb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 50 additions and 77 deletions

View File

@ -25,10 +25,10 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 pip install -r requirements-lint.txt
- name: Analysing the code with ruff - name: Analysing the code with ruff
run: | run: |
ruff . ruff check .
- name: Spelling check with codespell - name: Spelling check with codespell
run: | run: |
codespell --toml pyproject.toml codespell --toml pyproject.toml

View File

@ -45,8 +45,7 @@ if __name__ == "__main__":
rows = int(math.ceil(len(results) / 2)) rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
axs = axs.flatten() axs = axs.flatten()
axs_idx = 0 for axs_idx, (shape, data) in enumerate(results.items()):
for shape, data in results.items():
plt.sca(axs[axs_idx]) plt.sca(axs[axs_idx])
df = pd.DataFrame(data) df = pd.DataFrame(data)
sns.lineplot(data=df, sns.lineplot(data=df,
@ -59,6 +58,5 @@ if __name__ == "__main__":
palette="Dark2") palette="Dark2")
plt.title(f"Shape: {shape}") plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)") plt.ylabel("time (median, s)")
axs_idx += 1
plt.tight_layout() plt.tight_layout()
plt.savefig("graph_machete_bench.pdf") plt.savefig("graph_machete_bench.pdf")

View File

@ -159,7 +159,7 @@ echo 'vLLM codespell: Done'
# Lint specified files # Lint specified files
lint() { lint() {
ruff "$@" ruff check "$@"
} }
# Lint files that differ from main branch. Ignores dirs that are not slated # Lint files that differ from main branch. Ignores dirs that are not slated
@ -175,7 +175,7 @@ lint_changed() {
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
ruff ruff check
fi fi
} }

View File

@ -42,6 +42,8 @@ ignore = [
"E731", "E731",
# Loop control variable not used within loop body # Loop control variable not used within loop body
"B007", "B007",
# f-string format
"UP032",
] ]
[tool.mypy] [tool.mypy]

View File

@ -2,7 +2,7 @@
yapf==0.32.0 yapf==0.32.0
toml==0.10.2 toml==0.10.2
tomli==2.0.1 tomli==2.0.1
ruff==0.1.5 ruff==0.6.5
codespell==2.3.0 codespell==2.3.0
isort==5.13.2 isort==5.13.2
clang-format==18.1.5 clang-format==18.1.5

View File

@ -158,10 +158,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
to initialize torch. to initialize torch.
""" """
if request.node.get_closest_marker("skip_global_cleanup"): return not request.node.get_closest_marker("skip_global_cleanup")
return False
return True
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View File

@ -65,10 +65,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
to initialize torch. to initialize torch.
""" """
if request.node.get_closest_marker("skip_global_cleanup"): return not request.node.get_closest_marker("skip_global_cleanup")
return False
return True
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View File

@ -5,7 +5,7 @@ from vllm.multimodal.base import MultiModalInputs, NestedTensors
def assert_nested_tensors_equal(expected: NestedTensors, def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors): actual: NestedTensors):
assert type(expected) == type(actual) assert type(expected) == type(actual) # noqa: E721
if isinstance(expected, torch.Tensor): if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual) assert torch.equal(expected, actual)
else: else:

View File

@ -66,8 +66,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes.append([]) hashes.append([])
prompts = [prefix + prompt for prompt in sample_prompts] prompts = [prefix + prompt for prompt in sample_prompts]
seq_id = 0 for seq_id, prompt in enumerate(prompts):
for prompt in prompts:
hashes[-1].append([]) hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, seq = Sequence(seq_id,
@ -83,8 +82,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
for idx in range(num_blocks): for idx in range(num_blocks):
hashes[-1][-1].append(seq.hash_of_block(idx)) hashes[-1][-1].append(seq.hash_of_block(idx))
seq_id += 1
# Check that hashes made with two prefixes with different first blocks are # Check that hashes made with two prefixes with different first blocks are
# different everywhere. # different everywhere.
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):

View File

@ -111,7 +111,7 @@ def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist():
configuration occurs.""" configuration occurs."""
with pytest.raises(RuntimeError) as ex_info: with pytest.raises(RuntimeError) as ex_info:
_configure_vllm_root_logger() _configure_vllm_root_logger()
assert ex_info.type == RuntimeError assert ex_info.type == RuntimeError # noqa: E721
assert "File does not exist" in str(ex_info) assert "File does not exist" in str(ex_info)
@ -152,7 +152,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
logging_config_file.name): logging_config_file.name):
with pytest.raises(ValueError) as ex_info: with pytest.raises(ValueError) as ex_info:
_configure_vllm_root_logger() _configure_vllm_root_logger()
assert ex_info.type == ValueError assert ex_info.type == ValueError # noqa: E721
assert "Invalid logging config. Expected Dict, got" in str(ex_info) assert "Invalid logging config. Expected Dict, got" in str(ex_info)

View File

@ -453,8 +453,7 @@ def test_prepare_decode(batch_size):
# each sequence) in the decode phase # each sequence) in the decode phase
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 for selected_token_start_idx, seq_len in enumerate(seq_lens):
for seq_len in seq_lens:
# Compute the index offset of the final token in each # Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is # sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length # decoded per iteration per sequence, then the length
@ -463,7 +462,6 @@ def test_prepare_decode(batch_size):
# generated tokens is 0 (i.e. the expected sampling index # generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`) # for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices.append(selected_token_start_idx) expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_input.sampling_metadata sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices

View File

@ -241,10 +241,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling # Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 for selected_token_start_idx, _ in enumerate(context_lens):
for _ in context_lens:
expected_selected_token_indices.append(selected_token_start_idx) expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
seq_lens, seq_lens,

View File

@ -42,7 +42,7 @@ def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
def get_adapter(adapter_id: int, def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]: registered_adapters: Dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id, None) return registered_adapters.get(adapter_id)
## worker functions ## worker functions

View File

@ -33,10 +33,8 @@ def is_block_tables_empty(block_tables: Union[None, Dict]):
""" """
if block_tables is None: if block_tables is None:
return True return True
if isinstance(block_tables, dict) and all( return (isinstance(block_tables, dict)
value is None for value in block_tables.values()): and all(value is None for value in block_tables.values()))
return True
return False
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,

View File

@ -417,9 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def is_block_cached(self, block: Block) -> bool: def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None assert block.content_hash is not None
if block.content_hash in self._cached_blocks: return block.content_hash in self._cached_blocks
return True
return False
def promote_to_immutable_block(self, block: Block) -> BlockId: def promote_to_immutable_block(self, block: Block) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable """Once a mutable block is full, it can be promoted to an immutable

View File

@ -399,9 +399,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
""" """
alloc_status = self._can_swap(seq_group, Device.CPU, alloc_status = self._can_swap(seq_group, Device.CPU,
SequenceStatus.RUNNING) SequenceStatus.RUNNING)
if alloc_status == AllocStatus.OK: return alloc_status == AllocStatus.OK
return True
return False
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by """Returns the block id mapping (from GPU to CPU) generated by

View File

@ -826,7 +826,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
Yields: Yields:
@ -1042,7 +1042,7 @@ class AsyncLLMEngine:
async def start_profile(self) -> None: async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes # inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile() self.engine.model_executor.start_profile()
else: else:
self.engine.model_executor._run_workers("start_profile") self.engine.model_executor._run_workers("start_profile")
@ -1050,7 +1050,7 @@ class AsyncLLMEngine:
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes # inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile() self.engine.model_executor.stop_profile()
else: else:
self.engine.model_executor._run_workers("stop_profile") self.engine.model_executor._run_workers("stop_profile")

View File

@ -144,7 +144,7 @@ class LLMEngine:
decoding. decoding.
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
prompt_adapter_config (Optional): The configuration related to serving prompt_adapter_config (Optional): The configuration related to serving
prompt adapters. prompt adapters.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection. usage_context: Specified entry point, used for usage info collection.
@ -1605,7 +1605,7 @@ class LLMEngine:
def start_profile(self) -> None: def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor) # inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.start_profile() self.model_executor.start_profile()
else: else:
self.model_executor._run_workers("start_profile") self.model_executor._run_workers("start_profile")
@ -1613,7 +1613,7 @@ class LLMEngine:
def stop_profile(self) -> None: def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing # using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor) # inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.stop_profile() self.model_executor.stop_profile()
else: else:
self.model_executor._run_workers("stop_profile") self.model_executor._run_workers("stop_profile")

View File

@ -67,9 +67,9 @@ class BaseLogitsProcessor:
instruction = self._guide.get_next_instruction( instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id]) state=self._fsm_state[seq_id])
if type(instruction) == Generate: if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens allowed_tokens = instruction.tokens
elif type(instruction) == Write: elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens # TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]] allowed_tokens = [instruction.tokens[0]]
else: else:

View File

@ -110,9 +110,9 @@ class AWQMarlinConfig(QuantizationConfig):
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None) num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size", None) group_size = quant_config.get("group_size")
has_zp = quant_config.get("zero_point", None) has_zp = quant_config.get("zero_point")
if quant_method != "awq": if quant_method != "awq":
return False return False

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
@ -79,8 +79,8 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
target_scheme_map: Dict[str, Any] = dict() target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None) ignore = cast(List[str], config.get("ignore"))
quant_format: str = config.get("format", None) quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing # The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are # an input_activations key with details about how the activations are
@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_or_channel_weight = (weight_quant.strategy in [ is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
]) ])
if not (is_symmetric_weight and is_static_weight if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight): and is_per_tensor_or_channel_weight):
return False return False
@ -333,7 +333,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
""" """
Use the CompressedTensorsScheme associated with each layer to create Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param the necessary parameters for the layer. See LinearMethodBase for param
details details
""" """
@ -352,8 +352,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None): bias: Optional[torch.Tensor] = None):
""" """
Use the output of create_weights and the CompressedTensorsScheme Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details layer input. See LinearMethodBase for param details
""" """

View File

@ -132,10 +132,10 @@ class GPTQMarlinConfig(QuantizationConfig):
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None) num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size", None) group_size = quant_config.get("group_size")
sym = quant_config.get("sym", None) sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act", None) desc_act = quant_config.get("desc_act")
if quant_method != "gptq": if quant_method != "gptq":
return False return False

View File

@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"inferred as vLLM models, so setting vllm_tensorized=True is " "inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change.") "only necessary for models serialized prior to this change.")
return True return True
if (".vllm_tensorized_marker" in deserializer): return ".vllm_tensorized_marker" in deserializer
return True
return False
def serialize_vllm_model( def serialize_vllm_model(

View File

@ -884,7 +884,7 @@ class MiniCPMV(MiniCPMVBaseModel):
version = str(config.version).split(".") version = str(config.version).split(".")
version = tuple([int(x) for x in version]) version = tuple([int(x) for x in version])
# Dispatch class based on version # Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version, None) instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None: if instance_class is None:
raise ValueError( raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")

View File

@ -183,10 +183,7 @@ class TP1DraftModelRunner(ModelRunner):
return False return False
# TODO: Add soft-tuning prompt adapter support # TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config: return not self.prompt_adapter_config
return False
return True
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(

View File

@ -104,13 +104,10 @@ class AsyncMetricsCollector:
if self._rank != 0: if self._rank != 0:
return False return False
if (now - self._last_metrics_collect_time < return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
self._rejsample_metrics_collect_interval_s):
return False
return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics """Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously. (number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete. Returns a CUDA event recording when the copy is complete.

View File

@ -35,8 +35,8 @@ class LibEntry(triton.KernelInterface):
dns_key = [ dns_key = [
arg.dtype if hasattr( arg.dtype if hasattr(
arg, "data_ptr") else type(arg) if not isinstance(arg, int) arg, "data_ptr") else type(arg) if not isinstance(arg, int)
else "i32" if -(2**31) <= arg and arg <= 2**31 - else "i32" if arg >= -(2**31) and arg <= 2**31 -
1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" 1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
for arg in dns_args for arg in dns_args
] ]
# const args passed by position # const args passed by position