
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
|
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
|
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
|
from vllm.compilation.reshapes import RedundantReshapesPass
|
|
from vllm.config import CompilationConfig
|
|
|
|
from .backend import TestBackend
|
|
|
|
OPS_IN_MODEL = [
|
|
torch.ops._C.rotary_embedding.default,
|
|
torch.ops._C.fused_add_rms_norm.default,
|
|
torch.ops._C.silu_and_mul.default,
|
|
]
|
|
|
|
RMS_OP = torch.ops._C.rms_norm.default
|
|
|
|
RMS_QUANT_OPS = {
|
|
"static_fp8": [
|
|
torch.ops._C.rms_norm_static_fp8_quant.default,
|
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
|
],
|
|
}
|
|
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, quant_key",
|
|
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
|
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
|
|
kFp8DynamicTokenSym)])
|
|
@pytest.mark.parametrize("do_fusion", [True, False])
|
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
|
reason="Only test on CUDA")
|
|
def test_fix_functionalization(model: str, quant_key: QuantKey,
|
|
do_fusion: bool):
|
|
torch.set_default_device("cuda")
|
|
|
|
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
|
enable_reshape=True)
|
|
reshape_pass = RedundantReshapesPass(config)
|
|
fusion_pass = FusionPass.instance(config)
|
|
|
|
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
|
|
func_pass = FixFunctionalizationPass(config)
|
|
backend_func = TestBackend(*passes, func_pass)
|
|
backend_no_func = TestBackend(*passes)
|
|
|
|
# instantiate a full engine and manually compile the model 2x
|
|
# (with and without FixFunctionalizationPass)
|
|
llm = LLM(model=model, enforce_eager=True)
|
|
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
|
|
orig_model = model_runner.model
|
|
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
|
|
# Can only do that by using the decorator but then we'd have to instantiate
|
|
# 2 LLM instances.
|
|
|
|
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
|
model_runner.model = torch.compile(orig_model,
|
|
fullgraph=True,
|
|
backend=backend_func)
|
|
gen_func = llm.generate(prompts, sampling_params)
|
|
|
|
model_runner.model = torch.compile(orig_model,
|
|
fullgraph=True,
|
|
backend=backend_no_func)
|
|
gen_no_func = llm.generate(prompts, sampling_params)
|
|
|
|
for output_func, output_no_func in zip(gen_func, gen_no_func):
|
|
assert output_func.outputs[0].text == output_no_func.outputs[0].text
|
|
|
|
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
|
|
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
|
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
|
|
] if do_fusion else [RMS_OP]
|
|
ops = OPS_IN_MODEL + rms_ops
|
|
|
|
for op in ops:
|
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
|
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
|
op) is None # noqa: E501
|
|
|
|
# make sure the ops were all de-functionalized
|
|
found = dict()
|
|
for node in backend_func.graph_post_pass.nodes:
|
|
for op in ops:
|
|
if is_func(node, op):
|
|
found[op] = True
|
|
assert all(found[op] for op in ops)
|