[Bugfix] bitsandbytes models fail to run pipeline parallel (#10200)

Signed-off-by: Hoang Cong Duc <hoangcongducltt@gmail.com>
This commit is contained in:
HoangCongDuc 2024-11-14 00:56:39 +08:00 committed by GitHub
parent 0b8bb86bf1
commit ac49b59d8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 1 deletions

View File

@ -9,7 +9,7 @@ import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from tests.utils import fork_new_process_for_each_test from tests.utils import compare_two_settings, fork_new_process_for_each_test
models_4bit_to_test = [ models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"), ("facebook/opt-125m", "quantize opt model inflight"),
@ -82,6 +82,34 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
vllm_tp_size=2) vllm_tp_size=2)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='Test requires at least 2 GPUs.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test
def test_load_pp_4bit_bnb_model(model_name, description) -> None:
common_args = [
"--disable-log-stats",
"--disable-log-requests",
"--dtype",
"bfloat16",
"--enable-prefix-caching",
"--quantization",
"bitsandbytes",
"--load-format",
"bitsandbytes",
"--gpu-memory-utilization",
"0.7",
]
pp_args = [
*common_args,
"--pipeline-parallel-size",
"2",
]
compare_two_settings(model_name, common_args, pp_args)
def log_generated_texts(prompts, outputs, runner_name): def log_generated_texts(prompts, outputs, runner_name):
logged_texts = [] logged_texts = []
for i, (_, generated_text) in enumerate(outputs): for i, (_, generated_text) in enumerate(outputs):

View File

@ -991,7 +991,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
param_dict = dict(model.named_parameters()) param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
for quant_param_name in quant_state_dict: for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
non_stacked_param_name = quant_param_name non_stacked_param_name = quant_param_name
shard_index = 0 shard_index = 0