[Kernel] Fullgraph and opcheck tests (#8479)

This commit is contained in:
bnellnm 2024-09-25 10:35:52 -04:00 committed by GitHub
parent 1c046447a6
commit 300da09177
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 744 additions and 116 deletions

View File

@ -70,7 +70,7 @@ steps:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Core Test # 10min - label: Core Test # 10min
mirror_hardwares: [amd] mirror_hardwares: [amd]
fast_check: true fast_check: true
@ -210,6 +210,21 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4 parallelism: 4
- label: "PyTorch Fullgraph Smoke Test"
fast_check: true
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph_smoke.py
- label: "PyTorch Fullgraph Test"
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Test %N # 30min each - label: Kernels Test %N # 30min each
mirror_hardwares: [amd] mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
@ -355,7 +370,7 @@ steps:
- tests/distributed/ - tests/distributed/
- vllm/compilation - vllm/compilation
commands: commands:
- pytest -v -s ./compile/test_full_graph.py - pytest -v -s ./compile/test_full_graph_multi_gpu.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus

View File

@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream); selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
}); });
std::vector<at::Tensor> result = {out, x.value()}; std::vector<at::Tensor> result = {out};
if (has_z) { result.push_back(out_z); } if (has_z) { result.push_back(out_z); }
return result; return result;
} }

View File

@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor! A, Tensor! B, Tensor! C," "Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_," "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus," "bool delta_softplus,"
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]"); "Tensor? index_, Tensor!? x) -> Tensor[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def( ops.def(
@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_," "Tensor? bias_,"
"Tensor? seq_idx_," "Tensor? seq_idx_,"
"Tensor? initial_states_," "Tensor? initial_states_,"
"Tensor? final_states_out_," "Tensor!? final_states_out_,"
"bool silu_activation) -> Tensor"); "bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif #endif

View File

@ -1,42 +1,13 @@
import os
import pytest import pytest
from vllm.utils import cuda_device_count_stateless from vllm.compilation.backends import vllm_backend
from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("backend", ["eager", vllm_backend])
@fork_new_process_for_each_test def test_full_graph(model_info, backend):
def test_full_graph(model, tp_size): model = model_info[0]
model_kwargs = model_info[1]
# Skip the test if there are not enough CUDA devices. check_full_graph_support(model, model_kwargs, backend, tp_size=1)
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -0,0 +1,22 @@
import pytest
from vllm.compilation.backends import vllm_backend
from vllm.utils import cuda_device_count_stateless
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
@fork_new_process_for_each_test
def test_full_graph_multi_gpu(model_info, tp_size, backend):
model = model_info[0]
model_kwargs = model_info[1]
# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)

View File

@ -0,0 +1,13 @@
import pytest
from vllm.compilation.backends import vllm_backend
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)

104
tests/compile/utils.py Normal file
View File

@ -0,0 +1,104 @@
import os
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend
from vllm.utils import is_hip
TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]
TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
"dtype": torch.float16,
"quantization": "fp8"
}),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]
# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
}))
# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))
if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))
if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))
if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "gptq_marlin_24"
}))
if is_quant_method_supported("marlin"):
TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
"quantization": "marlin"
}))
if not is_hip() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager":
return
set_torch_compile_backend(backend)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True,
**model_kwargs)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup() cleanup()
@pytest.fixture(autouse=True)
def dynamo_reset():
yield
torch._dynamo.reset()
@pytest.fixture @pytest.fixture
def example_prompts() -> List[str]: def example_prompts() -> List[str]:
prompts = [] prompts = []

View File

@ -0,0 +1,37 @@
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
def test_aqlm_dequant_opcheck():
codes = torch.randint(-32768,
32767, (22016, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((2, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
codebook_partition_sizes = [11008, 11008]
opcheck(torch.ops._C.aqlm_dequant,
(codes, codebooks, codebook_partition_sizes))
def test_aqlm_gemm_opcheck():
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
codes = torch.randint(-32768,
32767, (12288, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((3, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
codebook_partition_sizes = [4096, 4096, 4096]
bias = None
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, None))
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, bias))

View File

@ -205,7 +205,8 @@ def test_paged_attention(
(output, query, key_cache, value_cache, num_kv_heads, scale, (output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0])) cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"): elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
@ -246,7 +247,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0])) cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else: else:
ops.paged_attention_rocm( ops.paged_attention_rocm(
@ -274,7 +276,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale), kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0])) cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
else: else:
raise AssertionError(f"Unknown version: {version}") raise AssertionError(f"Unknown version: {version}")

38
tests/kernels/test_awq.py Normal file
View File

@ -0,0 +1,38 @@
import os
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
split_k_iters = 0
thx = 0
thy = 0
opcheck(torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy))
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.randint(-2000000000,
2000000000, (64, 256),
device='cuda',
dtype=torch.int32)
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm,
(input, qweight, qzeros, scales, split_k_iters))

View File

@ -5,6 +5,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything from vllm.utils import seed_everything
@ -84,6 +86,64 @@ def causal_conv1d_update_ref(x: torch.Tensor,
return (out if activation is None else F.silu(out)).to(dtype=dtype_in) return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, seq_idx, initial_states, final_states_out,
activation in ["silu", "swish"]))
@pytest.mark.parametrize("return_final_states", [False, True]) @pytest.mark.parametrize("return_final_states", [False, True])
@pytest.mark.parametrize("has_initial_states", [False, True]) @pytest.mark.parametrize("has_initial_states", [False, True])
@pytest.mark.parametrize("channel_last", [False, True]) @pytest.mark.parametrize("channel_last", [False, True])
@ -149,6 +209,14 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
initial_states=initial_states_ref, initial_states=initial_states_ref,
return_final_states=return_final_states, return_final_states=return_final_states,
activation=activation) activation=activation)
causal_conv1d_opcheck_fn(x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
activation=activation)
if return_final_states: if return_final_states:
assert final_states is not None and final_states_ref is not None assert final_states is not None and final_states_ref is not None
assert torch.allclose(final_states, assert torch.allclose(final_states,
@ -205,6 +273,10 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref) assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(
torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation in ["silu", "swish"], None))
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@ -258,7 +330,5 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
bias, bias,
activation=activation) activation=activation)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

View File

@ -15,6 +15,9 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor): def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
@ -74,6 +77,9 @@ def cutlass_fp8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
def cutlass_int8_gemm_helper(m: int, def cutlass_int8_gemm_helper(m: int,
n: int, n: int,
@ -425,3 +431,7 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
baseline = torch.mm(scale_a * a.to(dtype=torch.float32), baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
def test_cutlass_support_opcheck():
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))

View File

@ -4,6 +4,7 @@ import pytest
import torch import torch
import vllm.attention.backends.flash_attn # noqa: F401 import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything from vllm.utils import seed_everything
NUM_HEADS = [(4, 4), (8, 2), (16, 2)] NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
@ -127,19 +128,19 @@ def test_flash_attn_with_paged_kv(
else: else:
test_utils = ["test_faketensor"] test_utils = ["test_faketensor"]
torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache, opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(), args=tuple(),
kwargs=dict( kwargs=dict(
decode_query=query.unsqueeze(1), decode_query=query.unsqueeze(1),
key_cache=key_cache, key_cache=key_cache,
value_cache=value_cache, value_cache=value_cache,
softmax_scale=scale, softmax_scale=scale,
causal=True, causal=True,
block_table=block_tables, block_table=block_tables,
cache_seqlens=kv_lens_tensor, cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
), ),
test_utils=test_utils) test_utils=test_utils)
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
query=query, query=query,
@ -232,23 +233,23 @@ def test_varlen_with_paged_kv(
else: else:
test_utils = ["test_faketensor"] test_utils = ["test_faketensor"]
torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func, opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(), args=tuple(),
kwargs=dict( kwargs=dict(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens, cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len, max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len, max_seqlen_k=max_kv_len,
softmax_scale=scale, softmax_scale=scale,
causal=True, causal=True,
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
), ),
test_utils=test_utils) test_utils=test_utils)
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
query=query, query=query,

View File

@ -5,6 +5,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import (FP8_DTYPE, from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant) ref_dynamic_per_token_quant)
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -16,6 +17,26 @@ SCALE_UBS = [True, False]
SEEDS = [0] SEEDS = [0]
def opcheck_fp8_quant(output,
input,
scale=None,
scale_ub=None,
use_per_token_if_dynamic=False):
if scale is not None:
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
elif use_per_token_if_dynamic:
scale = torch.empty((input.shape[0], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant,
(output, input, scale, scale_ub))
else:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@ -41,6 +62,12 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close(ref_out.to(dtype=torch.float32), torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32)) ops_out.to(dtype=torch.float32))
opcheck_fp8_quant(ops_out,
x,
None,
scale_ub,
use_per_token_if_dynamic=True)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@ -60,6 +87,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close(ref_out.to(dtype=torch.float32), torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32)) ops_out.to(dtype=torch.float32))
opcheck_fp8_quant(ops_out, x)
# Regression test for a case with large activations where an int32 index cannot # Regression test for a case with large activations where an int32 index cannot
# represent the number of elements. # represent the number of elements.

View File

@ -0,0 +1,22 @@
import gguf
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@pytest.mark.parametrize("quant_type", [12])
def test_ggml_opcheck(quant_type):
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
shape = [256, 1152]
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
m = qweight.shape[0]
n = qweight.shape[1] // type_size * block_size
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n))
x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8,
(qweight, x, quant_type, qweight.shape[0]))
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
(qweight, x, quant_type, qweight.shape[0]))

View File

@ -0,0 +1,29 @@
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
def test_gptq_shuffle_opcheck():
weight = torch.randint(-2000000,
2000000, (1792, 4096),
device='cuda',
dtype=torch.int32)
perm = torch.empty((0, ), device='cuda', dtype=torch.int32)
bit = 4
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
def test_gptq_gemm_opcheck():
a = torch.rand((240, 4096), device='cuda', dtype=torch.float16)
weight = torch.randint(-2000000,
2000000, (512, 6144),
device='cuda',
dtype=torch.int32)
zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32)
scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16)
idx = torch.empty((0, ), device='cuda', dtype=torch.int32)
use_exllama = True
bit = 4
opcheck(torch.ops._C.gptq_gemm,
(a, weight, zeros, scales, idx, use_exllama, bit))

View File

@ -3,6 +3,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything from vllm.utils import seed_everything
@ -161,6 +163,59 @@ def selective_scan_ref(u,
return out if not return_last_state else (out, last_state) return out if not return_last_state else (out, last_state)
def selective_scan_opcheck_fn(u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = B.unsqueeze(1)
if C.dim() == 3:
C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros((
u.shape[0],
u.shape[1],
n_chunks,
int(A.shape[1] * 2),
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
# Disable test_autograd_registration for now as it seems to trigger
# a bogus error.
opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
position_indices, x),
test_utils=["test_schema", "test_faketensor"])
@pytest.mark.parametrize('wtype', [torch.float32]) @pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32]) @pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
@ -274,6 +329,17 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
assert state is not None and state_ref is not None assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u,
delta,
A,
B,
C,
D,
z=z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state)
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])

View File

@ -501,3 +501,18 @@ def test_marlin_qqq_gemm(
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04 assert max_diff < 0.04
def test_marlin_gemm_opcheck():
size_m = 2048
size_n = 4096
size_k = 4096
a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL).scratch
x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
torch.testing.assert_close(x, y)
opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))

View File

@ -9,11 +9,14 @@ import torch
from transformers import MixtralConfig from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe) fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize) marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
@ -247,6 +250,35 @@ def test_fused_marlin_moe(
assert compute_max_diff(marlin_output, triton_output) < 4e-2 assert compute_max_diff(marlin_output, triton_output) < 4e-2
if ops.supports_moe_ops:
token_expert_indicies = torch.empty(m,
topk,
dtype=torch.int32,
device=a.device)
opcheck(torch.ops._moe_C.topk_softmax, (
topk_weights,
topk_ids,
token_expert_indicies,
score.float(),
))
block_size_m = 4
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
e)
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, " @pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.") "don't run it in automated tests.")
@ -319,3 +351,29 @@ def test_single_marlin_moe_multiply(
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2 assert compute_max_diff(marlin_output, torch_output) < 1e-2
def test_moe_align_block_size_opcheck():
num_experts = 4
block_size = 4
topk_ids = torch.randint(0,
num_experts, (3, 4),
dtype=torch.int32,
device='cuda')
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
opcheck(torch.ops._C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))

View File

@ -0,0 +1,62 @@
"""
Tests for miscellaneous utilities
"""
from typing import Optional
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_opcheck(rot,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
opcheck(torch.ops._C.batched_rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style, rot.rotary_dim, offsets))
else:
opcheck(torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style))
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len):
batch_size = 1
base = 0
num_heads = 7
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=device)
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device=device)
key = torch.randn_like(query)
rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,
device=device,
dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets)

View File

@ -0,0 +1,24 @@
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
def test_convert_fp8_opcheck():
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="Only supported for CUDA")
def test_cuda_utils_opcheck():
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
opcheck(
torch.ops._C_cuda_utils.
get_max_shared_memory_per_block_device_attribute, (0, ))

View File

@ -2,12 +2,14 @@
import itertools import itertools
import random import random
import unittest
from numbers import Number from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union) Union)
import pytest import pytest
import torch import torch
from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test.view_as(ideal_output)) output_under_test.view_as(ideal_output))
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
"""
Reference implementation of torch.allclose
"""
torch._refs._check_close_args(name="torch.allclose",
a=a,
b=b,
rtol=rtol,
atol=atol)
return bool(
torch.all(
torch.isclose(a.double(),
b.double(),
rtol=rtol,
atol=atol,
equal_nan=equal_nan)).item())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef], torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...], args: Tuple[Any, ...],
@ -954,9 +984,10 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True, raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]: cond: bool = True) -> Dict[str, str]:
return torch.library.opcheck( with unittest.mock.patch('torch.allclose', new=fp8_allclose):
op, return torch.library.opcheck(
args, op,
kwargs, args,
test_utils=test_utils, kwargs,
raise_exception=raise_exception) if cond else {} test_utils=test_utils,
raise_exception=raise_exception) if cond else {}

View File

@ -20,8 +20,10 @@ if not current_platform.is_tpu():
if current_platform.is_rocm(): if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401 import vllm._rocm_C # noqa: F401
supports_moe_ops = False
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
supports_moe_ops = True
def hint_on_error(fn): def hint_on_error(fn):
@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this if hasattr(torch.ops._C, "gptq_gemm"):
try:
torch.ops._C.gptq_gemm # noqa B018
@torch.library.register_fake("_C::gptq_gemm") @torch.library.register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
@ -265,8 +265,6 @@ try:
return torch.empty((a.size(0), b_q_weight.size(1)), return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype, dtype=a.dtype,
device=a.device) device=a.device)
except Exception:
pass
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k) size_n, size_k)
# TODO: has to be a better way to do this if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
try:
torch.ops._C.gptq_marlin_24_gemm # noqa B018
@torch.library.register_fake("_C::gptq_marlin_24_gemm") @torch.library.register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
@ -420,8 +416,8 @@ try:
@torch.library.register_fake("_C::machete_gemm") @torch.library.register_fake("_C::machete_gemm")
def machete_gemm_fake( def machete_gemm_fake(
a: torch.Tensor, a: torch.Tensor,
b_q: torch. # Should be the tensor returned by machete_prepack_B
Tensor, # Should be the tensor returned by machete_prepack_B b_q: torch.Tensor,
b_type: ScalarType, b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None, b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None, b_zeros: Optional[torch.Tensor] = None,
@ -451,10 +447,10 @@ try:
return torch.empty_like(x) return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update") @torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, def causal_conv1d_update_fake(
weight: torch.Tensor, x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], bias_: Optional[torch.Tensor], silu_activation: bool,
silu_activation: bool) -> torch.Tensor: conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x) return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd") @torch.library.register_fake("_C::selective_scan_fwd")
@ -465,20 +461,11 @@ try:
delta_softplus: bool, index_: Optional[torch.Tensor], delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]: x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u) a = torch.empty_like(u)
if x is not None:
b = x
else:
b = torch.empty((u.size(0), u.size(1), A.size(1)),
dtype=u.dtype,
device=u.device)
if z_ is not None: if z_ is not None:
c = torch.empty_like(z_) c = torch.empty_like(z_)
return [a, b, c] return [a, c]
else: else:
return [a, b] return [a]
except Exception:
pass
# cutlass # cutlass
@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return torch.ops._C.machete_prepack_B(b_q_weight, b_type) return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
# TODO: has to be a better way to do this if hasattr(torch.ops._C, "permute_cols"):
try:
torch.ops._C.permute_cols # noqa B018
@torch.library.register_fake("_C::permute_cols") @torch.library.register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor, def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor: perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a) return torch.empty_like(a)
except Exception:
pass
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output) token_expert_indicies, gating_output)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int,
is_k_full: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
device=a.device)
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,

View File

@ -361,8 +361,8 @@ def selective_scan_fn(u,
x[:, :, 0, 0::2] = 1 x[:, :, 0, 0::2] = 1
if prev_state is not None: if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state) x[:, :, 0, 1::2].copy_(prev_state)
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x) delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if z is None: if z is None:
return out if not return_last_state else (out, last_state) return out if not return_last_state else (out, last_state)

View File

@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded # exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass # here we do the shuffle on first forward pass