[CI/Build] Update Ruff version (#8469)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Aaron Pham 2024-09-18 07:00:56 -04:00 committed by GitHub
parent 6ffa3f314c
commit 9d104b5beb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 50 additions and 77 deletions

View File

@ -25,10 +25,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2
pip install -r requirements-lint.txt
- name: Analysing the code with ruff
run: |
ruff .
ruff check .
- name: Spelling check with codespell
run: |
codespell --toml pyproject.toml

View File

@ -45,8 +45,7 @@ if __name__ == "__main__":
rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
axs = axs.flatten()
axs_idx = 0
for shape, data in results.items():
for axs_idx, (shape, data) in enumerate(results.items()):
plt.sca(axs[axs_idx])
df = pd.DataFrame(data)
sns.lineplot(data=df,
@ -59,6 +58,5 @@ if __name__ == "__main__":
palette="Dark2")
plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)")
axs_idx += 1
plt.tight_layout()
plt.savefig("graph_machete_bench.pdf")

View File

@ -159,7 +159,7 @@ echo 'vLLM codespell: Done'
# Lint specified files
lint() {
ruff "$@"
ruff check "$@"
}
# Lint files that differ from main branch. Ignores dirs that are not slated
@ -175,7 +175,7 @@ lint_changed() {
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
ruff
ruff check
fi
}

View File

@ -42,6 +42,8 @@ ignore = [
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
]
[tool.mypy]

View File

@ -2,7 +2,7 @@
yapf==0.32.0
toml==0.10.2
tomli==2.0.1
ruff==0.1.5
ruff==0.6.5
codespell==2.3.0
isort==5.13.2
clang-format==18.1.5

View File

@ -158,10 +158,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
to initialize torch.
"""
if request.node.get_closest_marker("skip_global_cleanup"):
return False
return True
return not request.node.get_closest_marker("skip_global_cleanup")
@pytest.fixture(autouse=True)

View File

@ -65,10 +65,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
to initialize torch.
"""
if request.node.get_closest_marker("skip_global_cleanup"):
return False
return True
return not request.node.get_closest_marker("skip_global_cleanup")
@pytest.fixture(autouse=True)

View File

@ -5,7 +5,7 @@ from vllm.multimodal.base import MultiModalInputs, NestedTensors
def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors):
assert type(expected) == type(actual)
assert type(expected) == type(actual) # noqa: E721
if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual)
else:

View File

@ -66,8 +66,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes.append([])
prompts = [prefix + prompt for prompt in sample_prompts]
seq_id = 0
for prompt in prompts:
for seq_id, prompt in enumerate(prompts):
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id,
@ -83,8 +82,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
for idx in range(num_blocks):
hashes[-1][-1].append(seq.hash_of_block(idx))
seq_id += 1
# Check that hashes made with two prefixes with different first blocks are
# different everywhere.
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):

View File

@ -111,7 +111,7 @@ def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist():
configuration occurs."""
with pytest.raises(RuntimeError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type == RuntimeError
assert ex_info.type == RuntimeError # noqa: E721
assert "File does not exist" in str(ex_info)
@ -152,7 +152,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
logging_config_file.name):
with pytest.raises(ValueError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type == ValueError
assert ex_info.type == ValueError # noqa: E721
assert "Invalid logging config. Expected Dict, got" in str(ex_info)

View File

@ -453,8 +453,7 @@ def test_prepare_decode(batch_size):
# each sequence) in the decode phase
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
for selected_token_start_idx, seq_len in enumerate(seq_lens):
# Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length
@ -463,7 +462,6 @@ def test_prepare_decode(batch_size):
# generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices

View File

@ -241,10 +241,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
for _ in context_lens:
for selected_token_start_idx, _ in enumerate(context_lens):
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,

View File

@ -42,7 +42,7 @@ def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id, None)
return registered_adapters.get(adapter_id)
## worker functions

View File

@ -33,10 +33,8 @@ def is_block_tables_empty(block_tables: Union[None, Dict]):
"""
if block_tables is None:
return True
if isinstance(block_tables, dict) and all(
value is None for value in block_tables.values()):
return True
return False
return (isinstance(block_tables, dict)
and all(value is None for value in block_tables.values()))
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,

View File

@ -417,9 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
if block.content_hash in self._cached_blocks:
return True
return False
return block.content_hash in self._cached_blocks
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable

View File

@ -399,9 +399,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
"""
alloc_status = self._can_swap(seq_group, Device.CPU,
SequenceStatus.RUNNING)
if alloc_status == AllocStatus.OK:
return True
return False
return alloc_status == AllocStatus.OK
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by

View File

@ -826,7 +826,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
Yields:
@ -1042,7 +1042,7 @@ class AsyncLLMEngine:
async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
@ -1050,7 +1050,7 @@ class AsyncLLMEngine:
async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")

View File

@ -144,7 +144,7 @@ class LLMEngine:
decoding.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
@ -1605,7 +1605,7 @@ class LLMEngine:
def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor:
if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.start_profile()
else:
self.model_executor._run_workers("start_profile")
@ -1613,7 +1613,7 @@ class LLMEngine:
def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor:
if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.stop_profile()
else:
self.model_executor._run_workers("stop_profile")

View File

@ -67,9 +67,9 @@ class BaseLogitsProcessor:
instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
if type(instruction) == Generate:
if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write:
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:

View File

@ -110,9 +110,9 @@ class AWQMarlinConfig(QuantizationConfig):
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
has_zp = quant_config.get("zero_point")
if quant_method != "awq":
return False

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
import torch
from pydantic import BaseModel
@ -79,8 +79,8 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight):
return False
@ -333,7 +333,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
@ -352,8 +352,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
"""

View File

@ -132,10 +132,10 @@ class GPTQMarlinConfig(QuantizationConfig):
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
if quant_method != "gptq":
return False

View File

@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change.")
return True
if (".vllm_tensorized_marker" in deserializer):
return True
return False
return ".vllm_tensorized_marker" in deserializer
def serialize_vllm_model(

View File

@ -884,7 +884,7 @@ class MiniCPMV(MiniCPMVBaseModel):
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version, None)
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")

View File

@ -183,10 +183,7 @@ class TP1DraftModelRunner(ModelRunner):
return False
# TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config:
return False
return True
return not self.prompt_adapter_config
@torch.inference_mode()
def execute_model(

View File

@ -104,13 +104,10 @@ class AsyncMetricsCollector:
if self._rank != 0:
return False
if (now - self._last_metrics_collect_time <
self._rejsample_metrics_collect_interval_s):
return False
return True
return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete.

View File

@ -35,8 +35,8 @@ class LibEntry(triton.KernelInterface):
dns_key = [
arg.dtype if hasattr(
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
else "i32" if -(2**31) <= arg and arg <= 2**31 -
1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64"
else "i32" if arg >= -(2**31) and arg <= 2**31 -
1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
for arg in dns_args
]
# const args passed by position