[Build] Avoid building too many extensions (#1624)

This commit is contained in:
Yanming W 2023-11-23 16:31:19 -08:00 committed by GitHub
parent de23687d16
commit e0c6f556e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 206 additions and 272 deletions

View File

@ -4,7 +4,7 @@ import time
import torch
from vllm import attention_ops
from vllm._C import ops
NUM_BLOCKS = 1024
PARTITION_SIZE = 512
@ -98,7 +98,7 @@ def main(
for _ in range(num_iters):
if version == "v1":
attention_ops.paged_attention_v1(
ops.paged_attention_v1(
output,
query,
key_cache,
@ -112,7 +112,7 @@ def main(
alibi_slopes,
)
elif version == "v2":
attention_ops.paged_attention_v2(
ops.paged_attention_v2(
output,
exp_sums,
max_logits,

View File

@ -1,28 +0,0 @@
#include <torch/extension.h>
void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
m.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
m.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
}

View File

@ -1,42 +0,0 @@
#include <torch/extension.h>
#include <c10/util/Optional.h>
void paged_attention_v1(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
void paged_attention_v2(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
m.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
}

View File

@ -26,22 +26,3 @@ void gather_cached_kv(
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
m.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
m.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
m.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");
}

View File

@ -1,13 +0,0 @@
#include <torch/extension.h>
int get_device_attribute(
int attribute,
int device_id);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
}

5
csrc/cuda_utils.h Normal file
View File

@ -0,0 +1,5 @@
#include <torch/extension.h>
int get_device_attribute(
int attribute,
int device_id);

View File

@ -1,24 +0,0 @@
#include <torch/extension.h>
void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
}

75
csrc/ops.h Normal file
View File

@ -0,0 +1,75 @@
#include <torch/extension.h>
void paged_attention_v1(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
void paged_attention_v2(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);

View File

@ -1,16 +0,0 @@
#include <torch/extension.h>
void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
}

80
csrc/pybind.cpp Normal file
View File

@ -0,0 +1,80 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");
// Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
}

View File

@ -1,19 +0,0 @@
#include <torch/extension.h>
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
}

View File

@ -142,87 +142,25 @@ if nvcc_cuda_version >= Version("11.2"):
NVCC_FLAGS += ["--threads", str(num_threads)]
ext_modules = []
# Cache operations.
cache_extension = CUDAExtension(
name="vllm.cache_ops",
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(cache_extension)
# Attention kernels.
attention_extension = CUDAExtension(
name="vllm.attention_ops",
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(attention_extension)
# Positional encoding kernels.
positional_encoding_extension = CUDAExtension(
name="vllm.pos_encoding_ops",
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(positional_encoding_extension)
# Layer normalization kernels.
layernorm_extension = CUDAExtension(
name="vllm.layernorm_ops",
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(layernorm_extension)
# Activation kernels.
activation_extension = CUDAExtension(
name="vllm.activation_ops",
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(activation_extension)
# Quantization kernels.
quantization_extension = CUDAExtension(
name="vllm.quantization_ops",
vllm_extension = CUDAExtension(
name="vllm._C",
sources=[
"csrc/quantization.cpp",
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(quantization_extension)
# Misc. CUDA utils.
cuda_utils_extension = CUDAExtension(
name="vllm.cuda_utils",
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(cuda_utils_extension)
ext_modules.append(vllm_extension)
def get_path(*filepath) -> str:

View File

@ -3,7 +3,7 @@ import torch
import torch.nn.functional as F
from transformers.activations import get_activation
from vllm import activation_ops
from vllm._C import ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
@ -31,7 +31,7 @@ def test_silu_and_mul(
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.silu_and_mul(out, x)
ops.silu_and_mul(out, x)
ref_out = ref_silu_and_mul(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@ -51,7 +51,7 @@ def test_gelu_new(
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_new(out, x)
ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@ -70,6 +70,6 @@ def test_gelu_fast(
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_fast(out, x)
ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

View File

@ -6,7 +6,7 @@ import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm import attention_ops
from vllm._C import ops
from vllm.utils import get_max_shared_memory_bytes
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
@ -165,7 +165,7 @@ def test_paged_attention(
# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
attention_ops.paged_attention_v1(
ops.paged_attention_v1(
output,
query,
key_cache,
@ -194,7 +194,7 @@ def test_paged_attention(
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2(
ops.paged_attention_v2(
output,
exp_sums,
max_logits,

View File

@ -3,7 +3,7 @@ import random
import pytest
import torch
from vllm import cache_ops
from vllm._C import cache_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [83] # Arbitrary values for testing

View File

@ -2,7 +2,7 @@ import pytest
import torch
import torch.nn as nn
from vllm import layernorm_ops
from vllm._C import ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
@ -48,7 +48,7 @@ def test_rms_norm(
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
out = torch.empty_like(x)
layernorm_ops.rms_norm(
ops.rms_norm(
out,
x,
ref.weight.data,

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import pos_encoding_ops
from vllm._C import ops
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -145,7 +145,7 @@ def test_rotary_embedding(
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
pos_encoding_ops.rotary_embedding(
ops.rotary_embedding(
positions,
out_query,
out_key,

View File

@ -4,7 +4,7 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm import activation_ops
from vllm._C import ops
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@ -26,7 +26,7 @@ class SiluAndMul(nn.Module):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
ops.silu_and_mul(out, x)
return out
@ -34,7 +34,7 @@ class NewGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
activation_ops.gelu_new(out, x)
ops.gelu_new(out, x)
return out
@ -42,7 +42,7 @@ class FastGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
activation_ops.gelu_fast(out, x)
ops.gelu_fast(out, x)
return out

View File

@ -7,8 +7,8 @@ from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm import attention_ops
from vllm import cache_ops
from vllm._C import ops
from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import get_rope
@ -163,7 +163,7 @@ class PagedAttention(nn.Module):
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
attention_ops.paged_attention_v1(
ops.paged_attention_v1(
output,
query,
key_cache,
@ -190,7 +190,7 @@ class PagedAttention(nn.Module):
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2(
ops.paged_attention_v2(
output,
exp_sums,
max_logits,

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import layernorm_ops
from vllm._C import ops
class RMSNorm(nn.Module):
@ -29,7 +29,7 @@ class RMSNorm(nn.Module):
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
layernorm_ops.fused_add_rms_norm(
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
@ -37,7 +37,7 @@ class RMSNorm(nn.Module):
)
return x, residual
out = torch.empty_like(x)
layernorm_ops.rms_norm(
ops.rms_norm(
out,
x,
self.weight.data,

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@ -151,8 +151,7 @@ class AWQLinearMethod(LinearMethodBase):
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@ -116,8 +116,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
reshaped_x = x.reshape(-1, x.shape[-1])
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
lookup_table)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
out = out + bias

View File

@ -27,7 +27,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import pos_encoding_ops
from vllm._C import ops
class RotaryEmbedding(nn.Module):
@ -87,11 +87,10 @@ class RotaryEmbedding(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pos_encoding_ops.rotary_embedding() is an in-place operation that
# ops.rotary_embedding() is an in-place operation that
# updates the query and key tensors.
pos_encoding_ops.rotary_embedding(positions, query, key,
self.head_size, self.cos_sin_cache,
self.is_neox_style)
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key

View File

@ -5,7 +5,7 @@ from platform import uname
import psutil
import torch
from vllm import cuda_utils
from vllm._C import cuda_utils
class Device(enum.Enum):

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
import torch
from vllm import cache_ops
from vllm._C import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import in_wsl