[BugFix] Fix fusion test and add them to CI (#16287)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
parent
b1eb4ca152
commit
9cdde47289
@ -292,6 +292,14 @@ steps:
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -301,7 +309,6 @@ steps:
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
source_file_dependencies:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -15,7 +15,7 @@ from vllm.platforms import current_platform
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
def models_list(all: bool):
|
||||
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||
("facebook/opt-125m", {}),
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||
@ -32,47 +32,50 @@ def models_list(all: bool):
|
||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||
]
|
||||
|
||||
if not all:
|
||||
return TEST_MODELS
|
||||
|
||||
if is_quant_method_supported("aqlm"):
|
||||
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
|
||||
"quantization": "aqlm"
|
||||
}))
|
||||
|
||||
# TODO: figure out why this fails.
|
||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
|
||||
"quantization": "gguf"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
|
||||
"quantization": "gptq"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq_marlin"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
|
||||
"quantization": "gptq_marlin"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq_marlin_24"):
|
||||
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
|
||||
"quantization": "gptq_marlin_24"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("marlin"):
|
||||
TEST_MODELS.append(
|
||||
("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
|
||||
"quantization": "marlin"
|
||||
if all:
|
||||
if is_quant_method_supported("aqlm"):
|
||||
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
|
||||
"quantization": "aqlm"
|
||||
}))
|
||||
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
||||
"quantization": "AWQ"
|
||||
}))
|
||||
# TODO: figure out why this fails.
|
||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
|
||||
"quantization": "gguf"
|
||||
}))
|
||||
|
||||
return TEST_MODELS
|
||||
if is_quant_method_supported("gptq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
|
||||
"quantization": "gptq"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq_marlin"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
|
||||
"quantization": "gptq_marlin"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq_marlin_24"):
|
||||
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
|
||||
"quantization": "gptq_marlin_24"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("marlin"):
|
||||
TEST_MODELS.append(
|
||||
("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
|
||||
"quantization": "marlin"
|
||||
}))
|
||||
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
||||
"quantization": "AWQ"
|
||||
}))
|
||||
|
||||
if keywords is None:
|
||||
return TEST_MODELS
|
||||
|
||||
# filter by keywords
|
||||
pred = lambda model: any(keyword in model[0] for keyword in keywords)
|
||||
return list(filter(pred, TEST_MODELS))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -96,20 +99,30 @@ def test_full_graph(
|
||||
run_model(optimization_level, model, model_kwargs)
|
||||
|
||||
|
||||
PassConfig = CompilationConfig.PassConfig
|
||||
|
||||
|
||||
# TODO(luka) add other supported compilation config scenarios here
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config",
|
||||
# additional compile sizes
|
||||
"compilation_config, model_info",
|
||||
[
|
||||
CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
compile_sizes=[1, 2])
|
||||
# additional compile sizes, only some of the models
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
compile_sizes=[1, 2]), model)
|
||||
for model in models_list(all=False)
|
||||
] + [
|
||||
# RMSNorm + quant fusion, only 8-bit quant models
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True,
|
||||
enable_noop=True)), model)
|
||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
])
|
||||
# only test some of the models
|
||||
@pytest.mark.parametrize("model_info", models_list(all=False))
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_compile_config(
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
compilation_config: CompilationConfig,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
):
|
||||
model, model_kwargs = model_info
|
||||
print(f"MODEL={model}")
|
||||
|
@ -44,12 +44,17 @@ class TestModel(torch.nn.Module):
|
||||
resid = torch.sqrt(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w[0],
|
||||
self.wscale[0],
|
||||
input_scale=self.scale[0])
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
|
||||
self.scale[1])
|
||||
x3 = self.fp8_linear.apply(y2,
|
||||
self.w[1],
|
||||
self.wscale[1],
|
||||
input_scale=self.scale[1])
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user