[Frontend] Separate pooling APIs in offline inference (#11129)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-13 18:40:07 +08:00 committed by GitHub
parent f93bf2b189
commit eeec9e3390
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 669 additions and 304 deletions

View File

@ -181,14 +181,14 @@ steps:
commands:
- VLLM_USE_V1=1 pytest -v -s v1
- label: Examples Test # 15min
- label: Examples Test # 25min
working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/entrypoints
- examples/
commands:
- pip install awscli tensorizer # for llava example and tensorizer test
- pip install tensorizer # for tensorizer test
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_chat.py
@ -198,6 +198,9 @@ steps:
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_inference_classification.py
- python3 offline_inference_embedding.py
- python3 offline_inference_scoring.py
- python3 offline_profile.py --model facebook/opt-125m
- label: Prefix Caching Test # 9min

View File

@ -6,7 +6,7 @@ Pooling Models
vLLM also supports pooling models, including embedding, reranking and reward models.
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input
before returning them.
.. note::
@ -45,20 +45,48 @@ which takes priority over both the model's and Sentence Transformers's defaults.
^^^^^^^^^^^^^^
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
It returns the aggregated hidden states directly.
It returns the extracted hidden states directly, which is useful for reward models.
.. code-block:: python
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward")
output, = llm.encode("Hello, my name is")
data = output.outputs.data
print(f"Prompt: {prompt!r} | Data: {data!r}")
``LLM.embed``
^^^^^^^^^^^^^
The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt.
It is primarily designed for embedding models.
.. code-block:: python
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
outputs = llm.encode("Hello, my name is")
output, = llm.embed("Hello, my name is")
outputs = model.encode(prompts)
for output in outputs:
embeddings = output.outputs.embedding
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.
``LLM.classify``
^^^^^^^^^^^^^^^^
The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt.
It is primarily designed for classification models.
.. code-block:: python
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify")
output, = llm.classify("Hello, my name is")
probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
A code example can be found in `examples/offline_inference_classification.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_classification.py>`_.
``LLM.score``
^^^^^^^^^^^^^
@ -71,7 +99,16 @@ These types of models serve as rerankers between candidate query-document pairs
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
.. code-block:: python
llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score")
output, = llm.score("What is the capital of France?",
"The capital of Brazil is Brasilia.")
score = output.outputs.score
print(f"Score: {score}")
A code example can be found in `examples/offline_inference_scoring.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_scoring.py>`_.
Online Inference
----------------

View File

@ -0,0 +1,28 @@
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass task="classify" for classification models
model = LLM(
model="jason9693/Qwen2.5-1.5B-apeach",
task="classify",
enforce_eager=True,
)
# Generate logits. The output is a list of ClassificationRequestOutputs.
outputs = model.classify(prompts)
# Print the outputs.
for prompt, output in zip(prompts, outputs):
probs = output.outputs.probs
probs_trimmed = ((str(probs[:16])[:-1] +
", ...]") if len(probs) > 16 else probs)
print(f"Prompt: {prompt!r} | "
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")

View File

@ -9,14 +9,20 @@ prompts = [
]
# Create an LLM.
# You should pass task="embed" for embedding models
model = LLM(
model="intfloat/e5-mistral-7b-instruct",
task="embed", # You should pass task="embed" for embedding models
task="embed",
enforce_eager=True,
)
# Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.embed(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] +
", ...]") if len(embeds) > 16 else embeds)
print(f"Prompt: {prompt!r} | "
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")

View File

@ -0,0 +1,23 @@
from vllm import LLM
# Sample prompts.
text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
# Create an LLM.
# You should pass task="score" for cross-encoder models
model = LLM(
model="BAAI/bge-reranker-v2-m3",
task="score",
enforce_eager=True,
)
# Generate scores. The output is a list of ScoringRequestOutputs.
outputs = model.score(text_1, texts_2)
# Print the outputs.
for text_2, output in zip(texts_2, outputs):
score = output.outputs.score
print(f"Pair: {[text_1, text_2]!r} | Score: {score}")

View File

@ -133,7 +133,7 @@ def run_encode(model: str, modality: QueryModality):
if req_data.image is not None:
mm_data["image"] = req_data.image
outputs = req_data.llm.encode({
outputs = req_data.llm.embed({
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
})

View File

@ -719,14 +719,6 @@ class VllmRunner:
return inputs
def classify(self, prompts: List[str]) -> List[str]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs
def generate(
self,
prompts: List[str],
@ -897,6 +889,10 @@ class VllmRunner:
returned_outputs.append((token_ids, texts))
return returned_outputs
def classify(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs]
def encode(
self,
prompts: List[str],
@ -909,16 +905,16 @@ class VllmRunner:
videos=videos,
audios=audios)
req_outputs = self.model.encode(inputs)
req_outputs = self.model.embed(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]
def score(
self,
text_1: Union[str, List[str]],
text_2: Union[str, List[str]],
) -> List[List[float]]:
) -> List[float]:
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.embedding for req_output in req_outputs]
return [req_output.outputs.score for req_output in req_outputs]
def __enter__(self):
return self

View File

@ -39,8 +39,8 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
assert score.data[0].score[0] <= 0.01
assert score.data[1].score[0] >= 0.9
assert score.data[0].score <= 0.01
assert score.data[1].score >= 0.9
@pytest.mark.asyncio
@ -67,8 +67,8 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
assert score.data[0].score[0] <= 0.01
assert score.data[1].score[0] >= 0.9
assert score.data[0].score <= 0.01
assert score.data[1].score >= 0.9
@pytest.mark.asyncio
@ -90,4 +90,4 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
assert score.id is not None
assert score.data is not None
assert len(score.data) == 1
assert score.data[0].score[0] >= 0.9
assert score.data[0].score >= 0.9

View File

@ -42,7 +42,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
assert len(vllm_outputs) == 1
assert len(hf_outputs) == 1
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
@pytest.mark.parametrize("dtype", ["half"])
@ -63,8 +63,8 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
@pytest.mark.parametrize("dtype", ["half"])
@ -85,5 +85,5 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)

View File

@ -2,7 +2,7 @@ import os
import pytest
from vllm import LLM, PoolingParams, SamplingParams
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from ..utils import fork_new_process_for_each_test
@ -36,9 +36,8 @@ def test_oot_registration_text_generation(dummy_opt_path):
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = PoolingParams()
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.encode(prompts, sampling_params)
outputs = llm.embed(prompts)
for output in outputs:
assert all(v == 0 for v in output.outputs.embedding)

View File

@ -7,8 +7,11 @@ from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
@ -27,6 +30,12 @@ __all__ = [
"CompletionOutput",
"PoolingOutput",
"PoolingRequestOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"ClassificationOutput",
"ClassificationRequestOutput",
"ScoringOutput",
"ScoringRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
@ -34,26 +43,3 @@ __all__ = [
"initialize_ray_cluster",
"PoolingParams",
]
def __getattr__(name: str):
import warnings
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingOutput
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -46,11 +46,10 @@ from vllm.outputs import (PoolingRequestOutput, RequestOutput,
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
ParallelSampleSequenceGroup, Sequence,
SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
@ -966,9 +965,9 @@ class LLMEngine:
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
outputs: List[PoolingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings
seq_group.pooled_data = outputs[0].data
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
@ -1784,8 +1783,8 @@ class LLMEngine:
num_prompt_tokens_iter)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
if model_output and isinstance(model_output[0], SamplerOutput) and (
model_output[0].spec_decode_worker_metrics is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None

View File

@ -26,7 +26,9 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@ -120,7 +122,7 @@ class LLM:
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
"""
DEPRECATE_LEGACY: ClassVar[bool] = False
DEPRECATE_LEGACY: ClassVar[bool] = True
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
@ -257,11 +259,14 @@ class LLM:
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@ -275,6 +280,9 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@ -288,6 +296,9 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@ -302,6 +313,9 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@ -316,6 +330,9 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@ -328,6 +345,9 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@ -678,11 +698,12 @@ class LLM:
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@ -696,6 +717,7 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@ -709,6 +731,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@ -723,6 +746,7 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@ -737,6 +761,7 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@ -749,6 +774,7 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@ -768,7 +794,8 @@ class LLM:
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
"""Generates the completions for the input prompts.
"""Apply pooling to the hidden states corresponding to the input
prompts.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
@ -787,7 +814,7 @@ class LLM:
Returns:
A list of ``PoolingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
pooled hidden states in the same order as the input prompts.
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
@ -833,28 +860,110 @@ class LLM:
return self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
def embed(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""
Generate an embedding vector for each prompt.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "embed":
raise ValueError(
"Embedding API is only enabled for `--task embed`")
items = self.encode(prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
return [EmbeddingRequestOutput.from_base(item) for item in items]
def classify(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ClassificationRequestOutput]:
"""
Generate class logits for each prompt.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of ``ClassificationRequestOutput`` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "classify":
raise ValueError(
"Classification API is only enabled for `--task classify`")
items = self.encode(prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
return [ClassificationRequestOutput.from_base(item) for item in items]
def score(
self,
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>.
) -> List[ScoringRequestOutput]:
"""Generate similarity scores for all pairs ``<text,text_pair>``.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
the text_1 sentence will be replicated N times to pair with the text_2
sentences. The input pairs are used to build a list of prompts for the
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
times to pair with the ``text_2`` sentences.
The input pairs are used to build a list of prompts for the
cross encoder model. This class automatically batches the prompts,
considering the memory constraint. For the best performance, put all
of your texts into a single list and pass it to this method.
Args:
text_1: can be a single prompt or a list of prompts, in which
case it has to have the same length as the text_2 list
case it has to have the same length as the ``text_2`` list
text_2: The texts to pair with the query to form the input
to the LLM. See :class:`~vllm.inputs.PromptType` for
more details about the format of each prompts.
@ -864,7 +973,7 @@ class LLM:
generation, if any.
Returns:
A list of ``PoolingRequestOutput`` objects containing the
A list of ``ScoringRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
runner_type = self.llm_engine.model_config.runner_type
@ -884,6 +993,8 @@ class LLM:
if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support cross encoding")
if self.llm_engine.model_config.task != "score":
raise ValueError("Score API is only enabled for `--task score`")
tokenizer = self.llm_engine.get_tokenizer()
@ -954,8 +1065,10 @@ class LLM:
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def start_profile(self) -> None:
self.llm_engine.start_profile()

View File

@ -900,7 +900,7 @@ class EmbeddingResponse(OpenAIBaseModel):
class ScoreResponseData(OpenAIBaseModel):
index: int
object: str = "score"
score: Union[List[float], str]
score: float
class ScoreResponse(OpenAIBaseModel):

View File

@ -18,14 +18,15 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
def _get_embedding(
output: PoolingOutput,
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
@ -46,8 +47,10 @@ def request_output_to_embedding_response(
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
prompt_token_ids = final_res.prompt_token_ids
embedding = _get_embedding(final_res.outputs, encoding_format)
embedding = _get_embedding(embedding_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)

View File

@ -31,7 +31,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
ModelPermission, ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
UnloadLoraAdapterRequest)
@ -73,7 +73,7 @@ class LoRAModulePath:
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest,
EmbeddingCompletionRequest, ScoreRequest,
TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
@ -567,12 +567,14 @@ class OpenAIServing:
return None
@staticmethod
def _base_request_id(raw_request: Request,
def _base_request_id(raw_request: Optional[Request],
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
return raw_request.headers.get(
"X-Request-Id", default) if raw_request is not None else default
if raw_request is None:
return default
return raw_request.headers.get("X-Request-Id", default)
@staticmethod
def _get_decoded_token(logprob: Logprob,

View File

@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
@ -24,13 +24,13 @@ def request_output_to_score_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
score = None
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
if final_res is not None:
score = final_res.outputs.embedding
score_data = ScoreResponseData(index=idx, score=score)
data.append(score_data)
classify_res = ScoringRequestOutput.from_base(final_res)
score_data = ScoreResponseData(index=idx,
score=classify_res.outputs.score)
data.append(score_data)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,

View File

@ -1,14 +1,16 @@
from enum import IntEnum
from typing import List, Optional
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.config import PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
@ -22,7 +24,7 @@ class PoolingType(IntEnum):
MEAN = 4
class Pooler(nn.Module):
class SimplePooler(nn.Module):
"""A layer that pools specific information from hidden states.
This layer does the following:
@ -35,22 +37,204 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data.
"""
@staticmethod
def from_pooling_type(
pooling_type: PoolingType,
*,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> "SimplePooler":
if pooling_type == PoolingType.LAST:
assert step_tag_id is None and returned_token_ids is None
return LastPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.ALL:
assert step_tag_id is None and returned_token_ids is None
return AllPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.CLS:
assert step_tag_id is None and returned_token_ids is None
return CLSPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.MEAN:
assert step_tag_id is None and returned_token_ids is None
return MeanPool(normalize=normalize, softmax=softmax)
if pooling_type == PoolingType.STEP:
return StepPool(normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids)
assert_never(pooling_type)
def __init__(self, *, normalize: bool, softmax: bool) -> None:
super().__init__()
self.head = PoolerHead(normalize=normalize, softmax=softmax)
def get_prompt_lens(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError
def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
return PoolingSequenceGroupOutput(data)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data)
pooled_outputs = [self.build_output(data) for data in pooled_data]
return PoolerOutput(outputs=pooled_outputs)
class CLSPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices]
class LastPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
return hidden_states[last_token_flat_indices]
class AllPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
return pooled_data
class MeanPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
return (cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
class StepPool(SimplePooler):
def __init__(
self,
pooling_type: PoolingType,
*,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
):
super().__init__()
super().__init__(normalize=normalize, softmax=softmax)
self.pooling_type = pooling_type
self.normalize = normalize
self.softmax = softmax
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
returned_token_ids = self.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
hidden_states = hidden_states[:, returned_token_ids]
step_tag_id = self.step_tag_id
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len, seq_data_i in zip(prompt_lens,
pooling_metadata.seq_data.values()):
pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
offset += prompt_len
pooled_data.append(pooled_data_i)
return pooled_data
class PoolerHead(nn.Module):
def __init__(self, *, normalize: bool, softmax: bool) -> None:
super().__init__()
self.normalize = normalize
self.softmax = softmax
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
if self.normalize:
if isinstance(pooled_data, list):
pooled_data = [
F.normalize(data, p=2, dim=1) for data in pooled_data
]
else:
pooled_data = F.normalize(pooled_data, p=2, dim=1)
if self.softmax:
if isinstance(pooled_data, list):
pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
else:
pooled_data = F.softmax(pooled_data, dim=-1)
return pooled_data
class Pooler(nn.Module):
@classmethod
def from_config_with_defaults(
cls,
@ -60,8 +244,8 @@ class Pooler(nn.Module):
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> "Pooler":
return cls(
) -> SimplePooler:
return SimplePooler.from_pooling_type(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
normalize=pooler_config.normalize
@ -75,85 +259,6 @@ class Pooler(nn.Module):
returned_token_ids,
)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools specific information from hidden states based on metadata."""
prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
if self.pooling_type is PoolingType.CLS:
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
dim=0)[:-1]
pooled_data = hidden_states[first_token_flat_indices]
elif self.pooling_type == PoolingType.LAST:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL:
offset = 0
pooled_data = []
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
pooled_data = (
cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
elif self.pooling_type == PoolingType.STEP:
returned_token_ids = self.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
hidden_states = hidden_states[:, returned_token_ids]
step_tag_id = self.step_tag_id
offset = 0
pooled_data = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
offset += prompt_len
pooled_data.append(pooled_data_i)
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
if self.normalize:
if isinstance(pooled_data, list):
pooled_data = [
nn.functional.normalize(data, p=2, dim=1)
for data in pooled_data
]
else:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
if self.softmax:
if isinstance(pooled_data, list):
pooled_data = [
nn.functional.softmax(data, dim=-1) for data in pooled_data
]
else:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
class CrossEncodingPooler(nn.Module):
"""A layer that pools specific information from hidden states.
@ -208,9 +313,8 @@ class CrossEncodingPooler(nn.Module):
if self.pooler is not None:
# apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output)
logits = self.default_activation_function(pooled_output)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in logits
]
scores = self.default_activation_function(pooled_output).squeeze(-1)
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs)

View File

@ -2,19 +2,20 @@ from array import array
from typing import List, Optional, Union
import torch
from torch import nn
import torch.nn as nn
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors,
PoolerOutput)
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
logger = init_logger(__name__)
@ -52,6 +53,8 @@ class GritLMPooler(nn.Module):
self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])
self.head = PoolerHead(normalize=True, softmax=False)
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
"""
Find the first occurrence of target in arr starting from start_idx.
@ -75,7 +78,7 @@ class GritLMPooler(nn.Module):
return i
return -1
def _get_instruction_len(self, prompt_token_ids: array) -> bool:
def _get_instruction_len(self, prompt_token_ids: array) -> int:
"""
Get the length of the instruction in the prompt.
@ -168,10 +171,10 @@ class GritLMPooler(nn.Module):
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1)
pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1)
pooled_data = self.head(mean_embeddings)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
PoolingSequenceGroupOutput(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)

View File

@ -1,9 +1,13 @@
import time
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, Generic, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
import torch
from typing_extensions import TypeVar
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
@ -57,14 +61,26 @@ class PoolingOutput:
"""The output data of one pooling output of a request.
Args:
embedding: The embedding vector, which is a list of floats. The
length of vector depends on the model as listed in the embedding guide.
data: The extracted hidden states.
"""
embedding: List[float]
data: torch.Tensor
def __repr__(self) -> str:
return (f"PoolingOutput("
f"embedding={len(self.embedding)})")
return (f"PoolingOutput(data={self.data})")
def __eq__(self, other: object) -> bool:
return (isinstance(other, self.__class__) and bool(
(self.data == other.data).all()))
@property
def embedding(self) -> list[float]:
msg = ("`LLM.encode()` now returns raw outputs. "
"To return embeddings, use `LLM.embed()`. "
"To return class probabilities, use `LLM.classify()` "
"and access the `probs` attribute. ")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self.data.tolist()
class RequestOutput:
@ -316,7 +332,10 @@ class RequestOutput:
f"multi_modal_placeholders={self.multi_modal_placeholders})")
class PoolingRequestOutput:
_O = TypeVar("_O", default=PoolingOutput)
class PoolingRequestOutput(Generic[_O]):
"""
The output data of a pooling request to the LLM.
@ -327,24 +346,24 @@ class PoolingRequestOutput:
finished (bool): A flag indicating whether the pooling is completed.
"""
def __init__(self, request_id: str, outputs: "PoolingOutput",
def __init__(self, request_id: str, outputs: _O,
prompt_token_ids: List[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.finished = finished
self.outputs = outputs
@classmethod
def from_seq_group(cls,
seq_group: 'SequenceGroup') -> "PoolingRequestOutput":
if seq_group.embeddings is None:
raise ValueError(
"Embeddings are missing in seq_group for EmbeddingRequest.")
output = PoolingOutput(seq_group.embeddings)
@staticmethod
def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
pooled_data = seq_group.pooled_data
assert pooled_data is not None
output = PoolingOutput(pooled_data)
prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished()
return cls(seq_group.request_id, output, prompt_token_ids, finished)
return PoolingRequestOutput(seq_group.request_id, output,
prompt_token_ids, finished)
def __repr__(self):
"""
@ -356,89 +375,137 @@ class PoolingRequestOutput:
Returns:
str: A string representation of the PoolingRequestOutput instance.
"""
return (f"PoolingRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}, "
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})")
@dataclass
class ScoreOutput:
"""The output data of one completion output of a request.
Args:
score: The score, which is a list of floats.
index: The correspondent text index of the score.
"""
index: int
score: List[float]
def __repr__(self) -> str:
return (f"ScoreOutput("
f"score={self.score}), "
f"index={self.index})")
class ScoreRequestOutput:
"""
The output data of an score request to the LLM.
Args:
request_id (str): A unique identifier for the score request.
outputs (score): The embedding results for the given input.
"""
def __init__(self, request_id: str, outputs: "ScoreOutput"):
self.request_id = request_id
self.outputs = outputs
def __repr__(self):
"""
Returns a string representation of an ScoreRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
Returns:
str: A string representation of the ScoreRequestOutput instance.
"""
return (f"ScoreRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}")
class RequestOutputFactory:
@staticmethod
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
if seq_group.pooled_data is not None:
return PoolingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)
def __getattr__(name: str):
import warnings
@dataclass
class EmbeddingOutput:
"""The output data of one embedding output of a request.
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
Args:
embedding: The embedding vector, which is a list of floats.
Its length depends on the hidden dimension of the model.
"""
embedding: list[float]
warnings.warn(DeprecationWarning(msg), stacklevel=2)
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
if pooled_data.ndim != 1:
raise ValueError("pooled_data should be a 1-D embedding vector")
return PoolingOutput
return EmbeddingOutput(pooled_data.tolist())
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
@property
def hidden_size(self) -> int:
return len(self.embedding)
warnings.warn(DeprecationWarning(msg), stacklevel=2)
def __repr__(self) -> str:
return f"EmbeddingOutput(hidden_size={self.hidden_size})"
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return EmbeddingRequestOutput(
request_id=request_output.request_id,
outputs=EmbeddingOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class ClassificationOutput:
"""The output data of one classification output of a request.
Args:
probs: The probability vector, which is a list of floats.
Its length depends on the number of classes.
"""
probs: list[float]
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
if pooled_data.ndim != 1:
raise ValueError("pooled_data should be a 1-D probability vector")
return ClassificationOutput(pooled_data.tolist())
@property
def num_classes(self) -> int:
return len(self.probs)
def __repr__(self) -> str:
return f"ClassificationOutput(num_classes={self.num_classes})"
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return ClassificationRequestOutput(
request_id=request_output.request_id,
outputs=ClassificationOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class ScoringOutput:
"""The output data of one scoring output of a request.
Args:
score: The similarity score, which is a scalar value.
"""
score: float
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
if pooled_data.ndim != 0:
raise ValueError("pooled_data should be a scalar score")
return ScoringOutput(pooled_data.item())
def __repr__(self) -> str:
return f"ScoringOutput(score={self.score})"
@property
def embedding(self) -> list[float]:
msg = ("`LLM.score()` now returns scalar scores. "
"Please access it via the `score` attribute. ")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return [self.score]
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return ScoringRequestOutput(
request_id=request_output.request_id,
outputs=ScoringOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)

View File

@ -617,10 +617,9 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
embeddings: The embeddings vectors of the prompt of the sequence group
for a pooling model.
pooling_params: The pooling parameters used to generate the pooling
pooling_params: The parameters used to generate the pooler
for a pooling model.
pooled_data: The extracted hidden states from a pooling model.
encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers.
@ -635,8 +634,8 @@ class SequenceGroup:
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
pooled_data: Optional[torch.Tensor] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -658,8 +657,8 @@ class SequenceGroup:
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.pooled_data = pooled_data
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
__metaclass__ = SequenceGroupOutput
"""The model output associated with a completion sequence group."""
__metaclass__ = SequenceGroupOutput
samples: List[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput(
and self.prompt_logprobs == other.prompt_logprobs)
class EmbeddingSequenceGroupOutput(
class PoolingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
):
"""The model output associated with an embedding sequence group."""
"""The model output associated with a pooling sequence group."""
__metaclass__ = SequenceGroupOutput
embeddings: List[int]
# Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data: Any
def __repr__(self) -> str:
return (f"EmbeddingSequenceGroupOutput("
f"embeddings_shape={len(self.embeddings)})")
return f"PoolingSequenceGroupOutput(data={self.data}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, EmbeddingSequenceGroupOutput):
if not isinstance(other, PoolingSequenceGroupOutput):
raise NotImplementedError()
return self.embeddings == other.embeddings
return self.data == other.data
# cannot use msgspec.Struct here because Dynamo does not support it
@ -1085,7 +1085,7 @@ class IntermediateTensors:
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value):
def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value
def __len__(self):
@ -1103,16 +1103,12 @@ class PoolerOutput(
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs: List[EmbeddingSequenceGroupOutput]
outputs: List[PoolingSequenceGroupOutput]
# lazy import to avoid circular import
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value):
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
self.outputs[idx] = value
def __len__(self):
@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
arrival_time=seq_group.arrival_time,
sampling_params=original_params,
lora_request=seq_group.lora_request,
embeddings=seq_group.embeddings,
pooling_params=seq_group.pooling_params,
pooled_data=seq_group.pooled_data,
encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request,