[Feature][kernel] tensor parallelism with bitsandbytes quantization (#8434)
This commit is contained in:
parent
1009e93c5d
commit
9855b99502
@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
|||||||
model_name)
|
model_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason='Test requires at least 2 GPUs.')
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||||
|
reason='bitsandbytes is not supported on this GPU type.')
|
||||||
|
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||||
|
model_name, description) -> None:
|
||||||
|
|
||||||
|
hf_model_kwargs = {"load_in_4bit": True}
|
||||||
|
validate_generated_texts(hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts[:1],
|
||||||
|
model_name,
|
||||||
|
hf_model_kwargs,
|
||||||
|
vllm_tp_size=2)
|
||||||
|
|
||||||
|
|
||||||
def log_generated_texts(prompts, outputs, runner_name):
|
def log_generated_texts(prompts, outputs, runner_name):
|
||||||
logged_texts = []
|
logged_texts = []
|
||||||
for i, (_, generated_text) in enumerate(outputs):
|
for i, (_, generated_text) in enumerate(outputs):
|
||||||
@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
|
|||||||
vllm_runner,
|
vllm_runner,
|
||||||
prompts,
|
prompts,
|
||||||
model_name,
|
model_name,
|
||||||
hf_model_kwargs=None):
|
hf_model_kwargs=None,
|
||||||
|
vllm_tp_size=1):
|
||||||
|
|
||||||
# NOTE: run vLLM first, as it requires a clean process
|
# NOTE: run vLLM first, as it requires a clean process
|
||||||
# when using distributed inference
|
# when using distributed inference
|
||||||
|
|
||||||
#Run with vLLM runner
|
|
||||||
with vllm_runner(model_name,
|
with vllm_runner(model_name,
|
||||||
quantization='bitsandbytes',
|
quantization='bitsandbytes',
|
||||||
load_format='bitsandbytes',
|
load_format='bitsandbytes',
|
||||||
|
tensor_parallel_size=vllm_tp_size,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
gpu_memory_utilization=0.8) as llm:
|
gpu_memory_utilization=0.8) as llm:
|
||||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||||
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
||||||
|
|
||||||
# Clean up the GPU memory for the next test
|
# Clean up the GPU memory for the next test
|
||||||
torch.cuda.synchronize()
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
|
|||||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||||
|
|
||||||
# Clean up the GPU memory for the next test
|
# Clean up the GPU memory for the next test
|
||||||
torch.cuda.synchronize()
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -393,12 +393,6 @@ class ModelConfig:
|
|||||||
"Pipeline parallelism is only supported for the following "
|
"Pipeline parallelism is only supported for the following "
|
||||||
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
||||||
|
|
||||||
if self.quantization == "bitsandbytes" and (
|
|
||||||
parallel_config.tensor_parallel_size > 1
|
|
||||||
or parallel_config.pipeline_parallel_size > 1):
|
|
||||||
raise ValueError(
|
|
||||||
"BitAndBytes quantization with TP or PP is not supported yet.")
|
|
||||||
|
|
||||||
# Remove the constraint after the bitsandbytes issue is fixed:
|
# Remove the constraint after the bitsandbytes issue is fixed:
|
||||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
|
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
|
||||||
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
|
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
|
||||||
|
@ -530,8 +530,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
param_data = param_data.narrow(output_dim, shard_offset,
|
param_data = param_data.narrow(output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
# bitsandbytes loads the weights of the specific portion
|
||||||
shard_size)
|
# no need to narrow here
|
||||||
|
if not use_bitsandbytes_4bit:
|
||||||
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
|
shard_size)
|
||||||
# Special case for AQLM codebooks.
|
# Special case for AQLM codebooks.
|
||||||
elif is_metadata:
|
elif is_metadata:
|
||||||
# metadata indicates fixed size concatenated along dim 0
|
# metadata indicates fixed size concatenated along dim 0
|
||||||
@ -899,8 +902,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
else:
|
else:
|
||||||
shard_id = tp_rank // self.num_kv_head_replicas
|
shard_id = tp_rank // self.num_kv_head_replicas
|
||||||
start_idx = shard_id * shard_size
|
start_idx = shard_id * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
||||||
shard_size)
|
# bitsandbytes loads the weights of the specific portion
|
||||||
|
# no need to narrow here
|
||||||
|
if not use_bitsandbytes_4bit:
|
||||||
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
|
shard_size)
|
||||||
|
|
||||||
# Special case for for AQLM codebooks.
|
# Special case for for AQLM codebooks.
|
||||||
elif is_metadata:
|
elif is_metadata:
|
||||||
# metadata indicates fixed size concatenated along dim 0
|
# metadata indicates fixed size concatenated along dim 0
|
||||||
@ -1000,6 +1008,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
input_dim = getattr(param, "input_dim", None)
|
input_dim = getattr(param, "input_dim", None)
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
|
|
||||||
# Special case for GGUF
|
# Special case for GGUF
|
||||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||||
@ -1015,7 +1024,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||||
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
if input_dim is not None:
|
# bitsandbytes loads the weights of the specific portion
|
||||||
|
# no need to narrow here
|
||||||
|
if input_dim is not None and not use_bitsandbytes_4bit:
|
||||||
shard_size = param_data.shape[input_dim]
|
shard_size = param_data.shape[input_dim]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||||
|
@ -22,6 +22,8 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -689,6 +691,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||||
|
|
||||||
|
# TODO: these module names are for Llama only,
|
||||||
|
# change so that it works with other models as well
|
||||||
default_target_modules = [
|
default_target_modules = [
|
||||||
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
|
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
|
||||||
"o_proj"
|
"o_proj"
|
||||||
@ -911,13 +915,44 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||||
quant_state_dict) -> Generator:
|
quant_state_dict) -> Generator:
|
||||||
from bitsandbytes.functional import quantize_4bit
|
from bitsandbytes.functional import quantize_4bit
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors):
|
hf_weights_files, use_safetensors):
|
||||||
if any(target_module in weight_name
|
if any(target_module in weight_name
|
||||||
for target_module in self.target_modules):
|
for target_module in self.target_modules):
|
||||||
weight_name = weight_name.replace(".weight", ".qweight")
|
weight_name = weight_name.replace(".weight", ".qweight")
|
||||||
|
|
||||||
|
# weight partitions of different modules occur at
|
||||||
|
# different dimensions
|
||||||
|
# TODO: these module names are for Llama only,
|
||||||
|
# change so that it works with other models as well
|
||||||
|
if 'down_proj' in weight_name or 'o_proj' in weight_name:
|
||||||
|
total_size = weight_tensor.size(-1)
|
||||||
|
start_index = total_size // tp_size * tp_rank
|
||||||
|
end_index = total_size // tp_size * (tp_rank + 1)
|
||||||
|
weight_sub_tensor = weight_tensor[...,
|
||||||
|
start_index:end_index]
|
||||||
|
|
||||||
|
else:
|
||||||
|
total_size = weight_tensor.size(0)
|
||||||
|
start_index = total_size // tp_size * tp_rank
|
||||||
|
end_index = total_size // tp_size * (tp_rank + 1)
|
||||||
|
weight_sub_tensor = weight_tensor[start_index:end_index,
|
||||||
|
...]
|
||||||
|
|
||||||
# bitsandbytes requires data in GPU
|
# bitsandbytes requires data in GPU
|
||||||
loaded_weight = weight_tensor.cuda().data
|
if weight_sub_tensor.is_cuda:
|
||||||
|
loaded_weight = weight_sub_tensor
|
||||||
|
else:
|
||||||
|
loaded_weight = weight_sub_tensor.cuda()
|
||||||
|
|
||||||
|
# remove the following after the issue is fixed:
|
||||||
|
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||||
|
if loaded_weight.is_contiguous() is False:
|
||||||
|
loaded_weight = loaded_weight.contiguous()
|
||||||
|
|
||||||
with set_default_torch_dtype(torch.float32):
|
with set_default_torch_dtype(torch.float32):
|
||||||
processed_weight, quant_state = quantize_4bit(
|
processed_weight, quant_state = quantize_4bit(
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
@ -958,6 +993,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
f"BitsAndBytes loader does not support {quant_method} "
|
f"BitsAndBytes loader does not support {quant_method} "
|
||||||
"quantization")
|
"quantization")
|
||||||
|
|
||||||
|
# The quant_states in pre_quantized models cannot work with a split
|
||||||
|
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||||
|
if pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Prequant BitsAndBytes models with TP is not supported."
|
||||||
|
"Please try with PP.")
|
||||||
|
|
||||||
load_8bit = False
|
load_8bit = False
|
||||||
if pre_quant:
|
if pre_quant:
|
||||||
load_8bit = quant_config.get('load_in_8bit', False)
|
load_8bit = quant_config.get('load_in_8bit', False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user