[Bugfix] bitsandbytes models fail to run pipeline parallel (#10200)
Signed-off-by: Hoang Cong Duc <hoangcongducltt@gmail.com>
This commit is contained in:
parent
0b8bb86bf1
commit
ac49b59d8b
@ -9,7 +9,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from tests.utils import fork_new_process_for_each_test
|
||||
from tests.utils import compare_two_settings, fork_new_process_for_each_test
|
||||
|
||||
models_4bit_to_test = [
|
||||
("facebook/opt-125m", "quantize opt model inflight"),
|
||||
@ -82,6 +82,34 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
vllm_tp_size=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason='Test requires at least 2 GPUs.')
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
|
||||
@fork_new_process_for_each_test
|
||||
def test_load_pp_4bit_bnb_model(model_name, description) -> None:
|
||||
common_args = [
|
||||
"--disable-log-stats",
|
||||
"--disable-log-requests",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enable-prefix-caching",
|
||||
"--quantization",
|
||||
"bitsandbytes",
|
||||
"--load-format",
|
||||
"bitsandbytes",
|
||||
"--gpu-memory-utilization",
|
||||
"0.7",
|
||||
]
|
||||
pp_args = [
|
||||
*common_args,
|
||||
"--pipeline-parallel-size",
|
||||
"2",
|
||||
]
|
||||
compare_two_settings(model_name, common_args, pp_args)
|
||||
|
||||
|
||||
def log_generated_texts(prompts, outputs, runner_name):
|
||||
logged_texts = []
|
||||
for i, (_, generated_text) in enumerate(outputs):
|
||||
|
@ -991,7 +991,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
shard_index = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user