[Bugfix] Fix tensor parallel for qwen2 classification model (#10297)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
ac49b59d8b
commit
15bb8330aa
@ -21,14 +21,14 @@ def test_classification_models(
|
|||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
with hf_runner(model,
|
with hf_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||||
hf_outputs = hf_model.classify(example_prompts)
|
hf_outputs = hf_model.classify(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.classify(example_prompts)
|
|
||||||
|
|
||||||
print(hf_outputs, vllm_outputs)
|
print(hf_outputs, vllm_outputs)
|
||||||
|
|
||||||
# check logits difference
|
# check logits difference
|
||||||
|
@ -69,9 +69,14 @@ class Qwen2ForSequenceClassification(nn.Module):
|
|||||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
# hidden_states from Qwen2Model has been reduced,
|
||||||
|
# the input of score layer is not parallelized.
|
||||||
self.score = RowParallelLinear(config.hidden_size,
|
self.score = RowParallelLinear(config.hidden_size,
|
||||||
config.num_labels,
|
config.num_labels,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
input_is_parallel=False,
|
||||||
|
bias=False,
|
||||||
|
prefix=maybe_prefix(prefix, "score"))
|
||||||
self._pooler = Pooler.from_config_with_defaults(
|
self._pooler = Pooler.from_config_with_defaults(
|
||||||
pooler_config,
|
pooler_config,
|
||||||
pooling_type=PoolingType.LAST,
|
pooling_type=PoolingType.LAST,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user