[Kernel] support merge_attn_states CUDA kernel, 3x speedup (#16173)

Signed-off-by: DefTruth <qiustudent_r@163.com>
This commit is contained in:
DefTruth 2025-04-11 20:50:50 +08:00 committed by GitHub
parent 51baa9c333
commit e9528f6dc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 519 additions and 4 deletions

View File

@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"

View File

@ -0,0 +1,173 @@
#include <optional>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
namespace vllm {
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
if (global_idx >= token_head_threads) return;
// global_idx -> token_idx + head_idx + pack_idx
const uint token_head_idx = global_idx / threads_per_head;
const uint pack_idx = global_idx % threads_per_head;
const uint token_idx = token_head_idx / num_heads;
const uint head_idx = token_head_idx % num_heads;
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
const uint head_offset =
token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
scalar_t* output_head_ptr = output + head_offset;
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
const float max_lse = fmaxf(p_lse, s_lse);
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);
const float s_se = expf(s_lse);
const float out_se = p_se + s_se;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
o_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
float out_lse = logf(out_se) + max_lse;
output_lse[head_idx * num_tokens + token_idx] = out_lse;
}
}
} // namespace vllm
// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(uint16_t); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,d] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,d] The log-sum-exp values for the suffix attention
* states.
*/
template <typename scalar_t>
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
// process one pack elements per thread. float -> 4, half/bf16 -> 8
const uint threads_per_head = head_size / pack_size;
const uint total_threads = num_tokens * num_heads * threads_per_head;
dim3 block(NUM_THREADS);
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
}
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@ -52,6 +52,15 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);

View File

@ -64,6 +64,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
ops.def(
"merge_attn_states("
" Tensor! output,"
" Tensor!? output_lse,"
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#endif
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");

View File

@ -0,0 +1,265 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton)
from vllm.platforms import current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states_torch(
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
):
p_lse = prefix_lse
s_lse = suffix_lse
# inf -> -inf
p_lse[p_lse == torch.inf] = -torch.inf
s_lse[s_lse == torch.inf] = -torch.inf
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse = torch.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
s_lse_exp = torch.exp(s_lse)
out_se = (p_lse_exp + s_lse_exp)
if output_lse is not None:
output_lse = torch.log(out_se) + max_lse
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0,
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0,
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output = prefix_output * p_scale + suffix_output * s_scale
return output, output_lse
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536, 4096]
NUM_QUERY_HEADS = [4, 8, 16, 32, 48, 64]
HEAD_SIZES = [32, 48, 64, 96, 128, 256]
DTYPES = [torch.float32, torch.half, torch.bfloat16]
all_case_info: list[tuple] = []
def generate_markdown_table():
global all_case_info
table_header = ("| tokens | heads | headsize | dtype "
"| device | torch | triton | cuda | speedup |")
table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |"
def shortly_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def shortly_device(device: str) -> str:
return device.removeprefix("NVIDIA").strip()
print(table_header)
print(table_separator)
for info in all_case_info:
(num_tokens, num_heads, head_size, dtype, device,
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
performance_improved) = info
dtype = shortly_dtype(dtype)
device = shortly_device(device)
print(f"| {num_tokens} | {num_heads} | {head_size} "
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
f"| {avg_time_triton_kernel:.5f}ms "
f"| {avg_time_cuda_kernel:.5f}ms "
f"| {performance_improved:.4f}x |")
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
head_size: int, output_dtype: torch.dtype):
if not current_platform.is_cuda():
pytest.skip('Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel')
NUM_TOKENS = num_tokens
NUM_HEADS = num_query_heads
HEAD_SIZE = head_size
print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"Device: {current_platform.get_device_name()}")
# prefix_lse and suffix_lse contain inf and normal values
prefix_lse = torch.randn(NUM_HEADS,
NUM_TOKENS,
dtype=torch.float32,
device="cuda")
suffix_lse = torch.randn(NUM_HEADS,
NUM_TOKENS,
dtype=torch.float32,
device="cuda")
# Generate boolean masks
mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
mask_suffix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
# Ensure that the same position is not True at the same time
combined_mask = torch.logical_and(mask_prefix, mask_suffix)
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
prefix_lse[mask_prefix] = float('inf')
suffix_lse[mask_suffix] = float('inf')
# Other input tensors (need to be initialized but
# no actual calculation needed)
output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS),
dtype=torch.float32,
device="cuda")
prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
warmup_times = 2
repeat_times = 20
output_torch = output.clone()
output_lse_torch = output_lse.clone()
total_time_torch_kernel = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# 0. Run the Torch kernel
prefix_lse_torch = prefix_lse.clone()
suffix_lse_torch = suffix_lse.clone()
for _ in range(warmup_times):
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch, prefix_output, prefix_lse_torch, suffix_output,
suffix_lse_torch, output_lse_torch)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch, prefix_output, prefix_lse_torch, suffix_output,
suffix_lse_torch, output_lse_torch)
end.record()
torch.cuda.synchronize()
total_time_torch_kernel += start.elapsed_time(end)
avg_time_torch_kernel = total_time_torch_kernel / repeat_times
# 1. Run the Triton kernel
output_ref_triton = output.clone()
output_lse_ref_triton = output_lse.clone()
total_time_triton_kernel = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times):
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
suffix_output, suffix_lse,
output_lse_ref_triton)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
suffix_output, suffix_lse,
output_lse_ref_triton)
end.record()
torch.cuda.synchronize()
total_time_triton_kernel += start.elapsed_time(end)
avg_time_triton_kernel = total_time_triton_kernel / repeat_times
# 2. Run the CUDA kernel
total_time_cuda_kernel = 0
output_cuda = output.clone()
output_lse_cuda = output_lse.clone()
for _ in range(warmup_times):
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse_cuda)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse_cuda)
end.record()
torch.cuda.synchronize()
total_time_cuda_kernel += start.elapsed_time(end)
avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times
# 3. Performance compare
performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel
print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
print(f"Triton time: {avg_time_triton_kernel:.6f}ms")
print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
f"Performance: {performance_improved:.5f}x")
print("-" * 100)
# 4. Correctness compare
# Liger Kernel: Efficient Triton Kernels for LLM Training
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
# use rtol = 1e-2 for bfloat16.
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
def diff(a: torch.Tensor, b: torch.Tensor):
max_diff = torch.max(torch.abs(a.float() - b.float()))
return max_diff
# Use Triton output as reference because we want to replace
# the Triton kernel with custom CUDA kernel for merge attn
# states operation.
output_ref = output_ref_triton
output_lse_ref = output_lse_ref_triton
torch.testing.assert_close(output_cuda.float(),
output_ref.float(),
atol=1e-3,
rtol=rtol)
print("Output all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
print("-" * 100)
torch.testing.assert_close(output_lse_cuda.float(),
output_lse_ref.float(),
atol=1e-3,
rtol=rtol)
print("Output LSE all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}")
print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}")
print("-" * 100)
print("All output values test passed! All inf values "
"are correctly replaced with -inf.")
print("-" * 100)
device = current_platform.get_device_name()
all_case_info.append(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device,
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
performance_improved))
if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) *
len(NUM_QUERY_HEADS) * len(DTYPES)):
generate_markdown_table()

View File

@ -138,6 +138,17 @@ def mla_decode_kvcache_cpu(
block_tables, seq_lens)
# merge attn states ops
def merge_attn_states(output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None) -> None:
torch.ops._C.merge_attn_states(output, output_lse, prefix_output,
prefix_lse, suffix_output, suffix_lse)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,

View File

@ -204,6 +204,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
@ -217,9 +218,7 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
else:
merge_attn_states = None
triton_attention = None
try:

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from vllm.platforms import current_platform
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None,
) -> None:
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# is not support for FP8 dtype, fallback to use Triton kernel.
def supported_dtypes(o: torch.Tensor) -> bool:
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
# kernel load/store 128b(16 bytes) per memory issue within
# thread. Namely, the headsize(headdim) must be multiple of
# pack_size (float32 -> 4, half/bfloat16 -> 8).
def supported_headdim(o: torch.Tensor) -> bool:
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if o.dtype == torch.float32:
return headdim % 4 == 0
return headdim % 8 == 0
if (current_platform.is_cuda() and supported_dtypes(output)
and supported_headdim(output)):
from vllm._custom_ops import merge_attn_states
return merge_attn_states(output, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse)
else:
from vllm.attention.ops.triton_merge_attn_states import (
merge_attn_states)
return merge_attn_states(output, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse)

View File

@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv

View File

@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,