[Bugfix] Fix bnb quantization for models with both HF-format and Mistral-format weights (#14950)

This commit is contained in:
Tristan Leclercq 2025-03-18 00:27:26 +01:00 committed by GitHub
parent 18551e820c
commit 5eeabc2a44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 6 deletions

View File

@ -15,6 +15,8 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
models_4bit_to_test = [ models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"), ("facebook/opt-125m", "quantize opt model inflight"),
("mistralai/Mistral-7B-Instruct-v0.3",
"quantize inflight model with both HF and Mistral format weights")
] ]
models_pre_qaunt_4bit_to_test = [ models_pre_qaunt_4bit_to_test = [

View File

@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path: str, model_name_or_path: str,
allowed_patterns: List[str], allowed_patterns: List[str],
revision: Optional[str] = None, revision: Optional[str] = None,
) -> Tuple[List[str], str]: ) -> Tuple[str, List[str], str]:
"""Retrieve weight files. Download the files if necessary. """Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern.""" Return the weight files and the file pattern."""
@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_files = glob.glob( weight_files = glob.glob(
os.path.join(model_name_or_path, pattern)) os.path.join(model_name_or_path, pattern))
if weight_files: if weight_files:
return weight_files, pattern return model_name_or_path, weight_files, pattern
else: else:
hf_api = HfApi() hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision, revision,
ignore_patterns=self.load_config.ignore_patterns, ignore_patterns=self.load_config.ignore_patterns,
) )
return glob.glob(os.path.join(hf_folder, pattern)), pattern return hf_folder, glob.glob(
os.path.join(hf_folder, pattern)), pattern
raise RuntimeError( raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`") f"No model weights found in: `{model_name_or_path}`")
@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_weights_files, matched_pattern = self._get_weight_files( hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision) model_name_or_path, allowed_patterns, revision)
if matched_pattern != "*.safetensors": use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files) hf_weights_files)
@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise RuntimeError( raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`") f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, matched_pattern == "*.safetensors" return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors: if use_safetensors: