[Bugfix] Fix JambaForCausalLM LoRA (#14370)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-03-07 14:05:47 +08:00 committed by GitHub
parent e5e03c2c1b
commit ddd1ef66ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 83 deletions

View File

@ -6,7 +6,6 @@ from typing import TypedDict
from unittest.mock import MagicMock, patch
import pytest
import safetensors
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)
adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
remove_unnecessary_weights(adapter_path)
return adapter_path
@pytest.fixture(scope="session")
def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")

View File

@ -1,54 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
MAX_TOKENS = 40
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: list[str]) -> list[str]:
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [4])
def test_jamba_lora(jamba_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = ["Write a story about a sheep and a goat."]
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)
expected_jamba_output = [
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
]
assert do_sample(llm, jamba_lora_files, lora_id=1,
prompts=prompts) == expected_jamba_output

View File

@ -632,6 +632,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras(
id_to_index,
@ -757,6 +758,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras(
id_to_index,
@ -904,6 +906,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_column_parallel_packed_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper)
lora_dict, sublora_dict = populate_loras(
id_to_index,

View File

@ -274,6 +274,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
) -> bool:
return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):
return self.base_layer.weight
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
@ -409,6 +413,34 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.output_slices)
return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
# HQQ marlin
elif hasattr(self.base_layer, "W_q"):
return self.base_layer.W_q
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> Optional[torch.Tensor]:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
@ -902,11 +934,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
return output, output_bias
@property
def weight(self):
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
else self.base_layer.qweight)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(