[Bugfix] Fix bnb quantization for models with both HF-format and Mistral-format weights (#14950)

This commit is contained in:
Tristan Leclercq 2025-03-18 00:27:26 +01:00 committed by GitHub
parent 18551e820c
commit 5eeabc2a44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 6 deletions

View File

@ -15,6 +15,8 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"),
("mistralai/Mistral-7B-Instruct-v0.3",
"quantize inflight model with both HF and Mistral format weights")
]
models_pre_qaunt_4bit_to_test = [

View File

@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None,
) -> Tuple[List[str], str]:
) -> Tuple[str, List[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_files = glob.glob(
os.path.join(model_name_or_path, pattern))
if weight_files:
return weight_files, pattern
return model_name_or_path, weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return glob.glob(os.path.join(hf_folder, pattern)), pattern
return hf_folder, glob.glob(
os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`")
@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_weights_files, matched_pattern = self._get_weight_files(
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision)
if matched_pattern != "*.safetensors":
use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, matched_pattern == "*.safetensors"
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors: