[Bugfix] [torch.compile] Add Dynamo metrics context during compilation (#15639)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-03-28 16:01:09 -04:00 committed by GitHub
parent 038bededba
commit 04437e313d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 89 additions and 28 deletions

View File

@ -2,21 +2,20 @@
from __future__ import annotations
from typing import Any
from typing import Any, Union
import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test
@pytest.fixture(params=None, name="model_info")
def models_list_fixture(request):
def models_list(all: bool):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
@ -33,6 +32,9 @@ def models_list_fixture(request):
("meta-llama/Llama-3.2-1B-Instruct", {}),
]
if not all:
return TEST_MODELS
if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
@ -77,7 +79,7 @@ def models_list_fixture(request):
"optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
)
@pytest.mark.parametrize("model_info", "", indirect=True)
@pytest.mark.parametrize("model_info", models_list(all=True))
@create_new_process_for_each_test()
def test_full_graph(
monkeypatch: pytest.MonkeyPatch,
@ -91,6 +93,31 @@ def test_full_graph(
m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
print(f"MODEL={model}")
run_model(optimization_level, model, model_kwargs)
# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
"compilation_config",
# additional compile sizes
[
CompilationConfig(level=CompilationLevel.PIECEWISE,
compile_sizes=[1, 2])
])
# only test some of the models
@pytest.mark.parametrize("model_info", models_list(all=False))
@create_new_process_for_each_test()
def test_custom_compile_config(
model_info: tuple[str, dict[str, Any]],
compilation_config: CompilationConfig,
):
model, model_kwargs = model_info
print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs)
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -103,7 +130,7 @@ def test_full_graph(
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=optimization_level,
compilation_config=compile_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import copy
import hashlib
import importlib.metadata
import os
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple
@ -9,6 +11,7 @@ from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
from packaging.version import Version
from vllm.config import VllmConfig
@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface):
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache))
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
compiled_graph = compile_fx(
graph,
example_inputs,
@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface):
hash_str = handle[0]
from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()):
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
if torch.__version__.startswith("2.5"):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface):
return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different level of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
else:
return contextlib.nullcontext()
class EagerAdaptor(CompilerInterface):
name = "eager"