[Bugfix] [torch.compile] Add Dynamo metrics context during compilation (#15639)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-03-28 16:01:09 -04:00 committed by GitHub
parent 038bededba
commit 04437e313d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 89 additions and 28 deletions

View File

@ -2,21 +2,20 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any, Union
import pytest import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@pytest.fixture(params=None, name="model_info") def models_list(all: bool):
def models_list_fixture(request):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
@ -33,6 +32,9 @@ def models_list_fixture(request):
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if not all:
return TEST_MODELS
if is_quant_method_supported("aqlm"): if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm" "quantization": "aqlm"
@ -77,7 +79,7 @@ def models_list_fixture(request):
"optimization_level", "optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
) )
@pytest.mark.parametrize("model_info", "", indirect=True) @pytest.mark.parametrize("model_info", models_list(all=True))
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_full_graph( def test_full_graph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
@ -91,25 +93,50 @@ def test_full_graph(
m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
print(f"MODEL={model}") print(f"MODEL={model}")
prompts = [ run_model(optimization_level, model, model_kwargs)
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(
model=model,
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=optimization_level,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs: # TODO(luka) add other supported compilation config scenarios here
prompt = output.prompt @pytest.mark.parametrize(
generated_text = output.outputs[0].text "compilation_config",
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # additional compile sizes
[
CompilationConfig(level=CompilationLevel.PIECEWISE,
compile_sizes=[1, 2])
])
# only test some of the models
@pytest.mark.parametrize("model_info", models_list(all=False))
@create_new_process_for_each_test()
def test_custom_compile_config(
model_info: tuple[str, dict[str, Any]],
compilation_config: CompilationConfig,
):
model, model_kwargs = model_info
print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs)
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(
model=model,
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import contextlib
import copy import copy
import hashlib import hashlib
import importlib.metadata
import os import os
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
@ -9,6 +11,7 @@ from unittest.mock import patch
import torch import torch
import torch._inductor.compile_fx import torch._inductor.compile_fx
import torch.fx as fx import torch.fx as fx
from packaging.version import Version
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface):
"torch._inductor.codecache.FxGraphCache._check_can_cache", "torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache)) _check_can_cache))
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
compiled_graph = compile_fx( compiled_graph = compile_fx(
graph, graph,
example_inputs, example_inputs,
@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface):
hash_str = handle[0] hash_str = handle[0]
from torch._inductor.codecache import FxGraphCache from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", with ExitStack() as exit_stack:
lambda *args, **kwargs: AlwaysHitShapeEnv()): exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
if torch.__version__.startswith("2.5"): if torch.__version__.startswith("2.5"):
inductor_compiled_graph = FxGraphCache._lookup_graph( inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False) hash_str, example_inputs, True, False)
@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface):
return compiled_graph return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different level of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
else:
return contextlib.nullcontext()
class EagerAdaptor(CompilerInterface): class EagerAdaptor(CompilerInterface):
name = "eager" name = "eager"