Add extra punica sizes to support bigger vocabs (#4015)
This commit is contained in:
parent
95e7d4a97c
commit
1e96c3341a
@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 33024) \
|
f(in_T, out_T, W_T, narrow, 33024) \
|
||||||
f(in_T, out_T, W_T, narrow, 36864) \
|
f(in_T, out_T, W_T, narrow, 36864) \
|
||||||
f(in_T, out_T, W_T, narrow, 49152) \
|
f(in_T, out_T, W_T, narrow, 49152) \
|
||||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
f(in_T, out_T, W_T, narrow, 64000) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 64256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 64512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 102400) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 102656) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 102912) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128000) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128512) \
|
||||||
|
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||||
|
// and vllm/tests/lora/test_punica.py
|
||||||
|
|
||||||
// Keep this in sync with vllm/config::LoRAConfig
|
// Keep this in sync with vllm/config::LoRAConfig
|
||||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||||
|
@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
|
||||||
return (uint32_t(a) << 16) | uint32_t(b);
|
return (uint64_t(a) << 32) | uint64_t(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||||
@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
|||||||
template <typename in_T, typename out_T, typename W_T>
|
template <typename in_T, typename out_T, typename W_T>
|
||||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||||
const int64_t *lora_indices,
|
const int64_t *lora_indices,
|
||||||
uint16_t in_features, uint16_t out_features,
|
uint32_t in_features, uint32_t out_features,
|
||||||
int64_t y_offset, int64_t full_y_size,
|
int64_t y_offset, int64_t full_y_size,
|
||||||
int64_t batch_size, int64_t num_layers,
|
int64_t batch_size, int64_t num_layers,
|
||||||
int64_t layer_idx, float scale) {
|
int64_t layer_idx, float scale) {
|
||||||
switch (pack_u16(in_features, out_features)) {
|
switch (pack_u32(in_features, out_features)) {
|
||||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||||
case pack_u16(feat_in, feat_out): \
|
case pack_u32(feat_in, feat_out): \
|
||||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||||
full_y_size, batch_size, num_layers, \
|
full_y_size, batch_size, num_layers, \
|
||||||
layer_idx, scale); \
|
layer_idx, scale); \
|
||||||
@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||||||
CHECK_EQ(y.size(0), x.size(0));
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||||
bool ok = false;
|
bool ok = false;
|
||||||
if (h_in < 65536 && h_out < 65536) {
|
if (h_in <= 128512 && h_out <= 128512) {
|
||||||
// TODO: See if we can get rid of this massive nested switch
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
switch (x.scalar_type()) {
|
switch (x.scalar_type()) {
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||||||
CHECK_EQ(y.size(0), x.size(0));
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||||
bool ok = false;
|
bool ok = false;
|
||||||
if (h_in < 65536 && h_out < 65536) {
|
if (h_in <= 128512 && h_out <= 128512) {
|
||||||
// TODO: See if we can get rid of this massive nested switch
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
switch (x.scalar_type()) {
|
switch (x.scalar_type()) {
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
|
@ -170,7 +170,8 @@ def create_random_inputs(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_embeddings(dist_init, num_loras, device) -> None:
|
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||||
|
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
lora_dtype=torch.float16)
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
def create_random_embedding_layer():
|
def create_random_embedding_layer():
|
||||||
embedding = VocabParallelEmbedding(512, 256)
|
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||||
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
||||||
embedding.weight.data[512:, :] = 0
|
embedding.weight.data[vocab_size:, :] = 0
|
||||||
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
||||||
lora_embedding.create_lora_weights(max_loras, lora_config)
|
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
active_lora_ids=list(lora_dict.keys()),
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, 512),
|
input_range=(1, vocab_size),
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
512, lora_config.lora_extra_vocab_size)
|
vocab_size,
|
||||||
|
lora_config.lora_extra_vocab_size)
|
||||||
lora_embedding.set_mapping(*mapping_info)
|
lora_embedding.set_mapping(*mapping_info)
|
||||||
|
|
||||||
lora_result = lora_embedding(torch.cat(inputs))
|
lora_result = lora_embedding(torch.cat(inputs))
|
||||||
@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
active_lora_ids=[0],
|
active_lora_ids=[0],
|
||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, 512),
|
input_range=(1, vocab_size),
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
512, lora_config.lora_extra_vocab_size)
|
vocab_size,
|
||||||
|
lora_config.lora_extra_vocab_size)
|
||||||
lora_embedding.set_mapping(*mapping_info, )
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
lora_result = lora_embedding(torch.cat(inputs))
|
lora_result = lora_embedding(torch.cat(inputs))
|
||||||
@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
# reason="Fails when loras are in any slot other than the first.")
|
# reason="Fails when loras are in any slot other than the first.")
|
||||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||||
|
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||||
|
vocab_size) -> None:
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
lora_dtype=torch.float16)
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
def create_random_embedding_layer():
|
def create_random_embedding_layer():
|
||||||
embedding = VocabParallelEmbedding(512, 256)
|
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||||
embedding_data = torch.rand_like(embedding.weight.data)
|
embedding_data = torch.rand_like(embedding.weight.data)
|
||||||
embedding.weight.data = embedding_data
|
embedding.weight.data = embedding_data
|
||||||
embedding.weight.data[512:, :] = 0
|
embedding.weight.data[vocab_size:, :] = 0
|
||||||
expanded_embedding = VocabParallelEmbedding(
|
expanded_embedding = VocabParallelEmbedding(
|
||||||
512 + lora_config.lora_extra_vocab_size * max_loras,
|
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
|
||||||
256,
|
256,
|
||||||
org_num_embeddings=512)
|
org_num_embeddings=vocab_size)
|
||||||
expanded_embedding.weight.data[:512, :] = embedding_data
|
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
|
||||||
# We need to deepcopy the embedding as it will be modified
|
# We need to deepcopy the embedding as it will be modified
|
||||||
# in place
|
# in place
|
||||||
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
||||||
@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
id_to_index,
|
id_to_index,
|
||||||
layer=lora_embedding,
|
layer=lora_embedding,
|
||||||
layer_weights=torch.zeros(
|
layer_weights=torch.zeros(
|
||||||
(256, 512 + lora_config.lora_extra_vocab_size)),
|
(256, vocab_size + lora_config.lora_extra_vocab_size)),
|
||||||
generate_embeddings_tensor=256,
|
generate_embeddings_tensor=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
active_lora_ids=list(lora_dict.keys()),
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, 512),
|
input_range=(1, vocab_size),
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||||
prompt_mapping):
|
prompt_mapping):
|
||||||
embedding_id = lora_id - 1
|
embedding_id = lora_id - 1
|
||||||
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
|
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
|
||||||
original_input_[-1] = 512
|
original_input_[-1] = vocab_size
|
||||||
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
|
input_[-2] = vocab_size + (
|
||||||
original_input_[-2] = 512 + embeddings_tensor_len - 1
|
(embedding_id + 1) * embeddings_tensor_len - 1)
|
||||||
|
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
|
||||||
|
|
||||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
512, lora_config.lora_extra_vocab_size)
|
vocab_size,
|
||||||
|
lora_config.lora_extra_vocab_size)
|
||||||
lora_embedding.set_mapping(*mapping_info, )
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
expanded_embedding.weight[512:512 +
|
expanded_embedding.weight[vocab_size:vocab_size +
|
||||||
(embeddings_tensor_len *
|
(embeddings_tensor_len *
|
||||||
max_loras)] = torch.cat(embeddings_tensors)
|
max_loras)] = torch.cat(embeddings_tensors)
|
||||||
|
|
||||||
@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
active_lora_ids=[0],
|
active_lora_ids=[0],
|
||||||
num_inputs=num_loras * 3,
|
num_inputs=num_loras * 3,
|
||||||
input_size=(200, ),
|
input_size=(200, ),
|
||||||
input_range=(1, 512),
|
input_range=(1, vocab_size),
|
||||||
)
|
)
|
||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
original_inputs = deepcopy(inputs)
|
original_inputs = deepcopy(inputs)
|
||||||
|
|
||||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
512, lora_config.lora_extra_vocab_size)
|
vocab_size,
|
||||||
|
lora_config.lora_extra_vocab_size)
|
||||||
lora_embedding.set_mapping(*mapping_info, )
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
lora_result = lora_embedding(torch.cat(original_inputs))
|
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||||
@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||||
|
def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||||
|
vocab_size) -> None:
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
|||||||
lora_dtype=torch.float16)
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
def _pretest():
|
def _pretest():
|
||||||
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
||||||
1024, 32000)
|
1024, vocab_size)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
linear.weight.data[:, 32000:] = 0
|
linear.weight.data[:, vocab_size:] = 0
|
||||||
logits_processor = LogitsProcessor(
|
logits_processor = LogitsProcessor(
|
||||||
32000 + lora_config.lora_extra_vocab_size, 32000)
|
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
||||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||||
@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
|||||||
lora_mapping,
|
lora_mapping,
|
||||||
id_to_index,
|
id_to_index,
|
||||||
max_loras,
|
max_loras,
|
||||||
32000,
|
vocab_size,
|
||||||
lora_config.lora_extra_vocab_size,
|
lora_config.lora_extra_vocab_size,
|
||||||
)
|
)
|
||||||
lora_logits_processor.set_mapping(*mapping_info, )
|
lora_logits_processor.set_mapping(*mapping_info, )
|
||||||
@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
|||||||
org_vocab_size:logits_processor.org_vocab_size +
|
org_vocab_size:logits_processor.org_vocab_size +
|
||||||
embeddings_tensor_len] = embeddings_tensor
|
embeddings_tensor_len] = embeddings_tensor
|
||||||
|
|
||||||
logits_processor.org_vocab_size = (32000 +
|
logits_processor.org_vocab_size = (vocab_size +
|
||||||
lora_config.lora_extra_vocab_size)
|
lora_config.lora_extra_vocab_size)
|
||||||
expected_results = []
|
expected_results = []
|
||||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
|||||||
result = logits_processor._get_logits(hidden_states=input_,
|
result = logits_processor._get_logits(hidden_states=input_,
|
||||||
embedding=linear.weight,
|
embedding=linear.weight,
|
||||||
embedding_bias=None)
|
embedding_bias=None)
|
||||||
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
||||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
logits_processor.org_vocab_size = 32000
|
logits_processor.org_vocab_size = vocab_size
|
||||||
|
|
||||||
# Check that resetting the lora weights succeeds
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
|||||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
32000,
|
vocab_size,
|
||||||
lora_config.lora_extra_vocab_size)
|
lora_config.lora_extra_vocab_size)
|
||||||
lora_logits_processor.set_mapping(*mapping_info, )
|
lora_logits_processor.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
lora_result = lora_logits_processor._get_logits(
|
lora_result = lora_logits_processor._get_logits(
|
||||||
hidden_states=torch.cat(inputs),
|
hidden_states=torch.cat(inputs),
|
||||||
embedding=original_weight,
|
embedding=original_weight,
|
||||||
embedding_bias=None)[:, :32000]
|
embedding_bias=None)[:, :vocab_size]
|
||||||
expected_result = logits_processor._get_logits(
|
expected_result = logits_processor._get_logits(
|
||||||
hidden_states=torch.cat(inputs),
|
hidden_states=torch.cat(inputs),
|
||||||
embedding=original_weight,
|
embedding=original_weight,
|
||||||
|
@ -43,10 +43,51 @@ def _lora_ref_impl(
|
|||||||
|
|
||||||
|
|
||||||
H1 = H2 = [
|
H1 = H2 = [
|
||||||
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
|
128,
|
||||||
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
|
256,
|
||||||
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
|
512,
|
||||||
32768, 33024
|
1024,
|
||||||
|
1152,
|
||||||
|
1280,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
2304,
|
||||||
|
2560,
|
||||||
|
2752,
|
||||||
|
3072,
|
||||||
|
3456,
|
||||||
|
3584,
|
||||||
|
4096,
|
||||||
|
4608,
|
||||||
|
5120,
|
||||||
|
5504,
|
||||||
|
5632,
|
||||||
|
6144,
|
||||||
|
6848,
|
||||||
|
6912,
|
||||||
|
7168,
|
||||||
|
8192,
|
||||||
|
9216,
|
||||||
|
10240,
|
||||||
|
11008,
|
||||||
|
13824,
|
||||||
|
14336,
|
||||||
|
22016,
|
||||||
|
24576,
|
||||||
|
27392,
|
||||||
|
32000,
|
||||||
|
32256,
|
||||||
|
32512,
|
||||||
|
32768,
|
||||||
|
33024,
|
||||||
|
36864,
|
||||||
|
49152,
|
||||||
|
64000,
|
||||||
|
64256,
|
||||||
|
102400,
|
||||||
|
102656,
|
||||||
|
128000,
|
||||||
|
128256,
|
||||||
]
|
]
|
||||||
SEED = [0xabcdabcd987]
|
SEED = [0xabcdabcd987]
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
|
@ -935,9 +935,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
model_config: Optional[PretrainedConfig] = None,
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||||
if 32000 < self.base_layer.vocab_size > 33024:
|
if 32000 < self.base_layer.vocab_size > 128512:
|
||||||
raise ValueError("When using LoRA, vocab size must be "
|
raise ValueError("When using LoRA, vocab size must be "
|
||||||
"32000 >= vocab_size <= 33024")
|
"32000 >= vocab_size <= 128512")
|
||||||
self.lora_a_stacked = torch.zeros(
|
self.lora_a_stacked = torch.zeros(
|
||||||
(
|
(
|
||||||
max_loras,
|
max_loras,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user