[Bugfix] Added Command-R GPTQ support (#3849)
Co-authored-by: Egor Tolmachev <t333ga@gmail.com>
This commit is contained in:
parent
b4543c8f6b
commit
f46864d68d
@ -349,11 +349,21 @@ class CohereForCausalLM(nn.Module):
|
|||||||
if shard_name not in name:
|
if shard_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(shard_name, param_name)
|
name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# lm_head is not used in vllm as it is tied with embed_token.
|
||||||
|
# To prevent errors, skip loading lm_head.weight.
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user