[Bugfix] Fix bnb quantization for models with both HF-format and Mistral-format weights (#14950)
This commit is contained in:
parent
18551e820c
commit
5eeabc2a44
@ -15,6 +15,8 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
||||
models_4bit_to_test = [
|
||||
("facebook/opt-125m", "quantize opt model inflight"),
|
||||
("mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"quantize inflight model with both HF and Mistral format weights")
|
||||
]
|
||||
|
||||
models_pre_qaunt_4bit_to_test = [
|
||||
|
@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: List[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[List[str], str]:
|
||||
) -> Tuple[str, List[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
weight_files = glob.glob(
|
||||
os.path.join(model_name_or_path, pattern))
|
||||
if weight_files:
|
||||
return weight_files, pattern
|
||||
return model_name_or_path, weight_files, pattern
|
||||
else:
|
||||
hf_api = HfApi()
|
||||
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
||||
@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
return glob.glob(os.path.join(hf_folder, pattern)), pattern
|
||||
return hf_folder, glob.glob(
|
||||
os.path.join(hf_folder, pattern)), pattern
|
||||
|
||||
raise RuntimeError(
|
||||
f"No model weights found in: `{model_name_or_path}`")
|
||||
@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
|
||||
hf_weights_files, matched_pattern = self._get_weight_files(
|
||||
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
|
||||
model_name_or_path, allowed_patterns, revision)
|
||||
|
||||
if matched_pattern != "*.safetensors":
|
||||
use_safetensors = matched_pattern == "*.safetensors"
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_weights_files, matched_pattern == "*.safetensors"
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
if use_safetensors:
|
||||
|
Loading…
x
Reference in New Issue
Block a user