Add LoRA support for Gemma (#3050)
This commit is contained in:
parent
3b7178cfa4
commit
929b4f2973
@ -50,7 +50,7 @@ steps:
|
|||||||
command: pytest -v -s worker
|
command: pytest -v -s worker
|
||||||
|
|
||||||
- label: LoRA Test
|
- label: LoRA Test
|
||||||
command: pytest -v -s lora
|
command: pytest -v -s lora --forked
|
||||||
|
|
||||||
- label: Metrics Test
|
- label: Metrics Test
|
||||||
command: pytest -v -s metrics
|
command: pytest -v -s metrics
|
||||||
|
@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 5120) \
|
f(in_T, out_T, W_T, narrow, 5120) \
|
||||||
f(in_T, out_T, W_T, narrow, 5504) \
|
f(in_T, out_T, W_T, narrow, 5504) \
|
||||||
f(in_T, out_T, W_T, narrow, 5632) \
|
f(in_T, out_T, W_T, narrow, 5632) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 6144) \
|
||||||
f(in_T, out_T, W_T, narrow, 6912) \
|
f(in_T, out_T, W_T, narrow, 6912) \
|
||||||
f(in_T, out_T, W_T, narrow, 7168) \
|
f(in_T, out_T, W_T, narrow, 7168) \
|
||||||
f(in_T, out_T, W_T, narrow, 8192) \
|
f(in_T, out_T, W_T, narrow, 8192) \
|
||||||
@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 14336) \
|
f(in_T, out_T, W_T, narrow, 14336) \
|
||||||
f(in_T, out_T, W_T, narrow, 16384) \
|
f(in_T, out_T, W_T, narrow, 16384) \
|
||||||
f(in_T, out_T, W_T, narrow, 20480) \
|
f(in_T, out_T, W_T, narrow, 20480) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 24576) \
|
||||||
f(in_T, out_T, W_T, narrow, 28672) \
|
f(in_T, out_T, W_T, narrow, 28672) \
|
||||||
f(in_T, out_T, W_T, narrow, 32000) \
|
f(in_T, out_T, W_T, narrow, 32000) \
|
||||||
f(in_T, out_T, W_T, narrow, 32256) \
|
f(in_T, out_T, W_T, narrow, 32256) \
|
||||||
|
@ -126,6 +126,11 @@ def mixtral_lora_files():
|
|||||||
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
|
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def gemma_lora_files():
|
||||||
|
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||||
cleanup()
|
cleanup()
|
||||||
|
46
tests/lora/test_gemma.py
Normal file
46
tests/lora/test_gemma.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import vllm
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
MODEL_PATH = "google/gemma-7b"
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(llm, lora_path: str, lora_id: int) -> str:
|
||||||
|
prompts = [
|
||||||
|
"Quote: Imagination is",
|
||||||
|
"Quote: Be yourself;",
|
||||||
|
"Quote: So many books,",
|
||||||
|
]
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||||
|
if lora_id else None)
|
||||||
|
# Print the outputs.
|
||||||
|
generated_texts = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text.strip()
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemma_lora(gemma_lora_files):
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
max_model_len=1024,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=4)
|
||||||
|
|
||||||
|
expected_lora_output = [
|
||||||
|
"more important than knowledge.\nAuthor: Albert Einstein\n",
|
||||||
|
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
|
||||||
|
"so little time\nAuthor: Frank Zappa\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
|
||||||
|
for i in range(len(expected_lora_output)):
|
||||||
|
assert output1[i].startswith(expected_lora_output[i])
|
||||||
|
output2 = do_sample(llm, gemma_lora_files, lora_id=2)
|
||||||
|
for i in range(len(expected_lora_output)):
|
||||||
|
assert output2[i].startswith(expected_lora_output[i])
|
@ -44,8 +44,8 @@ def _lora_ref_impl(
|
|||||||
|
|
||||||
H1 = H2 = [
|
H1 = H2 = [
|
||||||
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
|
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
|
||||||
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000,
|
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
|
||||||
32256, 32512, 32768, 33024
|
24576, 32000, 32256, 32512, 32768, 33024
|
||||||
]
|
]
|
||||||
SEED = [0xabcdabcd987]
|
SEED = [0xabcdabcd987]
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GemmaConfig
|
from transformers import GemmaConfig
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
@ -246,12 +247,36 @@ class GemmaModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GemmaForCausalLM(nn.Module):
|
class GemmaForCausalLM(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
# Gemma does not apply LoRA to the embedding layer.
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GemmaConfig,
|
config: GemmaConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
del lora_config # Unused.
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
@ -305,9 +330,6 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra layer for lora models.
|
|
||||||
if "lm_head" in name:
|
|
||||||
continue
|
|
||||||
# GemmaRMSNorm is different from Llama's in that it multiplies
|
# GemmaRMSNorm is different from Llama's in that it multiplies
|
||||||
# (1 + weight) to the output, instead of just weight.
|
# (1 + weight) to the output, instead of just weight.
|
||||||
if "norm.weight" in name:
|
if "norm.weight" in name:
|
||||||
|
@ -27,6 +27,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.config import LoRAConfig
|
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user