[Bugfix] Fix bnb quantization for models with both HF-format and Mistral-format weights (#14950)
This commit is contained in:
parent
18551e820c
commit
5eeabc2a44
@ -15,6 +15,8 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
|
|||||||
|
|
||||||
models_4bit_to_test = [
|
models_4bit_to_test = [
|
||||||
("facebook/opt-125m", "quantize opt model inflight"),
|
("facebook/opt-125m", "quantize opt model inflight"),
|
||||||
|
("mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
|
"quantize inflight model with both HF and Mistral format weights")
|
||||||
]
|
]
|
||||||
|
|
||||||
models_pre_qaunt_4bit_to_test = [
|
models_pre_qaunt_4bit_to_test = [
|
||||||
|
@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
allowed_patterns: List[str],
|
allowed_patterns: List[str],
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
) -> Tuple[List[str], str]:
|
) -> Tuple[str, List[str], str]:
|
||||||
"""Retrieve weight files. Download the files if necessary.
|
"""Retrieve weight files. Download the files if necessary.
|
||||||
|
|
||||||
Return the weight files and the file pattern."""
|
Return the weight files and the file pattern."""
|
||||||
@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
weight_files = glob.glob(
|
weight_files = glob.glob(
|
||||||
os.path.join(model_name_or_path, pattern))
|
os.path.join(model_name_or_path, pattern))
|
||||||
if weight_files:
|
if weight_files:
|
||||||
return weight_files, pattern
|
return model_name_or_path, weight_files, pattern
|
||||||
else:
|
else:
|
||||||
hf_api = HfApi()
|
hf_api = HfApi()
|
||||||
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
||||||
@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
revision,
|
revision,
|
||||||
ignore_patterns=self.load_config.ignore_patterns,
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
)
|
)
|
||||||
return glob.glob(os.path.join(hf_folder, pattern)), pattern
|
return hf_folder, glob.glob(
|
||||||
|
os.path.join(hf_folder, pattern)), pattern
|
||||||
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"No model weights found in: `{model_name_or_path}`")
|
f"No model weights found in: `{model_name_or_path}`")
|
||||||
@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||||
|
|
||||||
hf_weights_files, matched_pattern = self._get_weight_files(
|
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
|
||||||
model_name_or_path, allowed_patterns, revision)
|
model_name_or_path, allowed_patterns, revision)
|
||||||
|
|
||||||
if matched_pattern != "*.safetensors":
|
use_safetensors = matched_pattern == "*.safetensors"
|
||||||
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||||
|
if use_safetensors:
|
||||||
|
# For models like Mistral-7B-Instruct-v0.3
|
||||||
|
# there are both sharded safetensors files and a consolidated
|
||||||
|
# safetensors file. Using both breaks.
|
||||||
|
# Here, we download the `model.safetensors.index.json` and filter
|
||||||
|
# any files not found in the index.
|
||||||
|
if not is_local:
|
||||||
|
download_safetensors_index_file_from_hf(
|
||||||
|
model_name_or_path,
|
||||||
|
index_file,
|
||||||
|
self.load_config.download_dir,
|
||||||
|
revision,
|
||||||
|
)
|
||||||
|
hf_weights_files = filter_duplicate_safetensors_files(
|
||||||
|
hf_weights_files, hf_folder, index_file)
|
||||||
|
else:
|
||||||
hf_weights_files = filter_files_not_needed_for_inference(
|
hf_weights_files = filter_files_not_needed_for_inference(
|
||||||
hf_weights_files)
|
hf_weights_files)
|
||||||
|
|
||||||
@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
|
|
||||||
return hf_weights_files, matched_pattern == "*.safetensors"
|
return hf_weights_files, use_safetensors
|
||||||
|
|
||||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||||
if use_safetensors:
|
if use_safetensors:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user