[Bugfix] bitsandbytes models fail to run pipeline parallel (#10200)
Signed-off-by: Hoang Cong Duc <hoangcongducltt@gmail.com>
This commit is contained in:
parent
0b8bb86bf1
commit
ac49b59d8b
@ -9,7 +9,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from tests.utils import fork_new_process_for_each_test
|
from tests.utils import compare_two_settings, fork_new_process_for_each_test
|
||||||
|
|
||||||
models_4bit_to_test = [
|
models_4bit_to_test = [
|
||||||
("facebook/opt-125m", "quantize opt model inflight"),
|
("facebook/opt-125m", "quantize opt model inflight"),
|
||||||
@ -82,6 +82,34 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
|||||||
vllm_tp_size=2)
|
vllm_tp_size=2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason='Test requires at least 2 GPUs.')
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||||
|
reason='bitsandbytes is not supported on this GPU type.')
|
||||||
|
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_load_pp_4bit_bnb_model(model_name, description) -> None:
|
||||||
|
common_args = [
|
||||||
|
"--disable-log-stats",
|
||||||
|
"--disable-log-requests",
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--enable-prefix-caching",
|
||||||
|
"--quantization",
|
||||||
|
"bitsandbytes",
|
||||||
|
"--load-format",
|
||||||
|
"bitsandbytes",
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.7",
|
||||||
|
]
|
||||||
|
pp_args = [
|
||||||
|
*common_args,
|
||||||
|
"--pipeline-parallel-size",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
compare_two_settings(model_name, common_args, pp_args)
|
||||||
|
|
||||||
|
|
||||||
def log_generated_texts(prompts, outputs, runner_name):
|
def log_generated_texts(prompts, outputs, runner_name):
|
||||||
logged_texts = []
|
logged_texts = []
|
||||||
for i, (_, generated_text) in enumerate(outputs):
|
for i, (_, generated_text) in enumerate(outputs):
|
||||||
|
@ -991,7 +991,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
param_dict = dict(model.named_parameters())
|
param_dict = dict(model.named_parameters())
|
||||||
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
||||||
|
# TODO: Change this lazy import to normal import
|
||||||
|
# after the checks are updated to run on a new version
|
||||||
|
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||||
for quant_param_name in quant_state_dict:
|
for quant_param_name in quant_state_dict:
|
||||||
|
if is_pp_missing_parameter(quant_param_name, model):
|
||||||
|
continue
|
||||||
|
|
||||||
non_stacked_param_name = quant_param_name
|
non_stacked_param_name = quant_param_name
|
||||||
|
|
||||||
shard_index = 0
|
shard_index = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user