[Kernel] optimize performance of gptq marlin kernel when n is small (#14138)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
parent
58abe35455
commit
d0feea31c7
@ -538,6 +538,7 @@ __global__ void Marlin(
|
|||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int* locks, // extra global storage for barrier synchronization
|
int* locks, // extra global storage for barrier synchronization
|
||||||
|
bool use_atomic_add, // whether to use atomic add to reduce
|
||||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||||
) {
|
) {
|
||||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||||
@ -1542,7 +1543,17 @@ __global__ void Marlin(
|
|||||||
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||||
i++) {
|
i++) {
|
||||||
if (c_gl_wr < c_gl_wr_end) {
|
if (c_gl_wr < c_gl_wr_end) {
|
||||||
C[c_gl_wr] = sh_red[c_sh_rd];
|
if (use_atomic_add && slice_count > 1) {
|
||||||
|
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]);
|
||||||
|
scalar_t2* sh_red_half2 =
|
||||||
|
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int a = 0; a < 4; a++) {
|
||||||
|
atomicAdd(&C_half2[a], sh_red_half2[a]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
C[c_gl_wr] = sh_red[c_sh_rd];
|
||||||
|
}
|
||||||
c_gl_wr += c_gl_wr_delta;
|
c_gl_wr += c_gl_wr_delta;
|
||||||
c_sh_rd += c_sh_rd_delta;
|
c_sh_rd += c_sh_rd_delta;
|
||||||
}
|
}
|
||||||
@ -1644,7 +1655,7 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
} else {
|
} else {
|
||||||
if (last) {
|
if (last || use_atomic_add) {
|
||||||
if (s_sh_wr_pred) {
|
if (s_sh_wr_pred) {
|
||||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||||
}
|
}
|
||||||
@ -1664,7 +1675,7 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if (last) {
|
if (last || use_atomic_add) {
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
@ -1703,8 +1714,8 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
if (slice_count > 1 && !use_atomic_add) {
|
||||||
// block in a slice
|
// only globally reduce if there is more than one block in a slice
|
||||||
barrier_acquire(&locks[slice_col], slice_idx);
|
barrier_acquire(&locks[slice_col], slice_idx);
|
||||||
if (use_fp32_reduce) {
|
if (use_fp32_reduce) {
|
||||||
global_reduce_fp32(slice_idx == 0, last);
|
global_reduce_fp32(slice_idx == 0, last);
|
||||||
@ -1713,7 +1724,8 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
barrier_release(&locks[slice_col], last);
|
barrier_release(&locks[slice_col], last);
|
||||||
}
|
}
|
||||||
if (last) // only the last block in a slice actually writes the result
|
if (last || use_atomic_add)
|
||||||
|
// only the last block in a slice actuallywrites the result
|
||||||
write_result();
|
write_result();
|
||||||
slice_row = 0;
|
slice_row = 0;
|
||||||
slice_col_par++;
|
slice_col_par++;
|
||||||
@ -1768,7 +1780,8 @@ __global__ void Marlin(
|
|||||||
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
||||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
|
||||||
|
use_fp32_reduce); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
vllm::ScalarType const& q_type, bool has_act_order,
|
vllm::ScalarType const& q_type, bool has_act_order,
|
||||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||||
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
|
int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce,
|
||||||
|
bool is_zp_float) {
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||||
@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
torch::Tensor& workspace,
|
torch::Tensor& workspace,
|
||||||
vllm::ScalarTypeId const& b_q_type_id,
|
vllm::ScalarTypeId const& b_q_type_id,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||||
bool is_k_full, bool has_zp,
|
bool is_k_full, bool has_zp, bool use_atomic_add,
|
||||||
bool use_fp32_reduce, bool is_zp_float) {
|
bool use_fp32_reduce, bool is_zp_float) {
|
||||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
// Alloc buffers
|
// Alloc buffers
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||||
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
torch::Tensor c;
|
||||||
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
if (use_atomic_add) {
|
||||||
|
c = torch::zeros({size_m, size_n}, options);
|
||||||
|
} else {
|
||||||
|
c = torch::empty({size_m, size_n}, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor a_tmp;
|
||||||
|
bool has_act_order = g_idx.size(0) != 0;
|
||||||
|
if (has_act_order) {
|
||||||
|
a_tmp = torch::empty({size_m, size_k}, options);
|
||||||
|
} else {
|
||||||
|
a_tmp = torch::empty({0}, options);
|
||||||
|
}
|
||||||
|
|
||||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||||
|
torch::Tensor c_tmp;
|
||||||
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
||||||
int reduce_n = size_n;
|
int reduce_n = size_n;
|
||||||
auto options_fp32 =
|
auto options_fp32 =
|
||||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||||
if (!use_fp32_reduce) {
|
if (use_fp32_reduce) {
|
||||||
|
c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||||
|
} else {
|
||||||
reduce_max_m = 0;
|
reduce_max_m = 0;
|
||||||
reduce_n = 0;
|
reduce_n = 0;
|
||||||
|
c_tmp = torch::empty({0}, options_fp32);
|
||||||
}
|
}
|
||||||
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
|
||||||
|
|
||||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||||
// auto -1)
|
// auto -1)
|
||||||
@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
// Detect groupsize and act_order
|
// Detect groupsize and act_order
|
||||||
int num_groups = -1;
|
int num_groups = -1;
|
||||||
int group_size = -1;
|
int group_size = -1;
|
||||||
bool has_act_order = g_idx.size(0) != 0;
|
|
||||||
|
|
||||||
int rank = b_scales.sizes().size();
|
int rank = b_scales.sizes().size();
|
||||||
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||||
@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
||||||
|
use_fp32_reduce, is_zp_float);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
marlin::marlin_mm<nv_bfloat16>(
|
marlin::marlin_mm<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
||||||
|
use_fp32_reduce, is_zp_float);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||||
}
|
}
|
||||||
|
@ -272,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||||
"int b_q_type, "
|
"int b_q_type, "
|
||||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||||
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
|
||||||
|
"bool is_zp_float) -> Tensor",
|
||||||
{stride_tag});
|
{stride_tag});
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ from vllm.scalar_type import scalar_types
|
|||||||
|
|
||||||
ACT_ORDER_OPTS = [False, True]
|
ACT_ORDER_OPTS = [False, True]
|
||||||
K_FULL_OPTS = [False, True]
|
K_FULL_OPTS = [False, True]
|
||||||
|
USE_ATOMIC_ADD_OPTS = [False, True]
|
||||||
USE_FP32_REDUCE_OPTS = [False, True]
|
USE_FP32_REDUCE_OPTS = [False, True]
|
||||||
|
|
||||||
MARLIN_K_CHUNKS = [128]
|
MARLIN_K_CHUNKS = [128]
|
||||||
@ -194,6 +195,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
|||||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||||
|
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||||
def test_gptq_marlin_gemm(
|
def test_gptq_marlin_gemm(
|
||||||
k_chunk,
|
k_chunk,
|
||||||
@ -203,6 +205,7 @@ def test_gptq_marlin_gemm(
|
|||||||
mnk_factors,
|
mnk_factors,
|
||||||
act_order,
|
act_order,
|
||||||
is_k_full,
|
is_k_full,
|
||||||
|
use_atomic_add,
|
||||||
use_fp32_reduce,
|
use_fp32_reduce,
|
||||||
):
|
):
|
||||||
m_factor, n_factor, k_factor = mnk_factors
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
@ -228,12 +231,12 @@ def test_gptq_marlin_gemm(
|
|||||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
opcheck(
|
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||||
torch.ops._C.gptq_marlin_gemm,
|
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
workspace.scratch, quant_type.id, a_input.shape[0],
|
||||||
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
b_weight.shape[1], a_input.shape[1], is_k_full, False,
|
||||||
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
|
use_atomic_add, use_fp32_reduce, False),
|
||||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(
|
output = ops.gptq_marlin_gemm(
|
||||||
a_input,
|
a_input,
|
||||||
@ -249,6 +252,7 @@ def test_gptq_marlin_gemm(
|
|||||||
a_input.shape[1],
|
a_input.shape[1],
|
||||||
is_k_full=is_k_full,
|
is_k_full=is_k_full,
|
||||||
has_zp=False,
|
has_zp=False,
|
||||||
|
use_atomic_add=use_atomic_add,
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False,
|
is_zp_float=False,
|
||||||
)
|
)
|
||||||
|
@ -301,6 +301,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
size_k: torch.SymInt,
|
size_k: torch.SymInt,
|
||||||
is_k_full: bool,
|
is_k_full: bool,
|
||||||
has_zp: bool = False,
|
has_zp: bool = False,
|
||||||
|
use_atomic_add: bool = False,
|
||||||
use_fp32_reduce: bool = False,
|
use_fp32_reduce: bool = False,
|
||||||
is_zp_float: bool = False) -> torch.Tensor:
|
is_zp_float: bool = False) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||||
@ -713,12 +714,14 @@ def gptq_marlin_gemm(a: torch.Tensor,
|
|||||||
size_k: int,
|
size_k: int,
|
||||||
is_k_full: bool,
|
is_k_full: bool,
|
||||||
has_zp: bool = False,
|
has_zp: bool = False,
|
||||||
|
use_atomic_add: bool = False,
|
||||||
use_fp32_reduce: bool = False,
|
use_fp32_reduce: bool = False,
|
||||||
is_zp_float: bool = False) -> torch.Tensor:
|
is_zp_float: bool = False) -> torch.Tensor:
|
||||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||||
g_idx, perm, workspace, b_q_type.id,
|
g_idx, perm, workspace, b_q_type.id,
|
||||||
size_m, size_n, size_k, is_k_full,
|
size_m, size_n, size_k, is_k_full,
|
||||||
has_zp, use_fp32_reduce, is_zp_float)
|
has_zp, use_atomic_add,
|
||||||
|
use_fp32_reduce, is_zp_float)
|
||||||
|
|
||||||
|
|
||||||
# fp8 marlin
|
# fp8 marlin
|
||||||
|
@ -95,6 +95,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DP_SIZE: int = 1
|
VLLM_DP_SIZE: int = 1
|
||||||
VLLM_DP_MASTER_IP: str = ""
|
VLLM_DP_MASTER_IP: str = ""
|
||||||
VLLM_DP_MASTER_PORT: int = 0
|
VLLM_DP_MASTER_PORT: int = 0
|
||||||
|
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -630,6 +631,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to use S3 path for model loading in CI via RunAI Streamer
|
# Whether to use S3 path for model loading in CI via RunAI Streamer
|
||||||
"VLLM_CI_USE_S3":
|
"VLLM_CI_USE_S3":
|
||||||
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
|
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
|
||||||
|
|
||||||
|
# Whether to use atomicAdd reduce in gptq/awq marlin kernel.
|
||||||
|
"VLLM_MARLIN_USE_ATOMIC_ADD":
|
||||||
|
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
|
|||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -290,6 +291,23 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
|
||||||
|
dtype: torch.dtype) -> bool:
|
||||||
|
# disable atomicAdd reduce by default,
|
||||||
|
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
||||||
|
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# sm8x doesn't support atomicAdd + bfloat16 natively
|
||||||
|
device_capability = torch.cuda.get_device_capability(device)
|
||||||
|
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# the performance of atomicAdd is better than global reduce
|
||||||
|
# only when m*n is small and k is large
|
||||||
|
return max(m, 64) * n < 64 * 2048 and k >= 2048
|
||||||
|
|
||||||
|
|
||||||
def apply_gptq_marlin_linear(
|
def apply_gptq_marlin_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
@ -307,6 +325,12 @@ def apply_gptq_marlin_linear(
|
|||||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||||
|
|
||||||
|
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||||
|
n=output_size_per_partition,
|
||||||
|
k=reshaped_x.size(1),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype)
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||||
weight,
|
weight,
|
||||||
weight_scale,
|
weight_scale,
|
||||||
@ -320,6 +344,7 @@ def apply_gptq_marlin_linear(
|
|||||||
size_k=input_size_per_partition,
|
size_k=input_size_per_partition,
|
||||||
is_k_full=is_k_full,
|
is_k_full=is_k_full,
|
||||||
has_zp=False,
|
has_zp=False,
|
||||||
|
use_atomic_add=use_atomic_add,
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False)
|
is_zp_float=False)
|
||||||
|
|
||||||
@ -345,6 +370,12 @@ def apply_awq_marlin_linear(
|
|||||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||||
|
|
||||||
|
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||||
|
n=output_size_per_partition,
|
||||||
|
k=reshaped_x.size(1),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype)
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||||
weight,
|
weight,
|
||||||
weight_scale,
|
weight_scale,
|
||||||
@ -358,6 +389,7 @@ def apply_awq_marlin_linear(
|
|||||||
size_k=input_size_per_partition,
|
size_k=input_size_per_partition,
|
||||||
is_k_full=True,
|
is_k_full=True,
|
||||||
has_zp=True,
|
has_zp=True,
|
||||||
|
use_atomic_add=use_atomic_add,
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False)
|
is_zp_float=False)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user