2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-04-12 12:02:44 +08:00
|
|
|
# Adapted from
|
|
|
|
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
import vllm
|
2024-10-18 14:30:55 -07:00
|
|
|
from vllm.distributed import cleanup_dist_env_and_memory
|
2024-04-12 12:02:44 +08:00
|
|
|
from vllm.lora.request import LoRARequest
|
2024-10-28 12:07:00 +08:00
|
|
|
from vllm.platforms import current_platform
|
2024-04-12 12:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ModelWithQuantization:
|
|
|
|
model_path: str
|
|
|
|
quantization: str
|
|
|
|
|
|
|
|
|
2025-03-03 01:34:51 +00:00
|
|
|
MODELS: list[ModelWithQuantization]
|
2024-09-04 14:57:54 -04:00
|
|
|
#AWQ quantization is currently not supported in ROCm.
|
2024-10-28 12:07:00 +08:00
|
|
|
if current_platform.is_rocm():
|
2024-09-04 14:57:54 -04:00
|
|
|
MODELS = [
|
|
|
|
ModelWithQuantization(
|
|
|
|
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
|
|
|
quantization="GPTQ"),
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
MODELS = [
|
|
|
|
ModelWithQuantization(
|
|
|
|
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
|
|
|
quantization="AWQ"),
|
|
|
|
ModelWithQuantization(
|
|
|
|
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
|
|
|
quantization="GPTQ"),
|
|
|
|
]
|
2024-04-12 12:02:44 +08:00
|
|
|
|
|
|
|
|
2025-04-01 04:53:56 -04:00
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def v1(run_with_both_engines_lora):
|
|
|
|
# Simple autouse wrapper to run both engines for each test
|
|
|
|
# This can be promoted up to conftest.py to run for every
|
|
|
|
# test in a package
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
def do_sample(llm: vllm.LLM,
|
|
|
|
lora_path: str,
|
|
|
|
lora_id: int,
|
2025-03-03 01:34:51 +00:00
|
|
|
max_tokens: int = 256) -> list[str]:
|
2024-04-12 12:02:44 +08:00
|
|
|
raw_prompts = [
|
|
|
|
"Give me an orange-ish brown color",
|
|
|
|
"Give me a neon pink color",
|
|
|
|
]
|
|
|
|
|
|
|
|
def format_prompt_tuples(prompt):
|
|
|
|
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
|
|
|
|
|
|
|
prompts = [format_prompt_tuples(p) for p in raw_prompts]
|
|
|
|
|
|
|
|
sampling_params = vllm.SamplingParams(temperature=0,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
stop=["<|im_end|>"])
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts,
|
|
|
|
sampling_params,
|
|
|
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
|
|
|
if lora_id else None)
|
|
|
|
# Print the outputs.
|
2025-03-03 01:34:51 +00:00
|
|
|
generated_texts: list[str] = []
|
2024-04-12 12:02:44 +08:00
|
|
|
for output in outputs:
|
|
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
generated_texts.append(generated_text)
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
return generated_texts
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
|
|
@pytest.mark.parametrize("tp_size", [1])
|
2024-09-29 10:50:51 +08:00
|
|
|
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
|
|
|
|
tp_size):
|
2025-01-12 13:01:52 +00:00
|
|
|
if num_gpus_available < tp_size and \
|
|
|
|
tp_size > 1 and current_platform.is_cuda_alike():
|
2024-09-29 10:50:51 +08:00
|
|
|
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
2024-04-12 12:02:44 +08:00
|
|
|
|
2024-08-01 08:12:24 +08:00
|
|
|
llm = vllm.LLM(
|
|
|
|
model=model.model_path,
|
|
|
|
enable_lora=True,
|
|
|
|
max_num_seqs=16,
|
|
|
|
max_loras=4,
|
|
|
|
max_model_len=400,
|
|
|
|
tensor_parallel_size=tp_size,
|
|
|
|
gpu_memory_utilization=0.2, #avoid OOM
|
|
|
|
quantization=model.quantization,
|
2024-12-10 21:09:20 -05:00
|
|
|
trust_remote_code=True,
|
|
|
|
enable_chunked_prefill=True)
|
2024-04-12 12:02:44 +08:00
|
|
|
|
|
|
|
if model.quantization is None:
|
|
|
|
expected_no_lora_output = [
|
|
|
|
"Here are some examples of orange-brown colors",
|
|
|
|
"I'm sorry, I don't have"
|
|
|
|
]
|
|
|
|
expected_lora_output = [
|
|
|
|
"#ff8050",
|
|
|
|
"#ff8080",
|
|
|
|
]
|
|
|
|
elif model.quantization == "AWQ":
|
|
|
|
expected_no_lora_output = [
|
|
|
|
"I'm sorry, I don't understand",
|
|
|
|
"I'm sorry, I don't understand",
|
|
|
|
]
|
|
|
|
expected_lora_output = [
|
|
|
|
"#f07700: A v",
|
|
|
|
"#f00000: A v",
|
|
|
|
]
|
|
|
|
elif model.quantization == "GPTQ":
|
|
|
|
expected_no_lora_output = [
|
|
|
|
"I'm sorry, I don't have",
|
|
|
|
"I'm sorry, I don't have",
|
|
|
|
]
|
|
|
|
expected_lora_output = [
|
|
|
|
"#f08800: This is",
|
|
|
|
"#f07788 \n#",
|
|
|
|
]
|
|
|
|
|
|
|
|
def expect_match(output, expected_output):
|
|
|
|
# HACK: GPTQ lora outputs are just incredibly unstable.
|
|
|
|
# Assert that the outputs changed.
|
|
|
|
if (model.quantization == "GPTQ"
|
|
|
|
and expected_output is expected_lora_output):
|
|
|
|
assert output != expected_no_lora_output
|
|
|
|
for i, o in enumerate(output):
|
|
|
|
assert o.startswith(
|
|
|
|
'#'), f"Expected example {i} to start with # but got {o}"
|
|
|
|
return
|
|
|
|
assert output == expected_output
|
|
|
|
|
|
|
|
max_tokens = 10
|
|
|
|
|
|
|
|
print("lora adapter created")
|
|
|
|
output = do_sample(llm,
|
|
|
|
tinyllama_lora_files,
|
|
|
|
lora_id=0,
|
|
|
|
max_tokens=max_tokens)
|
|
|
|
expect_match(output, expected_no_lora_output)
|
|
|
|
|
|
|
|
print("lora 1")
|
|
|
|
output = do_sample(llm,
|
|
|
|
tinyllama_lora_files,
|
|
|
|
lora_id=1,
|
|
|
|
max_tokens=max_tokens)
|
|
|
|
expect_match(output, expected_lora_output)
|
|
|
|
|
|
|
|
print("no lora")
|
|
|
|
output = do_sample(llm,
|
|
|
|
tinyllama_lora_files,
|
|
|
|
lora_id=0,
|
|
|
|
max_tokens=max_tokens)
|
|
|
|
expect_match(output, expected_no_lora_output)
|
|
|
|
|
|
|
|
print("lora 2")
|
|
|
|
output = do_sample(llm,
|
|
|
|
tinyllama_lora_files,
|
|
|
|
lora_id=2,
|
|
|
|
max_tokens=max_tokens)
|
|
|
|
expect_match(output, expected_lora_output)
|
|
|
|
|
|
|
|
print("removing lora")
|
|
|
|
|
|
|
|
del llm
|
2024-10-18 14:30:55 -07:00
|
|
|
cleanup_dist_env_and_memory()
|
2024-04-12 12:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
2024-09-29 10:50:51 +08:00
|
|
|
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
|
|
|
|
model):
|
|
|
|
if num_gpus_available < 2:
|
|
|
|
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
2025-03-07 18:30:55 +08:00
|
|
|
if model.quantization == "GPTQ":
|
|
|
|
pytest.skip("GPTQ lora outputs are just incredibly unstable")
|
2024-08-01 08:12:24 +08:00
|
|
|
llm_tp1 = vllm.LLM(
|
|
|
|
model=model.model_path,
|
|
|
|
enable_lora=True,
|
|
|
|
max_num_seqs=16,
|
|
|
|
max_loras=4,
|
|
|
|
tensor_parallel_size=1,
|
|
|
|
gpu_memory_utilization=0.2, #avoid OOM
|
|
|
|
quantization=model.quantization,
|
2024-12-10 21:09:20 -05:00
|
|
|
trust_remote_code=True,
|
|
|
|
enable_chunked_prefill=True)
|
2024-04-12 12:02:44 +08:00
|
|
|
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
|
|
|
|
|
|
|
|
del llm_tp1
|
2024-10-18 14:30:55 -07:00
|
|
|
cleanup_dist_env_and_memory()
|
2024-04-12 12:02:44 +08:00
|
|
|
|
2024-08-01 08:12:24 +08:00
|
|
|
llm_tp2 = vllm.LLM(
|
|
|
|
model=model.model_path,
|
|
|
|
enable_lora=True,
|
|
|
|
max_num_seqs=16,
|
|
|
|
max_loras=4,
|
|
|
|
tensor_parallel_size=2,
|
|
|
|
gpu_memory_utilization=0.2, #avoid OOM
|
2024-12-10 21:09:20 -05:00
|
|
|
quantization=model.quantization,
|
|
|
|
enable_chunked_prefill=True)
|
2024-04-12 12:02:44 +08:00
|
|
|
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
|
|
|
|
|
|
|
|
del llm_tp2
|
2024-10-18 14:30:55 -07:00
|
|
|
cleanup_dist_env_and_memory()
|
2024-04-12 12:02:44 +08:00
|
|
|
|
|
|
|
assert output_tp1 == output_tp2
|