[Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776)

This commit is contained in:
chenqianfzh 2024-06-01 13:51:10 -07:00 committed by GitHub
parent 37464a0f74
commit b9c0605a8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 752 additions and 8 deletions

View File

@ -0,0 +1,140 @@
"""
This example shows how to use LoRA with different quantization techniques
for offline inference.
Requires HuggingFace credentials for access.
"""
import gc
from typing import List, Optional, Tuple
import torch
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
def create_test_prompts(
lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
return [
# this is an example of using quantization without LoRA
("My name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128), None),
# the next three examples use quantization with LoRA
("my name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-1", 1, lora_path)),
("The capital of USA is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-2", 1, lora_path)),
("The capital of France is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-3", 1, lora_path)),
]
def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1
request_outputs: List[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print("----------------------------------------------------")
print(f"Prompt: {request_output.prompt}")
print(f"Output: {request_output.outputs[0].text}")
def initialize_engine(model: str, quantization: str,
lora_repo: Optional[str]) -> LLMEngine:
"""Initialize the LLMEngine."""
if quantization == "bitsandbytes":
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
# It quantizes the model when loading, with some config info from the
# LoRA adapter repo. So need to set the parameter of load_format and
# qlora_adapter_name_or_path as below.
engine_args = EngineArgs(
model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
load_format="bitsandbytes",
enable_lora=True,
max_lora_rank=64,
# set it only in GPUs of limited memory
enforce_eager=True)
else:
engine_args = EngineArgs(
model=model,
quantization=quantization,
enable_lora=True,
max_loras=4,
# set it only in GPUs of limited memory
enforce_eager=True)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
test_configs = [{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
}, {
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}, {
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}]
for test_config in test_configs:
print(
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
)
engine = initialize_engine(test_config['model'],
test_config['quantization'],
test_config['lora_repo'])
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.cuda.empty_cache()
if __name__ == '__main__':
main()

View File

@ -35,3 +35,6 @@ aiohttp
# Multimodal
pillow
# quantization
bitsandbytes==0.42.0

View File

@ -0,0 +1,80 @@
'''Tests whether bitsandbytes computation is enabled correctly.
Run `pytest tests/quantization/test_bitsandbytes.py`.
'''
import pytest
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
@pytest.mark.skipif(
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None:
llm = vllm_runner('huggyllama/llama-7b',
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True)
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
# check the weights in MLP & SelfAttention are quantized to torch.uint8
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')
qweight = model.model.layers[0].mlp.down_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')
qweight = model.model.layers[0].self_attn.o_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')
qweight = model.model.layers[0].self_attn.qkv_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')
# some weights should not be quantized
weight = model.lm_head.weight
assert weight.dtype != torch.uint8, (
'lm_head weight dtype should not be torch.uint8')
weight = model.model.embed_tokens.weight
assert weight.dtype != torch.uint8, (
'embed_tokens weight dtype should not be torch.uint8')
weight = model.model.layers[0].input_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')
weight = model.model.layers[0].post_attention_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')
# check the output of the model is expected
sampling_params = SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=8)
prompts = ['That which does not kill us', 'To be or not to be,']
expected_outputs = [
'That which does not kill us makes us stronger.',
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(outputs) == len(prompts)
for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')

View File

@ -241,6 +241,12 @@ class ModelConfig:
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
@ -327,7 +333,7 @@ class ModelConfig:
def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
@ -487,6 +493,7 @@ class LoadFormat(str, enum.Enum):
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
BITSANDBYTES = "bitsandbytes"
@dataclass

View File

@ -92,6 +92,8 @@ class EngineArgs:
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
qlora_adapter_name_or_path: Optional[str] = None
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
@ -159,7 +161,8 @@ class EngineArgs:
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
@ -173,7 +176,9 @@ class EngineArgs:
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n')
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
'--dtype',
type=str,
@ -543,7 +548,10 @@ class EngineArgs:
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one.")
parser.add_argument('--qlora-adapter-name-or-path',
type=str,
default=None,
help='Name or path of the QLoRA adapter.')
return parser
@classmethod
@ -555,6 +563,23 @@ class EngineArgs:
return engine_args
def create_engine_config(self, ) -> EngineConfig:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
@ -622,6 +647,13 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
@ -26,6 +26,21 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def adjust_bitsandbytes_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"]
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
quantized_total = param.data.shape[0]
quantized_offset = orig_offset * quantized_total // total
quantized_size = orig_size * quantized_total // total
return quantized_size, quantized_offset
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@ -37,7 +52,7 @@ class LinearMethodBase(QuantizeMethodBase):
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
@ -416,6 +431,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
@ -615,6 +636,22 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
self.num_kv_heads * self.head_size),
"v":
((self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.head_size),
"total":
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0)
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":

View File

@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.bitsandbytes import (
BitsAndBytesConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
@ -30,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"sparseml": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
}

View File

@ -0,0 +1,175 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class BitsAndBytesConfig(QuantizationConfig):
"""Config class for BitsAndBytes Quantization.
Reference: https://arxiv.org/abs/2305.14314
"""
def __init__(
self,
adapter_name_or_path: str,
target_modules: List[str],
) -> None:
self.adapter_name_or_path = adapter_name_or_path
self.target_modules = target_modules
def __repr__(self) -> str:
return (
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
)
@classmethod
def get_name(self) -> str:
return "bitsandbytes"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(self) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
return [
"adapter_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
if adapter_name == "":
target_modules = default_target_modules
else:
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]:
if isinstance(layer, LinearBase):
return BitsAndBytesLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
Args:
quant_config: The BitsAndBytes quantization config.
"""
def __init__(self, quant_config: BitsAndBytesConfig):
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
"bitsandbytes quantizer.") from err
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
quant_ratio = 0
if params_dtype.is_floating_point:
quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo(
torch.uint8).bits
else:
quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo(
torch.uint8).bits
if input_size_per_partition * sum(
output_partition_sizes) % quant_ratio != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. ")
qweight = Parameter(
torch.empty(
input_size_per_partition * sum(output_partition_sizes) //
quant_ratio,
1,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
# In bitsandbytes, a tensor of shape [n,m] is quantized to
#[n*m/pack_ratio, 1],so the output_dim is 0
"output_dim": 0,
"pack_factor": quant_ratio,
"use_bitsandbytes": True,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
original_type = x.dtype
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
quant_states = qweight.bnb_quant_state
offsets = qweight.bnb_shard_offsets
out_dim_0 = x.shape[0]
out_dim_1 = sum(
[quant_state[1].shape[0] for quant_state in quant_states.items()])
out = torch.empty(out_dim_0,
out_dim_1,
dtype=torch.bfloat16,
device=x.device)
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
out = out.to(original_type)
if bias is not None:
out += bias
return out

View File

@ -1,13 +1,18 @@
# ruff: noqa: SIM117
import collections
import copy
import fnmatch
import glob
import json
import math
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import huggingface_hub
import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
@ -28,6 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import (
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
@ -125,7 +131,7 @@ class DefaultModelLoader(BaseModelLoader):
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if VLLM_USE_MODELSCOPE:
@ -247,6 +253,7 @@ class DefaultModelLoader(BaseModelLoader):
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
@ -539,6 +546,241 @@ class ShardedStateLoader(BaseModelLoader):
)
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
possible_config_file_names = ["adapter_config.json"]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
self.target_modules = self.default_target_modules
return
qlora_adapter = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]
config_file_path = self._get_config_file(qlora_adapter)
with open(config_file_path, "r") as f:
config = json.load(f)
self.target_modules = config["target_modules"]
def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
config_file_path = None
if is_local:
for file in self.possible_config_file_names:
config_file_path = os.path.join(qlora_adapter, file)
if os.path.exists(config_file_path):
break
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
for file in self.possible_config_file_names:
if file in repo_files:
config_file_path = hf_hub_download(repo_id=qlora_adapter,
filename=file)
break
if not config_file_path:
raise ValueError(
f"Cannot find adapter config file in {qlora_adapter}")
return config_file_path
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None) -> Tuple[List[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(
os.path.join(model_name_or_path, pattern))
if weight_files:
return weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path, self.load_config.download_dir,
[pattern], revision)
return glob.glob(os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> Tuple[List[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision)
if matched_pattern != "*.safetensors":
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, matched_pattern == "*.safetensors"
def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str]
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
from bitsandbytes.functional import quantize_4bit
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
"bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision)
quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weight_iterator = pt_weights_iterator(hf_weights_files)
def generator():
for weight_name, weight_tensor in weight_iterator:
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4")
quant_state_dict[weight_name] = quant_state
else:
processed_weight = weight_tensor
yield weight_name, processed_weight
return generator(), quant_state_dict
def _load_weights(self, model_config: ModelConfig,
model: nn.Module) -> None:
if not hasattr(model, 'load_weights'):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(self).__name__}.")
if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
raise AttributeError(
f"Model {type(self).__name__} does not support BitsAndBytes "
"quantization yet.")
logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision))
model.load_weights(qweight_iterator)
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict:
non_stacked_param_name = quant_param_name
shard_index = 0
for shard_name, (
weight_name, index
) in model.bitsandbytes_stacked_params_mapping.items():
if shard_name in quant_param_name:
shard_index = index
quant_param_name = quant_param_name.replace(
shard_name, weight_name)
break
if quant_param_name not in param_dict:
raise ValueError(
f"Parameter {quant_param_name} not found in the model.")
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = (
quant_state_dict[non_stacked_param_name])
# save quant_states and offsets as the attributes of the parameters
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
set_weight_attrs(param, {"bnb_quant_state": quant_states})
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(
f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in enumerate(quant_states.items()):
num_elements[seq] = math.prod(
quant_state[1].shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
self._load_weights(model_config, model)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
@ -554,4 +796,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
return DefaultModelLoader(load_config)

View File

@ -130,7 +130,17 @@ def get_quant_config(model_config: ModelConfig,
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if model_config.quantization == "bitsandbytes":
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
return quant_cls.from_config({"adapter_name_or_path": ""})
model_name_or_path = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]
else:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
@ -169,6 +179,10 @@ def get_quant_config(model_config: ModelConfig,
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
return quant_cls.from_config(config)

View File

@ -319,6 +319,14 @@ class LlamaForCausalLM(nn.Module):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,