[Kernel] Fullgraph and opcheck tests (#8479)
This commit is contained in:
parent
1c046447a6
commit
300da09177
@ -210,6 +210,21 @@ steps:
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
parallelism: 4
|
||||
|
||||
- label: "PyTorch Fullgraph Smoke Test"
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph_smoke.py
|
||||
|
||||
- label: "PyTorch Fullgraph Test"
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Test %N # 30min each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
@ -355,7 +370,7 @@ steps:
|
||||
- tests/distributed/
|
||||
- vllm/compilation
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||
|
@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
std::vector<at::Tensor> result = {out, x.value()};
|
||||
std::vector<at::Tensor> result = {out};
|
||||
if (has_z) { result.push_back(out_z); }
|
||||
return result;
|
||||
}
|
||||
|
@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
|
||||
"Tensor? index_, Tensor!? x) -> Tensor[]");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? bias_,"
|
||||
"Tensor? seq_idx_,"
|
||||
"Tensor? initial_states_,"
|
||||
"Tensor? final_states_out_,"
|
||||
"Tensor!? final_states_out_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
#endif
|
||||
|
@ -1,42 +1,13 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
from .utils import TEST_MODELS, check_full_graph_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@fork_new_process_for_each_test
|
||||
def test_full_graph(model, tp_size):
|
||||
|
||||
# Skip the test if there are not enough CUDA devices.
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
pytest.skip("Not enough CUDA devices for the test.")
|
||||
|
||||
# make sure these models can be captured in full graph mode
|
||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
disable_custom_all_reduce=True)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@pytest.mark.parametrize("model_info", TEST_MODELS)
|
||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
||||
def test_full_graph(model_info, backend):
|
||||
model = model_info[0]
|
||||
model_kwargs = model_info[1]
|
||||
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
|
||||
|
22
tests/compile/test_full_graph_multi_gpu.py
Normal file
22
tests/compile/test_full_graph_multi_gpu.py
Normal file
@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
||||
@fork_new_process_for_each_test
|
||||
def test_full_graph_multi_gpu(model_info, tp_size, backend):
|
||||
model = model_info[0]
|
||||
model_kwargs = model_info[1]
|
||||
|
||||
# Skip the test if there are not enough CUDA devices.
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
pytest.skip("Not enough CUDA devices for the test.")
|
||||
|
||||
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)
|
13
tests/compile/test_full_graph_smoke.py
Normal file
13
tests/compile/test_full_graph_smoke.py
Normal file
@ -0,0 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
|
||||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
|
||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
||||
def test_full_graph(model_info, backend):
|
||||
model = model_info[0]
|
||||
model_kwargs = model_info[1]
|
||||
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
|
104
tests/compile/utils.py
Normal file
104
tests/compile/utils.py
Normal file
@ -0,0 +1,104 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.plugins import set_torch_compile_backend
|
||||
from vllm.utils import is_hip
|
||||
|
||||
TEST_MODELS_SMOKE = [
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("meta-llama/Meta-Llama-3-8B", {}),
|
||||
]
|
||||
|
||||
TEST_MODELS = [
|
||||
("facebook/opt-125m", {}),
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||
"dtype": torch.float16,
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
|
||||
"dtype": torch.float16,
|
||||
"quantization": "fp8"
|
||||
}),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("meta-llama/Meta-Llama-3-8B", {}),
|
||||
]
|
||||
|
||||
# TODO: enable in pytorch 2.5
|
||||
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
|
||||
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
|
||||
"quantization": "aqlm"
|
||||
}))
|
||||
|
||||
# TODO: enable in pytorch 2.5
|
||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
|
||||
"quantization": "gguf"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
|
||||
"quantization": "gptq"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq_marlin"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
|
||||
"quantization": "gptq_marlin"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("gptq_marlin_24"):
|
||||
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
|
||||
"quantization": "gptq_marlin_24"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("marlin"):
|
||||
TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
|
||||
"quantization": "marlin"
|
||||
}))
|
||||
|
||||
if not is_hip() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
||||
"quantization": "AWQ"
|
||||
}))
|
||||
|
||||
|
||||
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
|
||||
# make sure these models can be captured in full graph mode
|
||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
|
||||
# Inductor doesn't support fp8/gptq_marlin_24 yet.
|
||||
quantization = model_kwargs.get("quantization")
|
||||
if (quantization == "fp8" or quantization == "gptq_marlin"
|
||||
or quantization == "gptq_marlin_24") and backend != "eager":
|
||||
return
|
||||
|
||||
set_torch_compile_backend(backend)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
disable_custom_all_reduce=True,
|
||||
**model_kwargs)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def dynamo_reset():
|
||||
yield
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
prompts = []
|
||||
|
37
tests/kernels/test_aqlm.py
Normal file
37
tests/kernels/test_aqlm.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_aqlm_dequant_opcheck():
|
||||
codes = torch.randint(-32768,
|
||||
32767, (22016, 512, 1),
|
||||
device='cuda',
|
||||
dtype=torch.int16)
|
||||
codebooks = torch.rand((2, 65536, 1, 8),
|
||||
device='cuda',
|
||||
dtype=torch.float16)
|
||||
codebook_partition_sizes = [11008, 11008]
|
||||
|
||||
opcheck(torch.ops._C.aqlm_dequant,
|
||||
(codes, codebooks, codebook_partition_sizes))
|
||||
|
||||
|
||||
def test_aqlm_gemm_opcheck():
|
||||
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
|
||||
codes = torch.randint(-32768,
|
||||
32767, (12288, 512, 1),
|
||||
device='cuda',
|
||||
dtype=torch.int16)
|
||||
codebooks = torch.rand((3, 65536, 1, 8),
|
||||
device='cuda',
|
||||
dtype=torch.float16)
|
||||
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
|
||||
codebook_partition_sizes = [4096, 4096, 4096]
|
||||
bias = None
|
||||
|
||||
opcheck(torch.ops._C.aqlm_gemm,
|
||||
(input, codes, codebooks, scales, codebook_partition_sizes, None))
|
||||
opcheck(torch.ops._C.aqlm_gemm,
|
||||
(input, codes, codebooks, scales, codebook_partition_sizes, bias))
|
@ -205,7 +205,8 @@ def test_paged_attention(
|
||||
(output, query, key_cache, value_cache, num_kv_heads, scale,
|
||||
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
|
||||
elif version in ("v2", "rocm"):
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
@ -246,7 +247,8 @@ def test_paged_attention(
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
@ -274,7 +276,8 @@ def test_paged_attention(
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
38
tests/kernels/test_awq.py
Normal file
38
tests/kernels/test_awq.py
Normal file
@ -0,0 +1,38 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_awq_dequantize_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
|
||||
split_k_iters = 0
|
||||
thx = 0
|
||||
thy = 0
|
||||
opcheck(torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||
|
||||
|
||||
def test_awq_gemm_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.randint(-2000000000,
|
||||
2000000000, (64, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm,
|
||||
(input, qweight, qzeros, scales, split_k_iters))
|
@ -5,6 +5,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.utils import seed_everything
|
||||
@ -84,6 +86,64 @@ def causal_conv1d_update_ref(x: torch.Tensor,
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
def causal_conv1d_opcheck_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out=None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
seq_idx: (batch, seqlen)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1), to be written to
|
||||
activation: either None or "silu" or "swish"
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(2) != 1 and x.stride(1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
if seq_idx is not None:
|
||||
assert (initial_states is
|
||||
None), "initial_states must be None if seq_idx is not None"
|
||||
assert (not return_final_states
|
||||
), "If seq_idx is not None, we don't return final_states_out"
|
||||
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
||||
if initial_states is not None and (initial_states.stride(2) != 1
|
||||
and initial_states.stride(1) != 1):
|
||||
initial_states = initial_states.contiguous()
|
||||
if return_final_states:
|
||||
assert (
|
||||
x.stride(1) == 1
|
||||
), "Only channel-last layout support returning final_states_out"
|
||||
if final_states_out is not None:
|
||||
assert (final_states_out.stride(2) == 1
|
||||
or final_states_out.stride(1) == 1)
|
||||
else:
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
final_states_out = torch.empty(batch,
|
||||
width - 1,
|
||||
dim,
|
||||
device=x.device,
|
||||
dtype=x.dtype).transpose(1, 2)
|
||||
else:
|
||||
final_states_out = None
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_fwd,
|
||||
(x, weight, bias, seq_idx, initial_states, final_states_out,
|
||||
activation in ["silu", "swish"]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("return_final_states", [False, True])
|
||||
@pytest.mark.parametrize("has_initial_states", [False, True])
|
||||
@pytest.mark.parametrize("channel_last", [False, True])
|
||||
@ -149,6 +209,14 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=return_final_states,
|
||||
activation=activation)
|
||||
|
||||
causal_conv1d_opcheck_fn(x_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=return_final_states,
|
||||
activation=activation)
|
||||
|
||||
if return_final_states:
|
||||
assert final_states is not None and final_states_ref is not None
|
||||
assert torch.allclose(final_states,
|
||||
@ -205,6 +273,10 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation in ["silu", "swish"], None))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@ -258,7 +330,5 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
@ -15,6 +15,9 @@ CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
@ -74,6 +77,9 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
@ -425,3 +431,7 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
|
||||
|
@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.attention.backends.flash_attn # noqa: F401
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
@ -127,19 +128,19 @@ def test_flash_attn_with_paged_kv(
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
@ -232,23 +233,23 @@ def test_varlen_with_paged_kv(
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
opcheck(torch.ops.vllm.flash_attn_varlen_func,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
|
@ -5,6 +5,7 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -16,6 +17,26 @@ SCALE_UBS = [True, False]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def opcheck_fp8_quant(output,
|
||||
input,
|
||||
scale=None,
|
||||
scale_ub=None,
|
||||
use_per_token_if_dynamic=False):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
|
||||
elif use_per_token_if_dynamic:
|
||||
scale = torch.empty((input.shape[0], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant,
|
||||
(output, input, scale, scale_ub))
|
||||
else:
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@ -41,6 +62,12 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
|
||||
opcheck_fp8_quant(ops_out,
|
||||
x,
|
||||
None,
|
||||
scale_ub,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@ -60,6 +87,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
|
||||
opcheck_fp8_quant(ops_out, x)
|
||||
|
||||
|
||||
# Regression test for a case with large activations where an int32 index cannot
|
||||
# represent the number of elements.
|
||||
|
22
tests/kernels/test_ggml.py
Normal file
22
tests/kernels/test_ggml.py
Normal file
@ -0,0 +1,22 @@
|
||||
import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_type", [12])
|
||||
def test_ggml_opcheck(quant_type):
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
|
||||
shape = [256, 1152]
|
||||
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
|
||||
m = qweight.shape[0]
|
||||
n = qweight.shape[1] // type_size * block_size
|
||||
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n))
|
||||
|
||||
x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
|
||||
opcheck(torch.ops._C.ggml_mul_mat_a8,
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
||||
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
29
tests/kernels/test_gptq.py
Normal file
29
tests/kernels/test_gptq.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_gptq_shuffle_opcheck():
|
||||
weight = torch.randint(-2000000,
|
||||
2000000, (1792, 4096),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
perm = torch.empty((0, ), device='cuda', dtype=torch.int32)
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
|
||||
|
||||
|
||||
def test_gptq_gemm_opcheck():
|
||||
a = torch.rand((240, 4096), device='cuda', dtype=torch.float16)
|
||||
weight = torch.randint(-2000000,
|
||||
2000000, (512, 6144),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32)
|
||||
scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16)
|
||||
idx = torch.empty((0, ), device='cuda', dtype=torch.int32)
|
||||
use_exllama = True
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_gemm,
|
||||
(a, weight, zeros, scales, idx, use_exllama, bit))
|
@ -3,6 +3,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.utils import seed_everything
|
||||
@ -161,6 +163,59 @@ def selective_scan_ref(u,
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
def selective_scan_opcheck_fn(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
position_indices=None,
|
||||
prev_state=None):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 3:
|
||||
C = C.unsqueeze(1)
|
||||
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
|
||||
x = torch.zeros((
|
||||
u.shape[0],
|
||||
u.shape[1],
|
||||
n_chunks,
|
||||
int(A.shape[1] * 2),
|
||||
),
|
||||
device=u.device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=False)
|
||||
x[:, :, 0, 0::2] = 1
|
||||
if prev_state is not None:
|
||||
x[:, :, 0, 1::2].copy_(prev_state)
|
||||
|
||||
# Disable test_autograd_registration for now as it seems to trigger
|
||||
# a bogus error.
|
||||
opcheck(torch.ops._C.selective_scan_fwd,
|
||||
(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
|
||||
position_indices, x),
|
||||
test_utils=["test_schema", "test_faketensor"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
@pytest.mark.parametrize('itype', [torch.float32])
|
||||
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
|
||||
@ -274,6 +329,17 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
assert state is not None and state_ref is not None
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
|
||||
selective_scan_opcheck_fn(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z=z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
|
@ -501,3 +501,18 @@ def test_marlin_qqq_gemm(
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_opcheck():
|
||||
size_m = 2048
|
||||
size_n = 4096
|
||||
size_k = 4096
|
||||
a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
|
||||
w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
|
||||
s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
|
||||
wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL).scratch
|
||||
x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
torch.testing.assert_close(x, y)
|
||||
opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
|
||||
|
@ -9,11 +9,14 @@ import torch
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
@ -247,6 +250,35 @@ def test_fused_marlin_moe(
|
||||
|
||||
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||
|
||||
if ops.supports_moe_ops:
|
||||
token_expert_indicies = torch.empty(m,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=a.device)
|
||||
|
||||
opcheck(torch.ops._moe_C.topk_softmax, (
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
score.float(),
|
||||
))
|
||||
|
||||
block_size_m = 4
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
|
||||
e)
|
||||
|
||||
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
opcheck(torch.ops._moe_C.marlin_gemm_moe,
|
||||
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
|
||||
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
|
||||
2 * n, k, True, e, topk, block_size_m, True, False))
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@ -319,3 +351,29 @@ def test_single_marlin_moe_multiply(
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
num_experts = 4
|
||||
block_size = 4
|
||||
topk_ids = torch.randint(0,
|
||||
num_experts, (3, 4),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
|
||||
opcheck(torch.ops._C.moe_align_block_size,
|
||||
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
||||
num_tokens_post_pad))
|
||||
|
62
tests/kernels/test_rotary_embedding.py
Normal file
62
tests/kernels/test_rotary_embedding.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
Tests for miscellaneous utilities
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
def rotary_embedding_opcheck(rot,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
opcheck(torch.ops._C.batched_rotary_embedding,
|
||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
||||
rot.is_neox_style, rot.rotary_dim, offsets))
|
||||
else:
|
||||
opcheck(torch.ops._C.rotary_embedding,
|
||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
||||
rot.is_neox_style))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
|
||||
@pytest.mark.parametrize("is_neox_style", [True, False])
|
||||
@pytest.mark.parametrize("rotary_dim", [32])
|
||||
@pytest.mark.parametrize("head_size", [32, 108])
|
||||
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||
def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
||||
is_neox_style, rotary_dim, head_size,
|
||||
seq_len):
|
||||
batch_size = 1
|
||||
base = 0
|
||||
num_heads = 7
|
||||
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, torch.float32)
|
||||
|
||||
positions = torch.randint(0,
|
||||
max_position, (batch_size, seq_len),
|
||||
device=device)
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
rotary_embedding_opcheck(rot, positions, query, key)
|
||||
offsets = torch.zeros(batch_size * seq_len,
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
rotary_embedding_opcheck(rot, positions, query, key, offsets)
|
24
tests/kernels/test_utils.py
Normal file
24
tests/kernels/test_utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
Tests for miscellaneous utilities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def test_convert_fp8_opcheck():
|
||||
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
|
||||
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
|
||||
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="Only supported for CUDA")
|
||||
def test_cuda_utils_opcheck():
|
||||
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
|
||||
opcheck(
|
||||
torch.ops._C_cuda_utils.
|
||||
get_max_shared_memory_per_block_device_attribute, (0, ))
|
@ -2,12 +2,14 @@
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from numbers import Number
|
||||
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
||||
Union)
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._prims_common import TensorLikeType
|
||||
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
output_under_test.view_as(ideal_output))
|
||||
|
||||
|
||||
# Copied/modified from torch._refs.__init__.py
|
||||
def fp8_allclose(
|
||||
a: TensorLikeType,
|
||||
b: TensorLikeType,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Reference implementation of torch.allclose
|
||||
"""
|
||||
torch._refs._check_close_args(name="torch.allclose",
|
||||
a=a,
|
||||
b=b,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
return bool(
|
||||
torch.all(
|
||||
torch.isclose(a.double(),
|
||||
b.double(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=equal_nan)).item())
|
||||
|
||||
|
||||
# A special version of op check that has a restricted default set of test_utils
|
||||
# and a patched version of allclose that supports fp8 types.
|
||||
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
torch._library.custom_ops.CustomOpDef],
|
||||
args: Tuple[Any, ...],
|
||||
@ -954,9 +984,10 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
cond: bool = True) -> Dict[str, str]:
|
||||
return torch.library.opcheck(
|
||||
op,
|
||||
args,
|
||||
kwargs,
|
||||
test_utils=test_utils,
|
||||
raise_exception=raise_exception) if cond else {}
|
||||
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
|
||||
return torch.library.opcheck(
|
||||
op,
|
||||
args,
|
||||
kwargs,
|
||||
test_utils=test_utils,
|
||||
raise_exception=raise_exception) if cond else {}
|
||||
|
@ -20,8 +20,10 @@ if not current_platform.is_tpu():
|
||||
if current_platform.is_rocm():
|
||||
import vllm._rocm_C # noqa: F401
|
||||
|
||||
supports_moe_ops = False
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C # noqa: F401
|
||||
supports_moe_ops = True
|
||||
|
||||
|
||||
def hint_on_error(fn):
|
||||
@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_g_idx, use_exllama, bit)
|
||||
|
||||
|
||||
# TODO: has to be a better way to do this
|
||||
try:
|
||||
torch.ops._C.gptq_gemm # noqa B018
|
||||
if hasattr(torch.ops._C, "gptq_gemm"):
|
||||
|
||||
@torch.library.register_fake("_C::gptq_gemm")
|
||||
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
@ -265,8 +265,6 @@ try:
|
||||
return torch.empty((a.size(0), b_q_weight.size(1)),
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
|
||||
@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
# TODO: has to be a better way to do this
|
||||
try:
|
||||
torch.ops._C.gptq_marlin_24_gemm # noqa B018
|
||||
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
|
||||
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
|
||||
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
@ -420,8 +416,8 @@ try:
|
||||
@torch.library.register_fake("_C::machete_gemm")
|
||||
def machete_gemm_fake(
|
||||
a: torch.Tensor,
|
||||
b_q: torch.
|
||||
Tensor, # Should be the tensor returned by machete_prepack_B
|
||||
# Should be the tensor returned by machete_prepack_B
|
||||
b_q: torch.Tensor,
|
||||
b_type: ScalarType,
|
||||
b_scales: Optional[torch.Tensor] = None,
|
||||
b_zeros: Optional[torch.Tensor] = None,
|
||||
@ -451,10 +447,10 @@ try:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@torch.library.register_fake("_C::causal_conv1d_update")
|
||||
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool) -> torch.Tensor:
|
||||
def causal_conv1d_update_fake(
|
||||
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor], silu_activation: bool,
|
||||
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@torch.library.register_fake("_C::selective_scan_fwd")
|
||||
@ -465,20 +461,11 @@ try:
|
||||
delta_softplus: bool, index_: Optional[torch.Tensor],
|
||||
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
|
||||
a = torch.empty_like(u)
|
||||
if x is not None:
|
||||
b = x
|
||||
else:
|
||||
b = torch.empty((u.size(0), u.size(1), A.size(1)),
|
||||
dtype=u.dtype,
|
||||
device=u.device)
|
||||
if z_ is not None:
|
||||
c = torch.empty_like(z_)
|
||||
return [a, b, c]
|
||||
return [a, c]
|
||||
else:
|
||||
return [a, b]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return [a]
|
||||
|
||||
|
||||
# cutlass
|
||||
@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
|
||||
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
|
||||
|
||||
|
||||
# TODO: has to be a better way to do this
|
||||
try:
|
||||
torch.ops._C.permute_cols # noqa B018
|
||||
if hasattr(torch.ops._C, "permute_cols"):
|
||||
|
||||
@torch.library.register_fake("_C::permute_cols")
|
||||
def _permute_cols_fake(a: torch.Tensor,
|
||||
perm: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(a)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||
@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
token_expert_indicies, gating_output)
|
||||
|
||||
|
||||
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
|
||||
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
|
||||
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
|
||||
sorted_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, b_scales: torch.Tensor,
|
||||
g_idx: torch.Tensor, perm: torch.Tensor,
|
||||
workspace: torch.Tensor, b_q_type: ScalarType,
|
||||
size_m: int, size_n: int, size_k: int,
|
||||
is_k_full: bool, num_experts: int, topk: int,
|
||||
moe_block_size: int, replicate_input: bool,
|
||||
apply_weights: bool) -> torch.Tensor:
|
||||
return torch.empty((size_m, topk, size_n),
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
|
@ -361,8 +361,8 @@ def selective_scan_fn(u,
|
||||
x[:, :, 0, 0::2] = 1
|
||||
if prev_state is not None:
|
||||
x[:, :, 0, 1::2].copy_(prev_state)
|
||||
out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, position_indices, x)
|
||||
out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
|
||||
delta_softplus, position_indices, x)
|
||||
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
||||
if z is None:
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user