Add LoRA support for Gemma (#3050)

This commit is contained in:
Woosuk Kwon 2024-02-28 13:03:28 -08:00 committed by GitHub
parent 3b7178cfa4
commit 929b4f2973
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 82 additions and 7 deletions

View File

@ -50,7 +50,7 @@ steps:
command: pytest -v -s worker
- label: LoRA Test
command: pytest -v -s lora
command: pytest -v -s lora --forked
- label: Metrics Test
command: pytest -v -s metrics

View File

@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \

View File

@ -126,6 +126,11 @@ def mixtral_lora_files():
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
@pytest.fixture(scope="session")
def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()

46
tests/lora/test_gemma.py Normal file
View File

@ -0,0 +1,46 @@
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "google/gemma-7b"
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
"Quote: Imagination is",
"Quote: Be yourself;",
"Quote: So many books,",
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4)
expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
"so little time\nAuthor: Frank Zappa\n",
]
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])
output2 = do_sample(llm, gemma_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i].startswith(expected_lora_output[i])

View File

@ -44,8 +44,8 @@ def _lora_ref_impl(
H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000,
32256, 32512, 32768, 33024
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
24576, 32000, 32256, 32512, 32768, 33024
]
SEED = [0xabcdabcd987]

View File

@ -20,6 +20,7 @@ import torch
from torch import nn
from transformers import GemmaConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import PagedAttention
@ -246,12 +247,36 @@ class GemmaModel(nn.Module):
class GemmaForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
self.linear_method = linear_method
@ -305,9 +330,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra layer for lora models.
if "lm_head" in name:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:

View File

@ -27,6 +27,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]