[Model] Add classification Task with Qwen2ForSequenceClassification (#9704)
Signed-off-by: Kevin-Yang <ykcha9@gmail.com> Co-authored-by: Kevin-Yang <ykcha9@gmail.com>
This commit is contained in:
parent
07e981fdf4
commit
6650e6a930
@ -361,6 +361,28 @@ Reward Modeling
|
|||||||
.. note::
|
.. note::
|
||||||
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
|
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
|
||||||
|
|
||||||
|
Classification
|
||||||
|
---------------
|
||||||
|
|
||||||
|
.. list-table::
|
||||||
|
:widths: 25 25 50 5 5
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Architecture
|
||||||
|
- Models
|
||||||
|
- Example HF Models
|
||||||
|
- :ref:`LoRA <lora>`
|
||||||
|
- :ref:`PP <distributed_serving>`
|
||||||
|
* - :code:`Qwen2ForSequenceClassification`
|
||||||
|
- Qwen2-based
|
||||||
|
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
|
||||||
|
-
|
||||||
|
- ✅︎
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now).
|
||||||
|
|
||||||
|
|
||||||
Multimodal Language Models
|
Multimodal Language Models
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -343,6 +343,17 @@ class HfRunner:
|
|||||||
|
|
||||||
return all_inputs
|
return all_inputs
|
||||||
|
|
||||||
|
def classify(self, prompts: List[str]) -> List[str]:
|
||||||
|
# output is final logits
|
||||||
|
all_inputs = self.get_inputs(prompts)
|
||||||
|
outputs = []
|
||||||
|
for inputs in all_inputs:
|
||||||
|
output = self.model(**self.wrap_device(inputs))
|
||||||
|
logits = output.logits.softmax(dim=-1)[0].tolist()
|
||||||
|
outputs.append(logits)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -688,6 +699,14 @@ class VllmRunner:
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def classify(self, prompts: List[str]) -> List[str]:
|
||||||
|
req_outputs = self.model.encode(prompts)
|
||||||
|
outputs = []
|
||||||
|
for req_output in req_outputs:
|
||||||
|
embedding = req_output.outputs.embedding
|
||||||
|
outputs.append(embedding)
|
||||||
|
return outputs
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
|
53
tests/models/embedding/language/test_cls_models.py
Normal file
53
tests/models/embedding/language/test_cls_models.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
This test only tests small models. Big models such as 7B should be tested from
|
||||||
|
test_big_models.py because it could use a larger instance to run tests.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_cls_models.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
def test_classification_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with hf_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||||
|
hf_outputs = hf_model.classify(example_prompts)
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
|
print(hf_outputs, vllm_outputs)
|
||||||
|
|
||||||
|
# check logits difference
|
||||||
|
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||||
|
hf_output = torch.tensor(hf_output)
|
||||||
|
vllm_output = torch.tensor(vllm_output)
|
||||||
|
|
||||||
|
assert torch.allclose(hf_output, vllm_output, 1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
def test_classification_model_print(
|
||||||
|
vllm_runner,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
# This test is for verifying whether the model's extra_repr
|
||||||
|
# can be printed correctly.
|
||||||
|
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
@ -28,11 +28,15 @@ class Pooler(nn.Module):
|
|||||||
normalize: Whether to normalize the pooled data.
|
normalize: Whether to normalize the pooled data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
def __init__(self,
|
||||||
|
pooling_type: PoolingType,
|
||||||
|
normalize: bool,
|
||||||
|
softmax: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.pooling_type = pooling_type
|
self.pooling_type = pooling_type
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
|
self.softmax = softmax
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -64,6 +68,9 @@ class Pooler(nn.Module):
|
|||||||
if self.normalize:
|
if self.normalize:
|
||||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||||
|
|
||||||
|
if self.softmax:
|
||||||
|
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
|
||||||
|
|
||||||
pooled_outputs = [
|
pooled_outputs = [
|
||||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||||
]
|
]
|
||||||
|
107
vllm/model_executor/models/qwen2_cls.py
Normal file
107
vllm/model_executor/models/qwen2_cls.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
|
||||||
|
# Copyright 2024 Kakao Corp. (Kanana-X Team)
|
||||||
|
# Copyright 2024 The Qwen team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
"""Inference-only Qwen2-Classification model compatible with HF weights."""
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Qwen2Config
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
|
from .utils import AutoWeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ForSequenceClassification(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen2Config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
# TODO (@robertgshaw2): see if this can be moved out
|
||||||
|
if (cache_config.sliding_window is not None
|
||||||
|
and hasattr(config, "max_window_layers")):
|
||||||
|
raise ValueError("Sliding window for some but all layers is not "
|
||||||
|
"supported. This model uses sliding window "
|
||||||
|
"but `max_window_layers` = %s is less than "
|
||||||
|
"`num_hidden_layers` = %s. Please open an issue "
|
||||||
|
"to discuss this feature." % (
|
||||||
|
config.max_window_layers,
|
||||||
|
config.num_hidden_layers,
|
||||||
|
))
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||||
|
|
||||||
|
self.score = RowParallelLinear(config.hidden_size,
|
||||||
|
config.num_labels,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self._pooler = Pooler(pooling_type=PoolingType.LAST,
|
||||||
|
normalize=False,
|
||||||
|
softmax=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors)
|
||||||
|
logits, _ = self.score(hidden_states)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
loader = AutoWeightsLoader(self,
|
||||||
|
ignore_unexpected_prefixes=["lm_head."])
|
||||||
|
loader.load_weights(weights)
|
@ -96,6 +96,8 @@ _EMBEDDING_MODELS = {
|
|||||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||||
|
"Qwen2ForSequenceClassification": (
|
||||||
|
"qwen2_cls", "Qwen2ForSequenceClassification"),
|
||||||
# [Multimodal]
|
# [Multimodal]
|
||||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user