[Misc] Reduce supported Punica dtypes (#4304)
This commit is contained in:
parent
e4bf860a54
commit
468d761b32
@ -212,23 +212,11 @@ define_gpu_extension_target(
|
|||||||
|
|
||||||
set(VLLM_PUNICA_EXT_SRC
|
set(VLLM_PUNICA_EXT_SRC
|
||||||
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
|
|
||||||
"csrc/punica/punica_ops.cc")
|
"csrc/punica/punica_ops.cc")
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
|
@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
|
@ -18,6 +18,26 @@ for input_dtype in DTYPES:
|
|||||||
if weight_dtype == "fp32":
|
if weight_dtype == "fp32":
|
||||||
# FP32 weights are not supported.
|
# FP32 weights are not supported.
|
||||||
continue
|
continue
|
||||||
|
if output_dtype == "fp32":
|
||||||
|
# LoRA A matrix.
|
||||||
|
if input_dtype != weight_dtype:
|
||||||
|
# NOTE(woosuk): While Punica supports the case where the
|
||||||
|
# input and weight dtypes are different, we only generate
|
||||||
|
# the kernels the same dtypes to reduce the binary size.
|
||||||
|
continue
|
||||||
|
elif input_dtype == "fp32":
|
||||||
|
# LoRA B matrix.
|
||||||
|
if output_dtype != weight_dtype:
|
||||||
|
# NOTE(woosuk): While Punica supports the case where the
|
||||||
|
# output and weight dtypes are different, we only generate
|
||||||
|
# the kernels the same dtypes to reduce the binary size.
|
||||||
|
continue
|
||||||
|
elif not (input_dtype == output_dtype == weight_dtype):
|
||||||
|
# NOTE(woosuk): While Punica supports mixed data types for
|
||||||
|
# input, output, and weight, we only generate the kernels with
|
||||||
|
# the same data types to reduce the binary size.
|
||||||
|
continue
|
||||||
|
|
||||||
kernel_definition = TEMPLATE.format(
|
kernel_definition = TEMPLATE.format(
|
||||||
input_dtype=DTYPE_MAP[input_dtype],
|
input_dtype=DTYPE_MAP[input_dtype],
|
||||||
output_dtype=DTYPE_MAP[output_dtype],
|
output_dtype=DTYPE_MAP[output_dtype],
|
||||||
|
@ -50,6 +50,23 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
|||||||
int64_t y_offset, int64_t full_y_size,
|
int64_t y_offset, int64_t full_y_size,
|
||||||
int64_t batch_size, int64_t num_layers,
|
int64_t batch_size, int64_t num_layers,
|
||||||
int64_t layer_idx, float scale) {
|
int64_t layer_idx, float scale) {
|
||||||
|
// NOTE(woosuk): While Punica supports various combinations of input/output
|
||||||
|
// data types, we limit the supported data types to reduce the binary size.
|
||||||
|
constexpr bool is_input_float = std::is_same<in_T, float>::value;
|
||||||
|
constexpr bool is_output_float = std::is_same<out_T, float>::value;
|
||||||
|
if (is_input_float) {
|
||||||
|
if (!std::is_same<out_T, W_T>::value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (is_output_float) {
|
||||||
|
if (!std::is_same<in_T, W_T>::value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (!(std::is_same<in_T, W_T>::value &&
|
||||||
|
std::is_same<out_T, W_T>::value)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
switch (pack_u32(in_features, out_features)) {
|
switch (pack_u32(in_features, out_features)) {
|
||||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||||
case pack_u32(feat_in, feat_out): \
|
case pack_u32(feat_in, feat_out): \
|
||||||
|
@ -413,7 +413,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
|||||||
|
|
||||||
def _pretest():
|
def _pretest():
|
||||||
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
||||||
1024, vocab_size)
|
1024,
|
||||||
|
vocab_size,
|
||||||
|
params_dtype=torch.float16)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
linear.weight.data[:, vocab_size:] = 0
|
linear.weight.data[:, vocab_size:] = 0
|
||||||
logits_processor = LogitsProcessor(
|
logits_processor = LogitsProcessor(
|
||||||
@ -445,7 +447,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
|||||||
num_inputs=8 * num_loras, # * 3,
|
num_inputs=8 * num_loras, # * 3,
|
||||||
input_size=(1, 1024),
|
input_size=(1, 1024),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float32,
|
input_type=torch.float16,
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
@ -494,7 +496,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
|||||||
num_inputs=8 * num_loras * 3,
|
num_inputs=8 * num_loras * 3,
|
||||||
input_size=(1, 1024),
|
input_size=(1, 1024),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float32,
|
input_type=torch.float16,
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
@ -533,11 +535,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
|||||||
|
|
||||||
def create_random_linear_parallel_layer():
|
def create_random_linear_parallel_layer():
|
||||||
if orientation == "row":
|
if orientation == "row":
|
||||||
linear = RowParallelLinear(4096, 4096, bias=False)
|
linear = RowParallelLinear(4096,
|
||||||
|
4096,
|
||||||
|
bias=False,
|
||||||
|
params_dtype=torch.float16)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
lora_linear = RowParallelLinearWithLoRA(linear)
|
lora_linear = RowParallelLinearWithLoRA(linear)
|
||||||
else:
|
else:
|
||||||
linear = ColumnParallelLinear(4096, 4096, bias=False)
|
linear = ColumnParallelLinear(4096,
|
||||||
|
4096,
|
||||||
|
bias=False,
|
||||||
|
params_dtype=torch.float16)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
lora_linear = ColumnParallelLinearWithLoRA(linear)
|
lora_linear = ColumnParallelLinearWithLoRA(linear)
|
||||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||||
@ -561,7 +569,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
|||||||
num_inputs=32 * num_loras,
|
num_inputs=32 * num_loras,
|
||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float32,
|
input_type=torch.float16,
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
@ -600,7 +608,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
|||||||
num_inputs=32 * num_loras,
|
num_inputs=32 * num_loras,
|
||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float32,
|
input_type=torch.float16,
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
@ -633,15 +641,24 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
|||||||
def create_column_parallel_packed_layer():
|
def create_column_parallel_packed_layer():
|
||||||
if repeats == 2:
|
if repeats == 2:
|
||||||
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
|
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
|
||||||
bias=False)
|
bias=False,
|
||||||
|
params_dtype=torch.float16)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
||||||
elif repeats == 3:
|
elif repeats == 3:
|
||||||
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
linear = QKVParallelLinear(4096,
|
||||||
|
64,
|
||||||
|
32,
|
||||||
|
bias=False,
|
||||||
|
params_dtype=torch.float16)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
lora_linear = MergedQKVParallelLinearWithLora(linear)
|
lora_linear = MergedQKVParallelLinearWithLora(linear)
|
||||||
else:
|
else:
|
||||||
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
linear = QKVParallelLinear(4096,
|
||||||
|
64,
|
||||||
|
32,
|
||||||
|
bias=False,
|
||||||
|
params_dtype=torch.float16)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
lora_linear = QKVParallelLinearWithLora(linear)
|
lora_linear = QKVParallelLinearWithLora(linear)
|
||||||
|
|
||||||
@ -676,7 +693,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
|||||||
num_inputs=32 * num_loras,
|
num_inputs=32 * num_loras,
|
||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float32,
|
input_type=torch.float16,
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
@ -716,7 +733,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
|||||||
num_inputs=32 * num_loras,
|
num_inputs=32 * num_loras,
|
||||||
input_size=(1, 4096),
|
input_size=(1, 4096),
|
||||||
input_range=(0, 1),
|
input_range=(0, 1),
|
||||||
input_type=torch.float32,
|
input_type=torch.float16,
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user