[Bugfix] Added Command-R GPTQ support (#3849)

Co-authored-by: Egor Tolmachev <t333ga@gmail.com>
This commit is contained in:
egortolmachev 2024-04-08 17:59:38 +03:00 committed by GitHub
parent b4543c8f6b
commit f46864d68d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -349,11 +349,21 @@ class CohereForCausalLM(nn.Module):
if shard_name not in name: if shard_name not in name:
continue continue
name = name.replace(shard_name, param_name) name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)