[Bugfix] [torch.compile] Add Dynamo metrics context during compilation (#15639)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
parent
038bededba
commit
04437e313d
@ -2,21 +2,20 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationLevel
|
from vllm.config import CompilationConfig, CompilationLevel
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import create_new_process_for_each_test
|
from ..utils import create_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=None, name="model_info")
|
def models_list(all: bool):
|
||||||
def models_list_fixture(request):
|
|
||||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||||
("facebook/opt-125m", {}),
|
("facebook/opt-125m", {}),
|
||||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||||
@ -33,6 +32,9 @@ def models_list_fixture(request):
|
|||||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not all:
|
||||||
|
return TEST_MODELS
|
||||||
|
|
||||||
if is_quant_method_supported("aqlm"):
|
if is_quant_method_supported("aqlm"):
|
||||||
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
|
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
|
||||||
"quantization": "aqlm"
|
"quantization": "aqlm"
|
||||||
@ -77,7 +79,7 @@ def models_list_fixture(request):
|
|||||||
"optimization_level",
|
"optimization_level",
|
||||||
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
|
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("model_info", "", indirect=True)
|
@pytest.mark.parametrize("model_info", models_list(all=True))
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_full_graph(
|
def test_full_graph(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@ -91,25 +93,50 @@ def test_full_graph(
|
|||||||
m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
|
m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
|
||||||
print(f"MODEL={model}")
|
print(f"MODEL={model}")
|
||||||
|
|
||||||
prompts = [
|
run_model(optimization_level, model, model_kwargs)
|
||||||
"Hello, my name is",
|
|
||||||
"The president of the United States is",
|
|
||||||
"The capital of France is",
|
|
||||||
"The future of AI is",
|
|
||||||
]
|
|
||||||
sampling_params = SamplingParams(temperature=0)
|
|
||||||
llm = LLM(
|
|
||||||
model=model,
|
|
||||||
enforce_eager=True,
|
|
||||||
tensor_parallel_size=1,
|
|
||||||
disable_custom_all_reduce=True,
|
|
||||||
compilation_config=optimization_level,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# Print the outputs.
|
|
||||||
for output in outputs:
|
# TODO(luka) add other supported compilation config scenarios here
|
||||||
prompt = output.prompt
|
@pytest.mark.parametrize(
|
||||||
generated_text = output.outputs[0].text
|
"compilation_config",
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
# additional compile sizes
|
||||||
|
[
|
||||||
|
CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||||
|
compile_sizes=[1, 2])
|
||||||
|
])
|
||||||
|
# only test some of the models
|
||||||
|
@pytest.mark.parametrize("model_info", models_list(all=False))
|
||||||
|
@create_new_process_for_each_test()
|
||||||
|
def test_custom_compile_config(
|
||||||
|
model_info: tuple[str, dict[str, Any]],
|
||||||
|
compilation_config: CompilationConfig,
|
||||||
|
):
|
||||||
|
model, model_kwargs = model_info
|
||||||
|
print(f"MODEL={model}")
|
||||||
|
run_model(compilation_config, model, model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
||||||
|
model_kwargs: dict[str, Any]):
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0)
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
disable_custom_all_reduce=True,
|
||||||
|
compilation_config=compile_config,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import importlib.metadata
|
||||||
import os
|
import os
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
@ -9,6 +11,7 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
import torch._inductor.compile_fx
|
import torch._inductor.compile_fx
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||||
_check_can_cache))
|
_check_can_cache))
|
||||||
|
|
||||||
|
# Dynamo metrics context, see method for more details.
|
||||||
|
stack.enter_context(self.metrics_context())
|
||||||
|
|
||||||
compiled_graph = compile_fx(
|
compiled_graph = compile_fx(
|
||||||
graph,
|
graph,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
hash_str = handle[0]
|
hash_str = handle[0]
|
||||||
|
|
||||||
from torch._inductor.codecache import FxGraphCache
|
from torch._inductor.codecache import FxGraphCache
|
||||||
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
with ExitStack() as exit_stack:
|
||||||
lambda *args, **kwargs: AlwaysHitShapeEnv()):
|
exit_stack.enter_context(
|
||||||
|
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||||
|
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||||
|
|
||||||
|
# Dynamo metrics context, see method for more details.
|
||||||
|
exit_stack.enter_context(self.metrics_context())
|
||||||
|
|
||||||
if torch.__version__.startswith("2.5"):
|
if torch.__version__.startswith("2.5"):
|
||||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||||
hash_str, example_inputs, True, False)
|
hash_str, example_inputs, True, False)
|
||||||
@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
|
|
||||||
return compiled_graph
|
return compiled_graph
|
||||||
|
|
||||||
|
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||||
|
"""
|
||||||
|
This method returns the Dynamo metrics context (if it exists,
|
||||||
|
otherwise a null context). It is used by various compile components.
|
||||||
|
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||||
|
torch==2.6 (but not after). It might also be used in various other
|
||||||
|
torch.compile internal functions.
|
||||||
|
|
||||||
|
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||||
|
and the context was already entered). We might want to revisit if it
|
||||||
|
should be set at a different level of compilation.
|
||||||
|
|
||||||
|
This is likely a bug in PyTorch: public APIs should not rely on
|
||||||
|
manually setting up internal contexts. But we also rely on non-public
|
||||||
|
APIs which might not provide these guarantees.
|
||||||
|
"""
|
||||||
|
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
||||||
|
import torch._dynamo.utils
|
||||||
|
return torch._dynamo.utils.get_metrics_context()
|
||||||
|
else:
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
|
||||||
class EagerAdaptor(CompilerInterface):
|
class EagerAdaptor(CompilerInterface):
|
||||||
name = "eager"
|
name = "eager"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user