[Model] Add classification Task with Qwen2ForSequenceClassification (#9704)

Signed-off-by: Kevin-Yang <ykcha9@gmail.com>
Co-authored-by: Kevin-Yang <ykcha9@gmail.com>
This commit is contained in:
kakao-kevin-us 2024-10-27 02:53:35 +09:00 committed by GitHub
parent 07e981fdf4
commit 6650e6a930
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 211 additions and 1 deletions

View File

@ -361,6 +361,28 @@ Reward Modeling
.. note::
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
Classification
---------------
.. list-table::
:widths: 25 25 50 5 5
:header-rows: 1
* - Architecture
- Models
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Qwen2ForSequenceClassification`
- Qwen2-based
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
-
- ✅︎
.. note::
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now).
Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -343,6 +343,17 @@ class HfRunner:
return all_inputs
def classify(self, prompts: List[str]) -> List[str]:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
logits = output.logits.softmax(dim=-1)[0].tolist()
outputs.append(logits)
return outputs
def generate(
self,
prompts: List[str],
@ -688,6 +699,14 @@ class VllmRunner:
return inputs
def classify(self, prompts: List[str]) -> List[str]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs
def generate(
self,
prompts: List[str],

View File

@ -0,0 +1,53 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
Run `pytest tests/models/test_cls_models.py`.
"""
import pytest
import torch
from transformers import AutoModelForSequenceClassification
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_classification_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
print(hf_outputs, vllm_outputs)
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output, 1e-3)
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_classification_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)

View File

@ -28,11 +28,15 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data.
"""
def __init__(self, pooling_type: PoolingType, normalize: bool):
def __init__(self,
pooling_type: PoolingType,
normalize: bool,
softmax: bool = False):
super().__init__()
self.pooling_type = pooling_type
self.normalize = normalize
self.softmax = softmax
def forward(
self,
@ -64,6 +68,9 @@ class Pooler(nn.Module):
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
if self.softmax:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
]

View File

@ -0,0 +1,107 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 Kakao Corp. (Kanana-X Team)
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-Classification model compatible with HF weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .utils import AutoWeightsLoader
class Qwen2ForSequenceClassification(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = %s is less than "
"`num_hidden_layers` = %s. Please open an issue "
"to discuss this feature." % (
config.max_window_layers,
config.num_hidden_layers,
))
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config)
self._pooler = Pooler(pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
logits, _ = self.score(hidden_states)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights)

View File

@ -96,6 +96,8 @@ _EMBEDDING_MODELS = {
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": (
"qwen2_cls", "Qwen2ForSequenceClassification"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),