Add extra punica sizes to support bigger vocabs (#4015)

This commit is contained in:
Antoni Baum 2024-04-11 15:18:57 -07:00 committed by GitHub
parent 95e7d4a97c
commit 1e96c3341a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 109 additions and 48 deletions

View File

@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \

View File

@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
return (uint64_t(a) << 32) | uint64_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint16_t in_features, uint16_t out_features,
uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
switch (pack_u16(in_features, out_features)) {
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \
case pack_u32(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:

View File

@ -170,7 +170,8 @@ def create_random_inputs(
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings(dist_init, num_loras, device) -> None:
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
torch.set_default_device(device)
max_loras = 8
@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)
def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256)
embedding = VocabParallelEmbedding(vocab_size, 256)
embedding.weight.data = torch.rand_like(embedding.weight.data)
embedding.weight.data[512:, :] = 0
embedding.weight.data[vocab_size:, :] = 0
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
lora_embedding.create_lora_weights(max_loras, lora_config)
@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info)
lora_result = lora_embedding(torch.cat(inputs))
@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(inputs))
@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size) -> None:
torch.set_default_device(device)
max_loras = 8
@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)
def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256)
embedding = VocabParallelEmbedding(vocab_size, 256)
embedding_data = torch.rand_like(embedding.weight.data)
embedding.weight.data = embedding_data
embedding.weight.data[512:, :] = 0
embedding.weight.data[vocab_size:, :] = 0
expanded_embedding = VocabParallelEmbedding(
512 + lora_config.lora_extra_vocab_size * max_loras,
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
256,
org_num_embeddings=512)
expanded_embedding.weight.data[:512, :] = embedding_data
org_num_embeddings=vocab_size)
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
# We need to deepcopy the embedding as it will be modified
# in place
lora_embedding = VocabParallelEmbeddingWithLoRA(
@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
id_to_index,
layer=lora_embedding,
layer_weights=torch.zeros(
(256, 512 + lora_config.lora_extra_vocab_size)),
(256, vocab_size + lora_config.lora_extra_vocab_size)),
generate_embeddings_tensor=256,
)
@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping):
embedding_id = lora_id - 1
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
original_input_[-1] = 512
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = 512 + embeddings_tensor_len - 1
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
original_input_[-1] = vocab_size
input_[-2] = vocab_size + (
(embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
expanded_embedding.weight[512:512 +
expanded_embedding.weight[vocab_size:vocab_size +
(embeddings_tensor_len *
max_loras)] = torch.cat(embeddings_tensors)
@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
original_inputs = deepcopy(inputs)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(original_inputs))
@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_lm_head_logits_processor(dist_init, num_loras, device,
vocab_size) -> None:
torch.set_default_device(device)
max_loras = 8
@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)
def _pretest():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
1024, 32000)
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
1024, vocab_size)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0
linear.weight.data[:, vocab_size:] = 0
logits_processor = LogitsProcessor(
32000 + lora_config.lora_extra_vocab_size, 32000)
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config)
@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping,
id_to_index,
max_loras,
32000,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_logits_processor.set_mapping(*mapping_info, )
@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor
logits_processor.org_vocab_size = (32000 +
logits_processor.org_vocab_size = (vocab_size +
lora_config.lora_extra_vocab_size)
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = 32000
logits_processor.org_vocab_size = vocab_size
# Check that resetting the lora weights succeeds
@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000,
vocab_size,
lora_config.lora_extra_vocab_size)
lora_logits_processor.set_mapping(*mapping_info, )
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
embedding_bias=None)[:, :vocab_size]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,

View File

@ -43,10 +43,51 @@ def _lora_ref_impl(
H1 = H2 = [
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
32768, 33024
128,
256,
512,
1024,
1152,
1280,
1536,
2048,
2304,
2560,
2752,
3072,
3456,
3584,
4096,
4608,
5120,
5504,
5632,
6144,
6848,
6912,
7168,
8192,
9216,
10240,
11008,
13824,
14336,
22016,
24576,
27392,
32000,
32256,
32512,
32768,
33024,
36864,
49152,
64000,
64256,
102400,
102656,
128000,
128256,
]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [

View File

@ -935,9 +935,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024:
if 32000 < self.base_layer.vocab_size > 128512:
raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 33024")
"32000 >= vocab_size <= 128512")
self.lora_a_stacked = torch.zeros(
(
max_loras,