[BugFix] Fix fusion test and add them to CI (#16287)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-04-09 02:46:45 -04:00 committed by GitHub
parent b1eb4ca152
commit 9cdde47289
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 75 additions and 50 deletions

View File

@ -292,6 +292,14 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
parallelism: 4 parallelism: 4
- label: PyTorch Compilation Unit Tests
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- label: PyTorch Fullgraph Smoke Test # 9min - label: PyTorch Fullgraph Smoke Test # 9min
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -301,7 +309,6 @@ steps:
# these tests need to be separated, cannot combine # these tests need to be separated, cannot combine
- pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py - pytest -v -s compile/piecewise/test_toy_llama.py
- pytest -v -s compile/test_pass_manager.py
- label: PyTorch Fullgraph Test # 18min - label: PyTorch Fullgraph Test # 18min
source_file_dependencies: source_file_dependencies:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Union from typing import Any, Optional, Union
import pytest import pytest
import torch import torch
@ -15,7 +15,7 @@ from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
def models_list(all: bool): def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
@ -32,9 +32,7 @@ def models_list(all: bool):
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if not all: if all:
return TEST_MODELS
if is_quant_method_supported("aqlm"): if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm" "quantization": "aqlm"
@ -72,8 +70,13 @@ def models_list(all: bool):
"quantization": "AWQ" "quantization": "AWQ"
})) }))
if keywords is None:
return TEST_MODELS return TEST_MODELS
# filter by keywords
pred = lambda model: any(keyword in model[0] for keyword in keywords)
return list(filter(pred, TEST_MODELS))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"optimization_level", "optimization_level",
@ -96,20 +99,30 @@ def test_full_graph(
run_model(optimization_level, model, model_kwargs) run_model(optimization_level, model, model_kwargs)
PassConfig = CompilationConfig.PassConfig
# TODO(luka) add other supported compilation config scenarios here # TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize( @pytest.mark.parametrize(
"compilation_config", "compilation_config, model_info",
# additional compile sizes
[ [
CompilationConfig(level=CompilationLevel.PIECEWISE, # additional compile sizes, only some of the models
compile_sizes=[1, 2]) (CompilationConfig(level=CompilationLevel.PIECEWISE,
compile_sizes=[1, 2]), model)
for model in models_list(all=False)
] + [
# RMSNorm + quant fusion, only 8-bit quant models
(CompilationConfig(level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True,
enable_noop=True)), model)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
]) ])
# only test some of the models # only test some of the models
@pytest.mark.parametrize("model_info", models_list(all=False))
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_custom_compile_config( def test_custom_compile_config(
model_info: tuple[str, dict[str, Any]],
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
): ):
model, model_kwargs = model_info model, model_kwargs = model_info
print(f"MODEL={model}") print(f"MODEL={model}")

View File

@ -44,12 +44,17 @@ class TestModel(torch.nn.Module):
resid = torch.sqrt(x) resid = torch.sqrt(x)
y = self.norm[0](x) y = self.norm[0](x)
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0]) x2 = self.fp8_linear.apply(y,
self.w[0],
self.wscale[0],
input_scale=self.scale[0])
# make sure resid is used for replacement to work # make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1], x3 = self.fp8_linear.apply(y2,
self.scale[1]) self.w[1],
self.wscale[1],
input_scale=self.scale[1])
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
return y3 return y3