[Misc] Auto detect bitsandbytes pre-quantized models (#16027)
Signed-off-by: Tristan Leclercq <tristanleclercq@gmail.com>
This commit is contained in:
parent
63375f0cdb
commit
4285e423a6
@ -19,17 +19,20 @@ And usually, these repositories have a config.json file that includes a quantiza
|
||||
|
||||
## Read quantized checkpoint
|
||||
|
||||
For pre-quantized checkpoints, vLLM will try to infer the quantization method from the config file, so you don't need to explicitly specify the quantization argument.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
import torch
|
||||
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
|
||||
model_id = "unsloth/tinyllama-bnb-4bit"
|
||||
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
|
||||
quantization="bitsandbytes")
|
||||
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True)
|
||||
```
|
||||
|
||||
## Inflight quantization: load as 4bit quantization
|
||||
|
||||
For inflight 4bit quantization with BitsAndBytes, you need to explicitly specify the quantization argument.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
import torch
|
||||
@ -40,7 +43,7 @@ quantization="bitsandbytes")
|
||||
|
||||
## OpenAI Compatible Server
|
||||
|
||||
Append the following to your 4bit model arguments:
|
||||
Append the following to your model arguments for 4bit inflight quantization:
|
||||
|
||||
```console
|
||||
--quantization bitsandbytes
|
||||
|
@ -41,7 +41,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
|
||||
hf_model_kwargs = {"load_in_4bit": True}
|
||||
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
|
||||
model_name, hf_model_kwargs)
|
||||
model_name, False, hf_model_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
@ -53,7 +53,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
|
||||
model_name)
|
||||
model_name, True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
@ -65,7 +65,7 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
|
||||
model_name)
|
||||
model_name, True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
@ -82,6 +82,7 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
vllm_runner,
|
||||
example_prompts[:1],
|
||||
model_name,
|
||||
False,
|
||||
hf_model_kwargs,
|
||||
vllm_tp_size=2)
|
||||
|
||||
@ -128,13 +129,14 @@ def validate_generated_texts(hf_runner,
|
||||
vllm_runner,
|
||||
prompts,
|
||||
model_name,
|
||||
pre_quant=False,
|
||||
hf_model_kwargs=None,
|
||||
vllm_tp_size=1):
|
||||
|
||||
# NOTE: run vLLM first, as it requires a clean process
|
||||
# when using distributed inference
|
||||
with vllm_runner(model_name,
|
||||
quantization='bitsandbytes',
|
||||
quantization=None if pre_quant else 'bitsandbytes',
|
||||
tensor_parallel_size=vllm_tp_size,
|
||||
enforce_eager=False) as llm:
|
||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||
|
@ -1275,6 +1275,10 @@ class EngineArgs:
|
||||
self.model_loader_extra_config[
|
||||
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
|
||||
|
||||
# bitsandbytes pre-quantized model need a specific model loader
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
self.quantization = self.load_format = "bitsandbytes"
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
prompt_adapter_config = PromptAdapterConfig(
|
||||
|
Loading…
x
Reference in New Issue
Block a user