vllm/vllm/_custom_ops.py
Tyler Michael Smith 85657b5607
[Kernel] Factor out epilogues from cutlass kernels (#5391)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: zifeitong <zifei.tong@parasail.io>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-06-13 11:22:19 -07:00

505 lines
17 KiB
Python

import contextlib
import functools
from typing import List, Optional, Tuple, Type
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
import vllm._moe_C
with contextlib.suppress(ImportError):
# ruff: noqa: F401
import vllm._punica_C
def is_custom_op_supported(op_name: str) -> bool:
op, overloads = torch._C._jit_get_operation(op_name)
return op is not None
def hint_on_error(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except AttributeError as e:
msg = (
"Error in calling custom op %s: %s\n"
"Possibly you have built or installed an obsolete version of vllm.\n"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger.error(msg, fn.__name__, e)
raise e
return wrapper
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m,
size_n, size_k)
# cutlass
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
m = a.shape[0]
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b)
return out
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes)
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensor for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
if batch_dim_padding:
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
output = torch.empty(shape,
device=input.device,
dtype=torch.float8_e4m3fn)
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# int8
def scaled_int8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize the input tensor to int8 and return the quantized tensor and scale.
Args:
input: The input tensor to be quantized to int8.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
Returns:
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
"""
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
# static-per-tensor quantization.
torch.ops._C.static_scaled_int8_quant(output, input, scale)
return output, scale
# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
gating_output: float) -> None:
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
token_expert_indicies, gating_output)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, kv_scale)
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor,
input: torch.Tensor,
scale: float = 1.0,
kv_dtype: str = "fp8") -> None:
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# ruff: noqa: E501
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
device)
# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
handles: List[str], offsets: List[int], rank: int,
full_nvlink: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
offsets, rank, full_nvlink)
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
full_nvlink: bool) -> bool:
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
full_nvlink)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
offsets: List[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: List[str],
offsets: List[List[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
# punica
def dispatch_bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
) -> None:
torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
scale)
def dispatch_bgmv_low_level(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
h_in: int,
h_out: int,
y_offset: int,
) -> None:
torch.ops._punica_C.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
h_in,
h_out,
y_offset,
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if isinstance(v, fn_type) \
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
names_and_values_to_update[k] = hint_on_error(v)
names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type