[Bugfix] Fix JambaForCausalLM LoRA (#14370)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
e5e03c2c1b
commit
ddd1ef66ec
@ -6,7 +6,6 @@ from typing import TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules():
|
||||
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def jamba_lora_files():
|
||||
# some of the adapters have unnecessary weights for serving,
|
||||
# hence we remove them
|
||||
def remove_unnecessary_weights(path):
|
||||
lora_path = f"{adapter_path}/adapter_model.safetensors"
|
||||
tensors = safetensors.torch.load_file(lora_path)
|
||||
nonlora_keys = []
|
||||
for k in list(tensors.keys()):
|
||||
if "lora" not in k:
|
||||
nonlora_keys.append(k)
|
||||
for k in nonlora_keys:
|
||||
del tensors[k]
|
||||
safetensors.torch.save_file(tensors, lora_path)
|
||||
|
||||
adapter_path = snapshot_download(
|
||||
repo_id=
|
||||
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
|
||||
|
||||
remove_unnecessary_weights(adapter_path)
|
||||
return adapter_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def gemma_lora_files():
|
||||
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
|
||||
|
@ -1,54 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
|
||||
|
||||
MAX_TOKENS = 40
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
|
||||
prompts: list[str]) -> list[str]:
|
||||
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [4])
|
||||
def test_jamba_lora(jamba_lora_files, tp_size):
|
||||
"""Original test, the LoRA model has the common target modules, not all"""
|
||||
if torch.cuda.device_count() < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
prompts = ["Write a story about a sheep and a goat."]
|
||||
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
distributed_executor_backend="ray",
|
||||
tensor_parallel_size=tp_size,
|
||||
)
|
||||
|
||||
expected_jamba_output = [
|
||||
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
|
||||
]
|
||||
assert do_sample(llm, jamba_lora_files, lora_id=1,
|
||||
prompts=prompts) == expected_jamba_output
|
@ -632,6 +632,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, lora_linear = create_random_linear_replicated_layer()
|
||||
assert torch.equal(linear.weight, lora_linear.weight)
|
||||
lora_linear.set_mapping(punica_wrapper)
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
@ -757,6 +758,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, lora_linear = create_random_linear_parallel_layer()
|
||||
assert torch.equal(linear.weight, lora_linear.weight)
|
||||
lora_linear.set_mapping(punica_wrapper)
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
@ -904,6 +906,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
|
||||
linear, lora_linear = create_column_parallel_packed_layer()
|
||||
assert torch.equal(linear.weight, lora_linear.weight)
|
||||
lora_linear.set_mapping(punica_wrapper)
|
||||
lora_dict, sublora_dict = populate_loras(
|
||||
id_to_index,
|
||||
|
@ -274,6 +274,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_layer.weight
|
||||
|
||||
|
||||
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
@ -409,6 +413,34 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
self.output_slices)
|
||||
return output
|
||||
|
||||
@property
|
||||
def weight(self) -> torch.Tensor:
|
||||
|
||||
# unquantizedLinear
|
||||
if hasattr(self.base_layer, "weight"):
|
||||
return self.base_layer.weight
|
||||
# Compressed Tensor
|
||||
elif hasattr(self.base_layer, "weight_packed"):
|
||||
return self.base_layer.weight_packed
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(self.base_layer, "qweight"):
|
||||
return self.base_layer.qweight
|
||||
# marlin
|
||||
elif hasattr(self.base_layer, "B"):
|
||||
return self.base_layer.B
|
||||
# HQQ marlin
|
||||
elif hasattr(self.base_layer, "W_q"):
|
||||
return self.base_layer.W_q
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {self.base_layer}")
|
||||
|
||||
@property
|
||||
def bias(self) -> Optional[torch.Tensor]:
|
||||
if hasattr(self.base_layer, "bias"):
|
||||
return self.base_layer.bias
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
|
||||
@ -902,11 +934,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
|
||||
return output, output_bias
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
|
||||
else self.base_layer.qweight)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
|
Loading…
x
Reference in New Issue
Block a user