[Feature][kernel] tensor parallelism with bitsandbytes quantization (#8434)
This commit is contained in:
parent
1009e93c5d
commit
9855b99502
@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason='Test requires at least 2 GPUs.')
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
|
||||
@fork_new_process_for_each_test
|
||||
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
hf_model_kwargs = {"load_in_4bit": True}
|
||||
validate_generated_texts(hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts[:1],
|
||||
model_name,
|
||||
hf_model_kwargs,
|
||||
vllm_tp_size=2)
|
||||
|
||||
|
||||
def log_generated_texts(prompts, outputs, runner_name):
|
||||
logged_texts = []
|
||||
for i, (_, generated_text) in enumerate(outputs):
|
||||
@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
|
||||
vllm_runner,
|
||||
prompts,
|
||||
model_name,
|
||||
hf_model_kwargs=None):
|
||||
hf_model_kwargs=None,
|
||||
vllm_tp_size=1):
|
||||
|
||||
# NOTE: run vLLM first, as it requires a clean process
|
||||
# when using distributed inference
|
||||
|
||||
#Run with vLLM runner
|
||||
with vllm_runner(model_name,
|
||||
quantization='bitsandbytes',
|
||||
load_format='bitsandbytes',
|
||||
tensor_parallel_size=vllm_tp_size,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8) as llm:
|
||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
|
||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -393,12 +393,6 @@ class ModelConfig:
|
||||
"Pipeline parallelism is only supported for the following "
|
||||
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
||||
|
||||
if self.quantization == "bitsandbytes" and (
|
||||
parallel_config.tensor_parallel_size > 1
|
||||
or parallel_config.pipeline_parallel_size > 1):
|
||||
raise ValueError(
|
||||
"BitAndBytes quantization with TP or PP is not supported yet.")
|
||||
|
||||
# Remove the constraint after the bitsandbytes issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
|
||||
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
|
||||
|
@ -530,8 +530,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
@ -899,8 +902,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
@ -1000,6 +1008,7 @@ class RowParallelLinear(LinearBase):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
@ -1015,7 +1024,9 @@ class RowParallelLinear(LinearBase):
|
||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if input_dim is not None:
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if input_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
|
@ -22,6 +22,8 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -689,6 +691,8 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
# TODO: these module names are for Llama only,
|
||||
# change so that it works with other models as well
|
||||
default_target_modules = [
|
||||
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
|
||||
"o_proj"
|
||||
@ -911,13 +915,44 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if any(target_module in weight_name
|
||||
for target_module in self.target_modules):
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
|
||||
# weight partitions of different modules occur at
|
||||
# different dimensions
|
||||
# TODO: these module names are for Llama only,
|
||||
# change so that it works with other models as well
|
||||
if 'down_proj' in weight_name or 'o_proj' in weight_name:
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[...,
|
||||
start_index:end_index]
|
||||
|
||||
else:
|
||||
total_size = weight_tensor.size(0)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[start_index:end_index,
|
||||
...]
|
||||
|
||||
# bitsandbytes requires data in GPU
|
||||
loaded_weight = weight_tensor.cuda().data
|
||||
if weight_sub_tensor.is_cuda:
|
||||
loaded_weight = weight_sub_tensor
|
||||
else:
|
||||
loaded_weight = weight_sub_tensor.cuda()
|
||||
|
||||
# remove the following after the issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||
if loaded_weight.is_contiguous() is False:
|
||||
loaded_weight = loaded_weight.contiguous()
|
||||
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
@ -958,6 +993,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
f"BitsAndBytes loader does not support {quant_method} "
|
||||
"quantization")
|
||||
|
||||
# The quant_states in pre_quantized models cannot work with a split
|
||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||
if pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with TP is not supported."
|
||||
"Please try with PP.")
|
||||
|
||||
load_8bit = False
|
||||
if pre_quant:
|
||||
load_8bit = quant_config.get('load_in_8bit', False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user