[Model] Add classification Task with Qwen2ForSequenceClassification (#9704)
Signed-off-by: Kevin-Yang <ykcha9@gmail.com> Co-authored-by: Kevin-Yang <ykcha9@gmail.com>
This commit is contained in:
parent
07e981fdf4
commit
6650e6a930
@ -361,6 +361,28 @@ Reward Modeling
|
||||
.. note::
|
||||
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
|
||||
|
||||
Classification
|
||||
---------------
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 50 5 5
|
||||
:header-rows: 1
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Qwen2ForSequenceClassification`
|
||||
- Qwen2-based
|
||||
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
|
||||
.. note::
|
||||
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now).
|
||||
|
||||
|
||||
Multimodal Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -343,6 +343,17 @@ class HfRunner:
|
||||
|
||||
return all_inputs
|
||||
|
||||
def classify(self, prompts: List[str]) -> List[str]:
|
||||
# output is final logits
|
||||
all_inputs = self.get_inputs(prompts)
|
||||
outputs = []
|
||||
for inputs in all_inputs:
|
||||
output = self.model(**self.wrap_device(inputs))
|
||||
logits = output.logits.softmax(dim=-1)[0].tolist()
|
||||
outputs.append(logits)
|
||||
|
||||
return outputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -688,6 +699,14 @@ class VllmRunner:
|
||||
|
||||
return inputs
|
||||
|
||||
def classify(self, prompts: List[str]) -> List[str]:
|
||||
req_outputs = self.model.encode(prompts)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
embedding = req_output.outputs.embedding
|
||||
outputs.append(embedding)
|
||||
return outputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
|
53
tests/models/embedding/language/test_cls_models.py
Normal file
53
tests/models/embedding/language/test_cls_models.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
This test only tests small models. Big models such as 7B should be tested from
|
||||
test_big_models.py because it could use a larger instance to run tests.
|
||||
|
||||
Run `pytest tests/models/test_cls_models.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_classification_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
print(hf_outputs, vllm_outputs)
|
||||
|
||||
# check logits difference
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output)
|
||||
vllm_output = torch.tensor(vllm_output)
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output, 1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_classification_model_print(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
# This test is for verifying whether the model's extra_repr
|
||||
# can be printed correctly.
|
||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
@ -28,11 +28,15 @@ class Pooler(nn.Module):
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
||||
def __init__(self,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.pooling_type = pooling_type
|
||||
self.normalize = normalize
|
||||
self.softmax = softmax
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -64,6 +68,9 @@ class Pooler(nn.Module):
|
||||
if self.normalize:
|
||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||
|
||||
if self.softmax:
|
||||
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||
]
|
||||
|
107
vllm/model_executor/models/qwen2_cls.py
Normal file
107
vllm/model_executor/models/qwen2_cls.py
Normal file
@ -0,0 +1,107 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
|
||||
# Copyright 2024 Kakao Corp. (Kanana-X Team)
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
"""Inference-only Qwen2-Classification model compatible with HF weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class Qwen2ForSequenceClassification(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
and hasattr(config, "max_window_layers")):
|
||||
raise ValueError("Sliding window for some but all layers is not "
|
||||
"supported. This model uses sliding window "
|
||||
"but `max_window_layers` = %s is less than "
|
||||
"`num_hidden_layers` = %s. Please open an issue "
|
||||
"to discuss this feature." % (
|
||||
config.max_window_layers,
|
||||
config.num_hidden_layers,
|
||||
))
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["lm_head."])
|
||||
loader.load_weights(weights)
|
@ -96,6 +96,8 @@ _EMBEDDING_MODELS = {
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": (
|
||||
"qwen2_cls", "Qwen2ForSequenceClassification"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user