[Misc] Auto detect bitsandbytes pre-quantized models (#16027)

Signed-off-by: Tristan Leclercq <tristanleclercq@gmail.com>
This commit is contained in:
Tristan Leclercq 2025-04-05 08:30:45 +02:00 committed by GitHub
parent 63375f0cdb
commit 4285e423a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 7 deletions

View File

@ -19,17 +19,20 @@ And usually, these repositories have a config.json file that includes a quantiza
## Read quantized checkpoint ## Read quantized checkpoint
For pre-quantized checkpoints, vLLM will try to infer the quantization method from the config file, so you don't need to explicitly specify the quantization argument.
```python ```python
from vllm import LLM from vllm import LLM
import torch import torch
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint. # unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
model_id = "unsloth/tinyllama-bnb-4bit" model_id = "unsloth/tinyllama-bnb-4bit"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \ llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True)
quantization="bitsandbytes")
``` ```
## Inflight quantization: load as 4bit quantization ## Inflight quantization: load as 4bit quantization
For inflight 4bit quantization with BitsAndBytes, you need to explicitly specify the quantization argument.
```python ```python
from vllm import LLM from vllm import LLM
import torch import torch
@ -40,7 +43,7 @@ quantization="bitsandbytes")
## OpenAI Compatible Server ## OpenAI Compatible Server
Append the following to your 4bit model arguments: Append the following to your model arguments for 4bit inflight quantization:
```console ```console
--quantization bitsandbytes --quantization bitsandbytes

View File

@ -41,7 +41,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
hf_model_kwargs = {"load_in_4bit": True} hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, hf_model_kwargs) model_name, False, hf_model_kwargs)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
@ -53,7 +53,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name) model_name, True)
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
@ -65,7 +65,7 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None: model_name, description) -> None:
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name) model_name, True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
@ -82,6 +82,7 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
vllm_runner, vllm_runner,
example_prompts[:1], example_prompts[:1],
model_name, model_name,
False,
hf_model_kwargs, hf_model_kwargs,
vllm_tp_size=2) vllm_tp_size=2)
@ -128,13 +129,14 @@ def validate_generated_texts(hf_runner,
vllm_runner, vllm_runner,
prompts, prompts,
model_name, model_name,
pre_quant=False,
hf_model_kwargs=None, hf_model_kwargs=None,
vllm_tp_size=1): vllm_tp_size=1):
# NOTE: run vLLM first, as it requires a clean process # NOTE: run vLLM first, as it requires a clean process
# when using distributed inference # when using distributed inference
with vllm_runner(model_name, with vllm_runner(model_name,
quantization='bitsandbytes', quantization=None if pre_quant else 'bitsandbytes',
tensor_parallel_size=vllm_tp_size, tensor_parallel_size=vllm_tp_size,
enforce_eager=False) as llm: enforce_eager=False) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8) vllm_outputs = llm.generate_greedy(prompts, 8)

View File

@ -1275,6 +1275,10 @@ class EngineArgs:
self.model_loader_extra_config[ self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
# bitsandbytes pre-quantized model need a specific model loader
if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes"
load_config = self.create_load_config() load_config = self.create_load_config()
prompt_adapter_config = PromptAdapterConfig( prompt_adapter_config = PromptAdapterConfig(