[Bugfix] [torch.compile] Add Dynamo metrics context during compilation (#15639)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
parent
038bededba
commit
04437e313d
@ -2,21 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.fixture(params=None, name="model_info")
|
||||
def models_list_fixture(request):
|
||||
def models_list(all: bool):
|
||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||
("facebook/opt-125m", {}),
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||
@ -33,6 +32,9 @@ def models_list_fixture(request):
|
||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||
]
|
||||
|
||||
if not all:
|
||||
return TEST_MODELS
|
||||
|
||||
if is_quant_method_supported("aqlm"):
|
||||
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
|
||||
"quantization": "aqlm"
|
||||
@ -77,7 +79,7 @@ def models_list_fixture(request):
|
||||
"optimization_level",
|
||||
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
|
||||
)
|
||||
@pytest.mark.parametrize("model_info", "", indirect=True)
|
||||
@pytest.mark.parametrize("model_info", models_list(all=True))
|
||||
@create_new_process_for_each_test()
|
||||
def test_full_graph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@ -91,25 +93,50 @@ def test_full_graph(
|
||||
m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
|
||||
print(f"MODEL={model}")
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
disable_custom_all_reduce=True,
|
||||
compilation_config=optimization_level,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
run_model(optimization_level, model, model_kwargs)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# TODO(luka) add other supported compilation config scenarios here
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config",
|
||||
# additional compile sizes
|
||||
[
|
||||
CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
compile_sizes=[1, 2])
|
||||
])
|
||||
# only test some of the models
|
||||
@pytest.mark.parametrize("model_info", models_list(all=False))
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_compile_config(
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
compilation_config: CompilationConfig,
|
||||
):
|
||||
model, model_kwargs = model_info
|
||||
print(f"MODEL={model}")
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
|
||||
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
||||
model_kwargs: dict[str, Any]):
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
disable_custom_all_reduce=True,
|
||||
compilation_config=compile_config,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
@ -9,6 +11,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
from packaging.version import Version
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface):
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache))
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface):
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False)
|
||||
@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different level of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
||||
import torch._dynamo.utils
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
Loading…
x
Reference in New Issue
Block a user