diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 3a45c354..5311a4ce 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -2,21 +2,20 @@ from __future__ import annotations -from typing import Any +from typing import Any, Union import pytest import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel from vllm.platforms import current_platform from ..utils import create_new_process_for_each_test -@pytest.fixture(params=None, name="model_info") -def models_list_fixture(request): +def models_list(all: bool): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { @@ -33,6 +32,9 @@ def models_list_fixture(request): ("meta-llama/Llama-3.2-1B-Instruct", {}), ] + if not all: + return TEST_MODELS + if is_quant_method_supported("aqlm"): TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { "quantization": "aqlm" @@ -77,7 +79,7 @@ def models_list_fixture(request): "optimization_level", [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], ) -@pytest.mark.parametrize("model_info", "", indirect=True) +@pytest.mark.parametrize("model_info", models_list(all=True)) @create_new_process_for_each_test() def test_full_graph( monkeypatch: pytest.MonkeyPatch, @@ -91,25 +93,50 @@ def test_full_graph( m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") print(f"MODEL={model}") - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(temperature=0) - llm = LLM( - model=model, - enforce_eager=True, - tensor_parallel_size=1, - disable_custom_all_reduce=True, - compilation_config=optimization_level, - **model_kwargs, - ) - outputs = llm.generate(prompts, sampling_params) + run_model(optimization_level, model, model_kwargs) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +# TODO(luka) add other supported compilation config scenarios here +@pytest.mark.parametrize( + "compilation_config", + # additional compile sizes + [ + CompilationConfig(level=CompilationLevel.PIECEWISE, + compile_sizes=[1, 2]) + ]) +# only test some of the models +@pytest.mark.parametrize("model_info", models_list(all=False)) +@create_new_process_for_each_test() +def test_custom_compile_config( + model_info: tuple[str, dict[str, Any]], + compilation_config: CompilationConfig, +): + model, model_kwargs = model_info + print(f"MODEL={model}") + run_model(compilation_config, model, model_kwargs) + + +def run_model(compile_config: Union[int, CompilationConfig], model: str, + model_kwargs: dict[str, Any]): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + llm = LLM( + model=model, + enforce_eager=True, + tensor_parallel_size=1, + disable_custom_all_reduce=True, + compilation_config=compile_config, + **model_kwargs, + ) + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index d6e44fa6..5a22cf70 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +import contextlib import copy import hashlib +import importlib.metadata import os from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Tuple @@ -9,6 +11,7 @@ from unittest.mock import patch import torch import torch._inductor.compile_fx import torch.fx as fx +from packaging.version import Version from vllm.config import VllmConfig @@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface): "torch._inductor.codecache.FxGraphCache._check_can_cache", _check_can_cache)) + # Dynamo metrics context, see method for more details. + stack.enter_context(self.metrics_context()) + compiled_graph = compile_fx( graph, example_inputs, @@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface): hash_str = handle[0] from torch._inductor.codecache import FxGraphCache - with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv()): + with ExitStack() as exit_stack: + exit_stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) + + # Dynamo metrics context, see method for more details. + exit_stack.enter_context(self.metrics_context()) + if torch.__version__.startswith("2.5"): inductor_compiled_graph = FxGraphCache._lookup_graph( hash_str, example_inputs, True, False) @@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface): return compiled_graph + def metrics_context(self) -> contextlib.AbstractContextManager: + """ + This method returns the Dynamo metrics context (if it exists, + otherwise a null context). It is used by various compile components. + Present in torch>=2.6, it's used inside FxGraphCache in + torch==2.6 (but not after). It might also be used in various other + torch.compile internal functions. + + Because it is re-entrant, we always set it (even if entering via Dynamo + and the context was already entered). We might want to revisit if it + should be set at a different level of compilation. + + This is likely a bug in PyTorch: public APIs should not rely on + manually setting up internal contexts. But we also rely on non-public + APIs which might not provide these guarantees. + """ + if Version(importlib.metadata.version('torch')) >= Version("2.6"): + import torch._dynamo.utils + return torch._dynamo.utils.get_metrics_context() + else: + return contextlib.nullcontext() + class EagerAdaptor(CompilerInterface): name = "eager"