[Kernel] Fullgraph and opcheck tests (#8479)

This commit is contained in:
bnellnm 2024-09-25 10:35:52 -04:00 committed by GitHub
parent 1c046447a6
commit 300da09177
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 744 additions and 116 deletions

View File

@ -210,6 +210,21 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4
- label: "PyTorch Fullgraph Smoke Test"
fast_check: true
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph_smoke.py
- label: "PyTorch Fullgraph Test"
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Test %N # 30min each
mirror_hardwares: [amd]
source_file_dependencies:
@ -355,7 +370,7 @@ steps:
- tests/distributed/
- vllm/compilation
commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus

View File

@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
std::vector<at::Tensor> result = {out, x.value()};
std::vector<at::Tensor> result = {out};
if (has_z) { result.push_back(out_z); }
return result;
}

View File

@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
"Tensor? index_, Tensor!? x) -> Tensor[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor? final_states_out_,"
"Tensor!? final_states_out_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif

View File

@ -1,42 +1,13 @@
import os
import pytest
from vllm.utils import cuda_device_count_stateless
from vllm.compilation.backends import vllm_backend
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tp_size", [1, 2])
@fork_new_process_for_each_test
def test_full_graph(model, tp_size):
# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)

View File

@ -0,0 +1,22 @@
import pytest
from vllm.compilation.backends import vllm_backend
from vllm.utils import cuda_device_count_stateless
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
@fork_new_process_for_each_test
def test_full_graph_multi_gpu(model_info, tp_size, backend):
model = model_info[0]
model_kwargs = model_info[1]
# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)

View File

@ -0,0 +1,13 @@
import pytest
from vllm.compilation.backends import vllm_backend
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)

104
tests/compile/utils.py Normal file
View File

@ -0,0 +1,104 @@
import os
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend
from vllm.utils import is_hip
TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]
TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
"dtype": torch.float16,
"quantization": "fp8"
}),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]
# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
}))
# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))
if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))
if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))
if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "gptq_marlin_24"
}))
if is_quant_method_supported("marlin"):
TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
"quantization": "marlin"
}))
if not is_hip() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager":
return
set_torch_compile_backend(backend)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True,
**model_kwargs)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup()
@pytest.fixture(autouse=True)
def dynamo_reset():
yield
torch._dynamo.reset()
@pytest.fixture
def example_prompts() -> List[str]:
prompts = []

View File

@ -0,0 +1,37 @@
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
def test_aqlm_dequant_opcheck():
codes = torch.randint(-32768,
32767, (22016, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((2, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
codebook_partition_sizes = [11008, 11008]
opcheck(torch.ops._C.aqlm_dequant,
(codes, codebooks, codebook_partition_sizes))
def test_aqlm_gemm_opcheck():
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
codes = torch.randint(-32768,
32767, (12288, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((3, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
codebook_partition_sizes = [4096, 4096, 4096]
bias = None
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, None))
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, bias))

View File

@ -205,7 +205,8 @@ def test_paged_attention(
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
@ -246,7 +247,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
ops.paged_attention_rocm(
@ -274,7 +276,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else:
raise AssertionError(f"Unknown version: {version}")

38
tests/kernels/test_awq.py Normal file
View File

@ -0,0 +1,38 @@
import os
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
split_k_iters = 0
thx = 0
thy = 0
opcheck(torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy))
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.randint(-2000000000,
2000000000, (64, 256),
device='cuda',
dtype=torch.int32)
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm,
(input, qweight, qzeros, scales, split_k_iters))

View File

@ -5,6 +5,8 @@ import torch
import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything
@ -84,6 +86,64 @@ def causal_conv1d_update_ref(x: torch.Tensor,
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, seq_idx, initial_states, final_states_out,
activation in ["silu", "swish"]))
@pytest.mark.parametrize("return_final_states", [False, True])
@pytest.mark.parametrize("has_initial_states", [False, True])
@pytest.mark.parametrize("channel_last", [False, True])
@ -149,6 +209,14 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
initial_states=initial_states_ref,
return_final_states=return_final_states,
activation=activation)
causal_conv1d_opcheck_fn(x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
activation=activation)
if return_final_states:
assert final_states is not None and final_states_ref is not None
assert torch.allclose(final_states,
@ -205,6 +273,10 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(
torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation in ["silu", "swish"], None))
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@ -258,7 +330,5 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
bias,
activation=activation)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

View File

@ -15,6 +15,9 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
@ -74,6 +77,9 @@ def cutlass_fp8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
def cutlass_int8_gemm_helper(m: int,
n: int,
@ -425,3 +431,7 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
def test_cutlass_support_opcheck():
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))

View File

@ -4,6 +4,7 @@ import pytest
import torch
import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
@ -127,7 +128,7 @@ def test_flash_attn_with_paged_kv(
else:
test_utils = ["test_faketensor"]
torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
@ -232,7 +233,7 @@ def test_varlen_with_paged_kv(
else:
test_utils = ["test_faketensor"]
torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,

View File

@ -5,6 +5,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -16,6 +17,26 @@ SCALE_UBS = [True, False]
SEEDS = [0]
def opcheck_fp8_quant(output,
input,
scale=None,
scale_ub=None,
use_per_token_if_dynamic=False):
if scale is not None:
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
elif use_per_token_if_dynamic:
scale = torch.empty((input.shape[0], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant,
(output, input, scale, scale_ub))
else:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@ -41,6 +62,12 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
opcheck_fp8_quant(ops_out,
x,
None,
scale_ub,
use_per_token_if_dynamic=True)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@ -60,6 +87,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
opcheck_fp8_quant(ops_out, x)
# Regression test for a case with large activations where an int32 index cannot
# represent the number of elements.

View File

@ -0,0 +1,22 @@
import gguf
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@pytest.mark.parametrize("quant_type", [12])
def test_ggml_opcheck(quant_type):
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
shape = [256, 1152]
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
m = qweight.shape[0]
n = qweight.shape[1] // type_size * block_size
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n))
x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8,
(qweight, x, quant_type, qweight.shape[0]))
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
(qweight, x, quant_type, qweight.shape[0]))

View File

@ -0,0 +1,29 @@
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
def test_gptq_shuffle_opcheck():
weight = torch.randint(-2000000,
2000000, (1792, 4096),
device='cuda',
dtype=torch.int32)
perm = torch.empty((0, ), device='cuda', dtype=torch.int32)
bit = 4
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
def test_gptq_gemm_opcheck():
a = torch.rand((240, 4096), device='cuda', dtype=torch.float16)
weight = torch.randint(-2000000,
2000000, (512, 6144),
device='cuda',
dtype=torch.int32)
zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32)
scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16)
idx = torch.empty((0, ), device='cuda', dtype=torch.int32)
use_exllama = True
bit = 4
opcheck(torch.ops._C.gptq_gemm,
(a, weight, zeros, scales, idx, use_exllama, bit))

View File

@ -3,6 +3,8 @@ import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
@ -161,6 +163,59 @@ def selective_scan_ref(u,
return out if not return_last_state else (out, last_state)
def selective_scan_opcheck_fn(u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = B.unsqueeze(1)
if C.dim() == 3:
C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros((
u.shape[0],
u.shape[1],
n_chunks,
int(A.shape[1] * 2),
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
# Disable test_autograd_registration for now as it seems to trigger
# a bogus error.
opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
position_indices, x),
test_utils=["test_schema", "test_faketensor"])
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
@ -274,6 +329,17 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u,
delta,
A,
B,
C,
D,
z=z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])

View File

@ -501,3 +501,18 @@ def test_marlin_qqq_gemm(
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
def test_marlin_gemm_opcheck():
size_m = 2048
size_n = 4096
size_k = 4096
a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL).scratch
x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
torch.testing.assert_close(x, y)
opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))

View File

@ -9,11 +9,14 @@ import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
@ -247,6 +250,35 @@ def test_fused_marlin_moe(
assert compute_max_diff(marlin_output, triton_output) < 4e-2
if ops.supports_moe_ops:
token_expert_indicies = torch.empty(m,
topk,
dtype=torch.int32,
device=a.device)
opcheck(torch.ops._moe_C.topk_softmax, (
topk_weights,
topk_ids,
token_expert_indicies,
score.float(),
))
block_size_m = 4
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
e)
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@ -319,3 +351,29 @@ def test_single_marlin_moe_multiply(
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
def test_moe_align_block_size_opcheck():
num_experts = 4
block_size = 4
topk_ids = torch.randint(0,
num_experts, (3, 4),
dtype=torch.int32,
device='cuda')
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
opcheck(torch.ops._C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))

View File

@ -0,0 +1,62 @@
"""
Tests for miscellaneous utilities
"""
from typing import Optional
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_opcheck(rot,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
opcheck(torch.ops._C.batched_rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style, rot.rotary_dim, offsets))
else:
opcheck(torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style))
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len):
batch_size = 1
base = 0
num_heads = 7
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=device)
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device=device)
key = torch.randn_like(query)
rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,
device=device,
dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets)

View File

@ -0,0 +1,24 @@
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
def test_convert_fp8_opcheck():
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="Only supported for CUDA")
def test_cuda_utils_opcheck():
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
opcheck(
torch.ops._C_cuda_utils.
get_max_shared_memory_per_block_device_attribute, (0, ))

View File

@ -2,12 +2,14 @@
import itertools
import random
import unittest
from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union)
import pytest
import torch
from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test.view_as(ideal_output))
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
"""
Reference implementation of torch.allclose
"""
torch._refs._check_close_args(name="torch.allclose",
a=a,
b=b,
rtol=rtol,
atol=atol)
return bool(
torch.all(
torch.isclose(a.double(),
b.double(),
rtol=rtol,
atol=atol,
equal_nan=equal_nan)).item())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
@ -954,6 +984,7 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]:
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
return torch.library.opcheck(
op,
args,

View File

@ -20,8 +20,10 @@ if not current_platform.is_tpu():
if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401
supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
def hint_on_error(fn):
@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_gemm # noqa B018
if hasattr(torch.ops._C, "gptq_gemm"):
@torch.library.register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
@ -265,8 +265,6 @@ try:
return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype,
device=a.device)
except Exception:
pass
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_marlin_24_gemm # noqa B018
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
@ -420,8 +416,8 @@ try:
@torch.library.register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
b_q: torch.
Tensor, # Should be the tensor returned by machete_prepack_B
# Should be the tensor returned by machete_prepack_B
b_q: torch.Tensor,
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
@ -451,10 +447,10 @@ try:
return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd")
@ -465,20 +461,11 @@ try:
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u)
if x is not None:
b = x
else:
b = torch.empty((u.size(0), u.size(1), A.size(1)),
dtype=u.dtype,
device=u.device)
if z_ is not None:
c = torch.empty_like(z_)
return [a, b, c]
return [a, c]
else:
return [a, b]
except Exception:
pass
return [a]
# cutlass
@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
# TODO: has to be a better way to do this
try:
torch.ops._C.permute_cols # noqa B018
if hasattr(torch.ops._C, "permute_cols"):
@torch.library.register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
except Exception:
pass
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int,
is_k_full: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
device=a.device)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,

View File

@ -361,7 +361,7 @@ def selective_scan_fn(u,
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if z is None:

View File

@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass