[Frontend] Separate pooling APIs in offline inference (#11129)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f93bf2b189
commit
eeec9e3390
@ -181,14 +181,14 @@ steps:
|
||||
commands:
|
||||
- VLLM_USE_V1=1 pytest -v -s v1
|
||||
|
||||
- label: Examples Test # 15min
|
||||
- label: Examples Test # 25min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
- examples/
|
||||
commands:
|
||||
- pip install awscli tensorizer # for llava example and tensorizer test
|
||||
- pip install tensorizer # for tensorizer test
|
||||
- python3 offline_inference.py
|
||||
- python3 cpu_offload.py
|
||||
- python3 offline_inference_chat.py
|
||||
@ -198,6 +198,9 @@ steps:
|
||||
- python3 offline_inference_vision_language_multi_image.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
- python3 offline_inference_classification.py
|
||||
- python3 offline_inference_embedding.py
|
||||
- python3 offline_inference_scoring.py
|
||||
- python3 offline_profile.py --model facebook/opt-125m
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
|
@ -6,7 +6,7 @@ Pooling Models
|
||||
vLLM also supports pooling models, including embedding, reranking and reward models.
|
||||
|
||||
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
|
||||
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
|
||||
These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input
|
||||
before returning them.
|
||||
|
||||
.. note::
|
||||
@ -45,20 +45,48 @@ which takes priority over both the model's and Sentence Transformers's defaults.
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
|
||||
It returns the aggregated hidden states directly.
|
||||
It returns the extracted hidden states directly, which is useful for reward models.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward")
|
||||
output, = llm.encode("Hello, my name is")
|
||||
|
||||
data = output.outputs.data
|
||||
print(f"Prompt: {prompt!r} | Data: {data!r}")
|
||||
|
||||
``LLM.embed``
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt.
|
||||
It is primarily designed for embedding models.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
|
||||
outputs = llm.encode("Hello, my name is")
|
||||
output, = llm.embed("Hello, my name is")
|
||||
|
||||
outputs = model.encode(prompts)
|
||||
for output in outputs:
|
||||
embeddings = output.outputs.embedding
|
||||
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
|
||||
embeds = output.outputs.embedding
|
||||
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
|
||||
|
||||
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.
|
||||
|
||||
``LLM.classify``
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt.
|
||||
It is primarily designed for classification models.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify")
|
||||
output, = llm.classify("Hello, my name is")
|
||||
|
||||
probs = output.outputs.probs
|
||||
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
|
||||
|
||||
A code example can be found in `examples/offline_inference_classification.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_classification.py>`_.
|
||||
|
||||
``LLM.score``
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
@ -71,7 +99,16 @@ These types of models serve as rerankers between candidate query-document pairs
|
||||
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
|
||||
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.
|
||||
|
||||
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score")
|
||||
output, = llm.score("What is the capital of France?",
|
||||
"The capital of Brazil is Brasilia.")
|
||||
|
||||
score = output.outputs.score
|
||||
print(f"Score: {score}")
|
||||
|
||||
A code example can be found in `examples/offline_inference_scoring.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_scoring.py>`_.
|
||||
|
||||
Online Inference
|
||||
----------------
|
||||
|
28
examples/offline_inference_classification.py
Normal file
28
examples/offline_inference_classification.py
Normal file
@ -0,0 +1,28 @@
|
||||
from vllm import LLM
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass task="classify" for classification models
|
||||
model = LLM(
|
||||
model="jason9693/Qwen2.5-1.5B-apeach",
|
||||
task="classify",
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
# Generate logits. The output is a list of ClassificationRequestOutputs.
|
||||
outputs = model.classify(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
probs = output.outputs.probs
|
||||
probs_trimmed = ((str(probs[:16])[:-1] +
|
||||
", ...]") if len(probs) > 16 else probs)
|
||||
print(f"Prompt: {prompt!r} | "
|
||||
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")
|
@ -9,14 +9,20 @@ prompts = [
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass task="embed" for embedding models
|
||||
model = LLM(
|
||||
model="intfloat/e5-mistral-7b-instruct",
|
||||
task="embed", # You should pass task="embed" for embedding models
|
||||
task="embed",
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
# Generate embedding. The output is a list of PoolingRequestOutputs.
|
||||
outputs = model.encode(prompts)
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = model.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
print(output.outputs.embedding) # list of 4096 floats
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} | "
|
||||
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
|
23
examples/offline_inference_scoring.py
Normal file
23
examples/offline_inference_scoring.py
Normal file
@ -0,0 +1,23 @@
|
||||
from vllm import LLM
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass task="score" for cross-encoder models
|
||||
model = LLM(
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
task="score",
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
# Generate scores. The output is a list of ScoringRequestOutputs.
|
||||
outputs = model.score(text_1, texts_2)
|
||||
|
||||
# Print the outputs.
|
||||
for text_2, output in zip(texts_2, outputs):
|
||||
score = output.outputs.score
|
||||
print(f"Pair: {[text_1, text_2]!r} | Score: {score}")
|
@ -133,7 +133,7 @@ def run_encode(model: str, modality: QueryModality):
|
||||
if req_data.image is not None:
|
||||
mm_data["image"] = req_data.image
|
||||
|
||||
outputs = req_data.llm.encode({
|
||||
outputs = req_data.llm.embed({
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
})
|
||||
|
@ -719,14 +719,6 @@ class VllmRunner:
|
||||
|
||||
return inputs
|
||||
|
||||
def classify(self, prompts: List[str]) -> List[str]:
|
||||
req_outputs = self.model.encode(prompts)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
embedding = req_output.outputs.embedding
|
||||
outputs.append(embedding)
|
||||
return outputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -897,6 +889,10 @@ class VllmRunner:
|
||||
returned_outputs.append((token_ids, texts))
|
||||
return returned_outputs
|
||||
|
||||
def classify(self, prompts: List[str]) -> List[List[float]]:
|
||||
req_outputs = self.model.classify(prompts)
|
||||
return [req_output.outputs.probs for req_output in req_outputs]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -909,16 +905,16 @@ class VllmRunner:
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
req_outputs = self.model.encode(inputs)
|
||||
req_outputs = self.model.embed(inputs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[str, List[str]],
|
||||
text_2: Union[str, List[str]],
|
||||
) -> List[List[float]]:
|
||||
) -> List[float]:
|
||||
req_outputs = self.model.score(text_1, text_2)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
return [req_output.outputs.score for req_output in req_outputs]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -39,8 +39,8 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
assert score.data[0].score[0] <= 0.01
|
||||
assert score.data[1].score[0] >= 0.9
|
||||
assert score.data[0].score <= 0.01
|
||||
assert score.data[1].score >= 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -67,8 +67,8 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
assert score.data[0].score[0] <= 0.01
|
||||
assert score.data[1].score[0] >= 0.9
|
||||
assert score.data[0].score <= 0.01
|
||||
assert score.data[1].score >= 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -90,4 +90,4 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 1
|
||||
assert score.data[0].score[0] >= 0.9
|
||||
assert score.data[0].score >= 0.9
|
||||
|
@ -42,7 +42,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
assert len(vllm_outputs) == 1
|
||||
assert len(hf_outputs) == 1
|
||||
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@ -63,8 +63,8 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
assert len(vllm_outputs) == 2
|
||||
assert len(hf_outputs) == 2
|
||||
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@ -85,5 +85,5 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
assert len(vllm_outputs) == 2
|
||||
assert len(hf_outputs) == 2
|
||||
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, PoolingParams, SamplingParams
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
@ -36,9 +36,8 @@ def test_oot_registration_text_generation(dummy_opt_path):
|
||||
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
|
||||
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
||||
prompts = ["Hello, my name is", "The text does not matter"]
|
||||
sampling_params = PoolingParams()
|
||||
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
|
||||
outputs = llm.encode(prompts, sampling_params)
|
||||
outputs = llm.embed(prompts)
|
||||
|
||||
for output in outputs:
|
||||
assert all(v == 0 for v in output.outputs.embedding)
|
||||
|
@ -7,8 +7,11 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (CompletionOutput, PoolingOutput,
|
||||
PoolingRequestOutput, RequestOutput)
|
||||
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
|
||||
CompletionOutput, EmbeddingOutput,
|
||||
EmbeddingRequestOutput, PoolingOutput,
|
||||
PoolingRequestOutput, RequestOutput, ScoringOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
@ -27,6 +30,12 @@ __all__ = [
|
||||
"CompletionOutput",
|
||||
"PoolingOutput",
|
||||
"PoolingRequestOutput",
|
||||
"EmbeddingOutput",
|
||||
"EmbeddingRequestOutput",
|
||||
"ClassificationOutput",
|
||||
"ClassificationRequestOutput",
|
||||
"ScoringOutput",
|
||||
"ScoringRequestOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
@ -34,26 +43,3 @@ __all__ = [
|
||||
"initialize_ray_cluster",
|
||||
"PoolingParams",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
import warnings
|
||||
|
||||
if name == "EmbeddingOutput":
|
||||
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return PoolingOutput
|
||||
|
||||
if name == "EmbeddingRequestOutput":
|
||||
msg = ("EmbeddingRequestOutput has been renamed to "
|
||||
"PoolingRequestOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return PoolingRequestOutput
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
@ -46,11 +46,10 @@ from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
ParallelSampleSequenceGroup, Sequence,
|
||||
SequenceGroup, SequenceGroupBase,
|
||||
SequenceGroupMetadata, SequenceGroupOutput,
|
||||
SequenceStatus)
|
||||
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupBase, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
@ -966,9 +965,9 @@ class LLMEngine:
|
||||
@staticmethod
|
||||
def _process_sequence_group_outputs(
|
||||
seq_group: SequenceGroup,
|
||||
outputs: List[EmbeddingSequenceGroupOutput],
|
||||
outputs: List[PoolingSequenceGroupOutput],
|
||||
) -> None:
|
||||
seq_group.embeddings = outputs[0].embeddings
|
||||
seq_group.pooled_data = outputs[0].data
|
||||
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
@ -1784,8 +1783,8 @@ class LLMEngine:
|
||||
num_prompt_tokens_iter)
|
||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||
# sampler output.
|
||||
if model_output and (model_output[0].spec_decode_worker_metrics
|
||||
is not None):
|
||||
if model_output and isinstance(model_output[0], SamplerOutput) and (
|
||||
model_output[0].spec_decode_worker_metrics is not None):
|
||||
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
||||
else:
|
||||
spec_decode_metrics = None
|
||||
|
@ -26,7 +26,9 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
@ -120,7 +122,7 @@ class LLM:
|
||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||
"""
|
||||
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = True
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
|
||||
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
|
||||
@ -257,11 +259,14 @@ class LLM:
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -275,6 +280,9 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -288,6 +296,9 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -302,6 +313,9 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -316,6 +330,9 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -328,6 +345,9 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -678,11 +698,12 @@ class LLM:
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -696,6 +717,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -709,6 +731,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -723,6 +746,7 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -737,6 +761,7 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -749,6 +774,7 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -768,7 +794,8 @@ class LLM:
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
"""Apply pooling to the hidden states corresponding to the input
|
||||
prompts.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
@ -787,7 +814,7 @@ class LLM:
|
||||
|
||||
Returns:
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated embeddings in the same order as the input prompts.
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
|
||||
@ -833,28 +860,110 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""
|
||||
Generate an embedding vector for each prompt.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
if self.llm_engine.model_config.task != "embed":
|
||||
raise ValueError(
|
||||
"Embedding API is only enabled for `--task embed`")
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def classify(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[ClassificationRequestOutput]:
|
||||
"""
|
||||
Generate class logits for each prompt.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``ClassificationRequestOutput`` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
if self.llm_engine.model_config.task != "classify":
|
||||
raise ValueError(
|
||||
"Classification API is only enabled for `--task classify`")
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return [ClassificationRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
/,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||
) -> List[ScoringRequestOutput]:
|
||||
"""Generate similarity scores for all pairs ``<text,text_pair>``.
|
||||
|
||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||
the text_1 sentence will be replicated N times to pair with the text_2
|
||||
sentences. The input pairs are used to build a list of prompts for the
|
||||
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
|
||||
In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
|
||||
times to pair with the ``text_2`` sentences.
|
||||
The input pairs are used to build a list of prompts for the
|
||||
cross encoder model. This class automatically batches the prompts,
|
||||
considering the memory constraint. For the best performance, put all
|
||||
of your texts into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
text_1: can be a single prompt or a list of prompts, in which
|
||||
case it has to have the same length as the text_2 list
|
||||
case it has to have the same length as the ``text_2`` list
|
||||
text_2: The texts to pair with the query to form the input
|
||||
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
||||
more details about the format of each prompts.
|
||||
@ -864,7 +973,7 @@ class LLM:
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
A list of ``ScoringRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
@ -884,6 +993,8 @@ class LLM:
|
||||
|
||||
if not self.llm_engine.model_config.is_cross_encoder:
|
||||
raise ValueError("Your model does not support cross encoding")
|
||||
if self.llm_engine.model_config.task != "score":
|
||||
raise ValueError("Score API is only enabled for `--task score`")
|
||||
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
|
||||
@ -954,8 +1065,10 @@ class LLM:
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
items = self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
@ -900,7 +900,7 @@ class EmbeddingResponse(OpenAIBaseModel):
|
||||
class ScoreResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "score"
|
||||
score: Union[List[float], str]
|
||||
score: float
|
||||
|
||||
|
||||
class ScoreResponse(OpenAIBaseModel):
|
||||
|
@ -18,14 +18,15 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: PoolingOutput,
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[List[float], str]:
|
||||
if encoding_format == "float":
|
||||
@ -46,8 +47,10 @@ def request_output_to_embedding_response(
|
||||
data: List[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
embedding_res = EmbeddingRequestOutput.from_base(final_res)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
embedding = _get_embedding(final_res.outputs, encoding_format)
|
||||
|
||||
embedding = _get_embedding(embedding_res.outputs, encoding_format)
|
||||
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||
data.append(embedding_data)
|
||||
|
||||
|
@ -31,7 +31,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
ModelPermission, ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
@ -73,7 +73,7 @@ class LoRAModulePath:
|
||||
|
||||
|
||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingCompletionRequest, ScoreRequest,
|
||||
TokenizeCompletionRequest]
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
@ -567,12 +567,14 @@ class OpenAIServing:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _base_request_id(raw_request: Request,
|
||||
def _base_request_id(raw_request: Optional[Request],
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""Pulls the request id to use from a header, if provided"""
|
||||
default = default or random_uuid()
|
||||
return raw_request.headers.get(
|
||||
"X-Request-Id", default) if raw_request is not None else default
|
||||
if raw_request is None:
|
||||
return default
|
||||
|
||||
return raw_request.headers.get("X-Request-Id", default)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(logprob: Logprob,
|
||||
|
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
@ -24,13 +24,13 @@ def request_output_to_score_response(
|
||||
final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str) -> ScoreResponse:
|
||||
data: List[ScoreResponseData] = []
|
||||
score = None
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
if final_res is not None:
|
||||
score = final_res.outputs.embedding
|
||||
score_data = ScoreResponseData(index=idx, score=score)
|
||||
data.append(score_data)
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
score_data = ScoreResponseData(index=idx,
|
||||
score=classify_res.outputs.score)
|
||||
data.append(score_data)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
|
@ -1,14 +1,16 @@
|
||||
from enum import IntEnum
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
@ -22,7 +24,7 @@ class PoolingType(IntEnum):
|
||||
MEAN = 4
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
class SimplePooler(nn.Module):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
@ -35,22 +37,204 @@ class Pooler(nn.Module):
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_pooling_type(
|
||||
pooling_type: PoolingType,
|
||||
*,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
) -> "SimplePooler":
|
||||
if pooling_type == PoolingType.LAST:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return LastPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.ALL:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return AllPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.CLS:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return CLSPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.MEAN:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
return MeanPool(normalize=normalize, softmax=softmax)
|
||||
if pooling_type == PoolingType.STEP:
|
||||
return StepPool(normalize=normalize,
|
||||
softmax=softmax,
|
||||
step_tag_id=step_tag_id,
|
||||
returned_token_ids=returned_token_ids)
|
||||
|
||||
assert_never(pooling_type)
|
||||
|
||||
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head = PoolerHead(normalize=normalize, softmax=softmax)
|
||||
|
||||
def get_prompt_lens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
|
||||
return PoolingSequenceGroupOutput(data)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data)
|
||||
pooled_outputs = [self.build_output(data) for data in pooled_data]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
|
||||
class CLSPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||
return hidden_states[first_token_flat_indices]
|
||||
|
||||
|
||||
class LastPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||
return hidden_states[last_token_flat_indices]
|
||||
|
||||
|
||||
class AllPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||
offset += prompt_len
|
||||
|
||||
return pooled_data
|
||||
|
||||
|
||||
class MeanPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
cumsum = torch.cumsum(hidden_states, dim=0)
|
||||
start_indices = torch.cat([
|
||||
torch.tensor([0], device=hidden_states.device),
|
||||
torch.cumsum(prompt_lens[:-1], dim=0)
|
||||
])
|
||||
end_indices = torch.cumsum(prompt_lens, dim=0)
|
||||
return (cumsum[end_indices - 1] - cumsum[start_indices] +
|
||||
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
class StepPool(SimplePooler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooling_type: PoolingType,
|
||||
*,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(normalize=normalize, softmax=softmax)
|
||||
|
||||
self.pooling_type = pooling_type
|
||||
self.normalize = normalize
|
||||
self.softmax = softmax
|
||||
self.step_tag_id = step_tag_id
|
||||
self.returned_token_ids = returned_token_ids
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
returned_token_ids = self.returned_token_ids
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
hidden_states = hidden_states[:, returned_token_ids]
|
||||
|
||||
step_tag_id = self.step_tag_id
|
||||
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
for prompt_len, seq_data_i in zip(prompt_lens,
|
||||
pooling_metadata.seq_data.values()):
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
if step_tag_id is not None:
|
||||
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
|
||||
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
||||
|
||||
offset += prompt_len
|
||||
pooled_data.append(pooled_data_i)
|
||||
|
||||
return pooled_data
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
|
||||
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.normalize = normalize
|
||||
self.softmax = softmax
|
||||
|
||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
|
||||
if self.normalize:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
F.normalize(data, p=2, dim=1) for data in pooled_data
|
||||
]
|
||||
else:
|
||||
pooled_data = F.normalize(pooled_data, p=2, dim=1)
|
||||
|
||||
if self.softmax:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
|
||||
else:
|
||||
pooled_data = F.softmax(pooled_data, dim=-1)
|
||||
|
||||
return pooled_data
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_config_with_defaults(
|
||||
cls,
|
||||
@ -60,8 +244,8 @@ class Pooler(nn.Module):
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
) -> "Pooler":
|
||||
return cls(
|
||||
) -> SimplePooler:
|
||||
return SimplePooler.from_pooling_type(
|
||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
||||
if pooler_config.pooling_type is not None else pooling_type,
|
||||
normalize=pooler_config.normalize
|
||||
@ -75,85 +259,6 @@ class Pooler(nn.Module):
|
||||
returned_token_ids,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
"""Pools specific information from hidden states based on metadata."""
|
||||
|
||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
if self.pooling_type is PoolingType.CLS:
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
|
||||
dim=0)[:-1]
|
||||
pooled_data = hidden_states[first_token_flat_indices]
|
||||
elif self.pooling_type == PoolingType.LAST:
|
||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||
pooled_data = hidden_states[last_token_flat_indices]
|
||||
elif self.pooling_type == PoolingType.ALL:
|
||||
offset = 0
|
||||
pooled_data = []
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||
offset += prompt_len
|
||||
elif self.pooling_type == PoolingType.MEAN:
|
||||
# Calculate mean pooling
|
||||
cumsum = torch.cumsum(hidden_states, dim=0)
|
||||
start_indices = torch.cat([
|
||||
torch.tensor([0], device=hidden_states.device),
|
||||
torch.cumsum(prompt_lens[:-1], dim=0)
|
||||
])
|
||||
end_indices = torch.cumsum(prompt_lens, dim=0)
|
||||
pooled_data = (
|
||||
cumsum[end_indices - 1] - cumsum[start_indices] +
|
||||
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
||||
elif self.pooling_type == PoolingType.STEP:
|
||||
returned_token_ids = self.returned_token_ids
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
hidden_states = hidden_states[:, returned_token_ids]
|
||||
|
||||
step_tag_id = self.step_tag_id
|
||||
|
||||
offset = 0
|
||||
pooled_data = []
|
||||
for prompt_len, seq_data_i in zip(
|
||||
prompt_lens, pooling_metadata.seq_data.values()):
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
if step_tag_id is not None:
|
||||
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
|
||||
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
||||
|
||||
offset += prompt_len
|
||||
pooled_data.append(pooled_data_i)
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
|
||||
if self.normalize:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
nn.functional.normalize(data, p=2, dim=1)
|
||||
for data in pooled_data
|
||||
]
|
||||
else:
|
||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||
|
||||
if self.softmax:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
nn.functional.softmax(data, dim=-1) for data in pooled_data
|
||||
]
|
||||
else:
|
||||
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||
]
|
||||
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
|
||||
class CrossEncodingPooler(nn.Module):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
@ -208,9 +313,8 @@ class CrossEncodingPooler(nn.Module):
|
||||
if self.pooler is not None:
|
||||
# apply classifier once on the full batch if possible
|
||||
pooled_output = self.classifier(pooled_output)
|
||||
logits = self.default_activation_function(pooled_output)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in logits
|
||||
]
|
||||
scores = self.default_activation_function(pooled_output).squeeze(-1)
|
||||
|
||||
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
@ -2,19 +2,20 @@ from array import array
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn as nn
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.xformers import XFormersImpl
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import PoolerHead
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors,
|
||||
PoolerOutput)
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -52,6 +53,8 @@ class GritLMPooler(nn.Module):
|
||||
self.embed_pattern_ids = tokens_to_ids(
|
||||
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
||||
|
||||
self.head = PoolerHead(normalize=True, softmax=False)
|
||||
|
||||
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
||||
"""
|
||||
Find the first occurrence of target in arr starting from start_idx.
|
||||
@ -75,7 +78,7 @@ class GritLMPooler(nn.Module):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def _get_instruction_len(self, prompt_token_ids: array) -> bool:
|
||||
def _get_instruction_len(self, prompt_token_ids: array) -> int:
|
||||
"""
|
||||
Get the length of the instruction in the prompt.
|
||||
|
||||
@ -168,10 +171,10 @@ class GritLMPooler(nn.Module):
|
||||
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
||||
1)
|
||||
|
||||
pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1)
|
||||
pooled_data = self.head(mean_embeddings)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||
PoolingSequenceGroupOutput(data) for data in pooled_data
|
||||
]
|
||||
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
225
vllm/outputs.py
225
vllm/outputs.py
@ -1,9 +1,13 @@
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, Generic, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
@ -57,14 +61,26 @@ class PoolingOutput:
|
||||
"""The output data of one pooling output of a request.
|
||||
|
||||
Args:
|
||||
embedding: The embedding vector, which is a list of floats. The
|
||||
length of vector depends on the model as listed in the embedding guide.
|
||||
data: The extracted hidden states.
|
||||
"""
|
||||
embedding: List[float]
|
||||
data: torch.Tensor
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"PoolingOutput("
|
||||
f"embedding={len(self.embedding)})")
|
||||
return (f"PoolingOutput(data={self.data})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (isinstance(other, self.__class__) and bool(
|
||||
(self.data == other.data).all()))
|
||||
|
||||
@property
|
||||
def embedding(self) -> list[float]:
|
||||
msg = ("`LLM.encode()` now returns raw outputs. "
|
||||
"To return embeddings, use `LLM.embed()`. "
|
||||
"To return class probabilities, use `LLM.classify()` "
|
||||
"and access the `probs` attribute. ")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
return self.data.tolist()
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
@ -316,7 +332,10 @@ class RequestOutput:
|
||||
f"multi_modal_placeholders={self.multi_modal_placeholders})")
|
||||
|
||||
|
||||
class PoolingRequestOutput:
|
||||
_O = TypeVar("_O", default=PoolingOutput)
|
||||
|
||||
|
||||
class PoolingRequestOutput(Generic[_O]):
|
||||
"""
|
||||
The output data of a pooling request to the LLM.
|
||||
|
||||
@ -327,24 +346,24 @@ class PoolingRequestOutput:
|
||||
finished (bool): A flag indicating whether the pooling is completed.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, outputs: "PoolingOutput",
|
||||
def __init__(self, request_id: str, outputs: _O,
|
||||
prompt_token_ids: List[int], finished: bool):
|
||||
self.request_id = request_id
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.finished = finished
|
||||
self.outputs = outputs
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls,
|
||||
seq_group: 'SequenceGroup') -> "PoolingRequestOutput":
|
||||
if seq_group.embeddings is None:
|
||||
raise ValueError(
|
||||
"Embeddings are missing in seq_group for EmbeddingRequest.")
|
||||
output = PoolingOutput(seq_group.embeddings)
|
||||
@staticmethod
|
||||
def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
|
||||
pooled_data = seq_group.pooled_data
|
||||
assert pooled_data is not None
|
||||
|
||||
output = PoolingOutput(pooled_data)
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
finished = seq_group.is_finished()
|
||||
|
||||
return cls(seq_group.request_id, output, prompt_token_ids, finished)
|
||||
return PoolingRequestOutput(seq_group.request_id, output,
|
||||
prompt_token_ids, finished)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
@ -356,89 +375,137 @@ class PoolingRequestOutput:
|
||||
Returns:
|
||||
str: A string representation of the PoolingRequestOutput instance.
|
||||
"""
|
||||
return (f"PoolingRequestOutput(request_id='{self.request_id}', "
|
||||
f"outputs={repr(self.outputs)}, "
|
||||
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
|
||||
f"outputs={self.outputs!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"finished={self.finished})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoreOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
Args:
|
||||
score: The score, which is a list of floats.
|
||||
index: The correspondent text index of the score.
|
||||
"""
|
||||
index: int
|
||||
score: List[float]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"ScoreOutput("
|
||||
f"score={self.score}), "
|
||||
f"index={self.index})")
|
||||
|
||||
|
||||
class ScoreRequestOutput:
|
||||
"""
|
||||
The output data of an score request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id (str): A unique identifier for the score request.
|
||||
outputs (score): The embedding results for the given input.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, outputs: "ScoreOutput"):
|
||||
self.request_id = request_id
|
||||
self.outputs = outputs
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Returns a string representation of an ScoreRequestOutput instance.
|
||||
|
||||
The representation includes the request_id and the number of outputs,
|
||||
providing a quick overview of the embedding request's results.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the ScoreRequestOutput instance.
|
||||
"""
|
||||
return (f"ScoreRequestOutput(request_id='{self.request_id}', "
|
||||
f"outputs={repr(self.outputs)}")
|
||||
|
||||
|
||||
class RequestOutputFactory:
|
||||
|
||||
@staticmethod
|
||||
def create(seq_group: SequenceGroup,
|
||||
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
|
||||
use_cache: bool = False):
|
||||
# Determine the type based on a condition, for example:
|
||||
if hasattr(seq_group,
|
||||
'embeddings') and seq_group.embeddings is not None:
|
||||
if seq_group.pooled_data is not None:
|
||||
return PoolingRequestOutput.from_seq_group(seq_group)
|
||||
else:
|
||||
return RequestOutput.from_seq_group(seq_group, use_cache,
|
||||
seq_id_to_seq_group)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
import warnings
|
||||
@dataclass
|
||||
class EmbeddingOutput:
|
||||
"""The output data of one embedding output of a request.
|
||||
|
||||
if name == "EmbeddingOutput":
|
||||
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
Args:
|
||||
embedding: The embedding vector, which is a list of floats.
|
||||
Its length depends on the hidden dimension of the model.
|
||||
"""
|
||||
embedding: list[float]
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
@staticmethod
|
||||
def from_base(pooling_output: PoolingOutput):
|
||||
pooled_data = pooling_output.data
|
||||
if pooled_data.ndim != 1:
|
||||
raise ValueError("pooled_data should be a 1-D embedding vector")
|
||||
|
||||
return PoolingOutput
|
||||
return EmbeddingOutput(pooled_data.tolist())
|
||||
|
||||
if name == "EmbeddingRequestOutput":
|
||||
msg = ("EmbeddingRequestOutput has been renamed to "
|
||||
"PoolingRequestOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return len(self.embedding)
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
def __repr__(self) -> str:
|
||||
return f"EmbeddingOutput(hidden_size={self.hidden_size})"
|
||||
|
||||
return PoolingRequestOutput
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
|
||||
|
||||
@staticmethod
|
||||
def from_base(request_output: PoolingRequestOutput):
|
||||
return EmbeddingRequestOutput(
|
||||
request_id=request_output.request_id,
|
||||
outputs=EmbeddingOutput.from_base(request_output.outputs),
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassificationOutput:
|
||||
"""The output data of one classification output of a request.
|
||||
|
||||
Args:
|
||||
probs: The probability vector, which is a list of floats.
|
||||
Its length depends on the number of classes.
|
||||
"""
|
||||
probs: list[float]
|
||||
|
||||
@staticmethod
|
||||
def from_base(pooling_output: PoolingOutput):
|
||||
pooled_data = pooling_output.data
|
||||
if pooled_data.ndim != 1:
|
||||
raise ValueError("pooled_data should be a 1-D probability vector")
|
||||
|
||||
return ClassificationOutput(pooled_data.tolist())
|
||||
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
return len(self.probs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ClassificationOutput(num_classes={self.num_classes})"
|
||||
|
||||
|
||||
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
|
||||
|
||||
@staticmethod
|
||||
def from_base(request_output: PoolingRequestOutput):
|
||||
return ClassificationRequestOutput(
|
||||
request_id=request_output.request_id,
|
||||
outputs=ClassificationOutput.from_base(request_output.outputs),
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoringOutput:
|
||||
"""The output data of one scoring output of a request.
|
||||
|
||||
Args:
|
||||
score: The similarity score, which is a scalar value.
|
||||
"""
|
||||
score: float
|
||||
|
||||
@staticmethod
|
||||
def from_base(pooling_output: PoolingOutput):
|
||||
pooled_data = pooling_output.data
|
||||
if pooled_data.ndim != 0:
|
||||
raise ValueError("pooled_data should be a scalar score")
|
||||
|
||||
return ScoringOutput(pooled_data.item())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ScoringOutput(score={self.score})"
|
||||
|
||||
@property
|
||||
def embedding(self) -> list[float]:
|
||||
msg = ("`LLM.score()` now returns scalar scores. "
|
||||
"Please access it via the `score` attribute. ")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
return [self.score]
|
||||
|
||||
|
||||
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
|
||||
|
||||
@staticmethod
|
||||
def from_base(request_output: PoolingRequestOutput):
|
||||
return ScoringRequestOutput(
|
||||
request_id=request_output.request_id,
|
||||
outputs=ScoringOutput.from_base(request_output.outputs),
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
@ -617,10 +617,9 @@ class SequenceGroup:
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
arrival_time: The arrival time of the request.
|
||||
lora_request: LoRA request.
|
||||
embeddings: The embeddings vectors of the prompt of the sequence group
|
||||
for a pooling model.
|
||||
pooling_params: The pooling parameters used to generate the pooling
|
||||
pooling_params: The parameters used to generate the pooler
|
||||
for a pooling model.
|
||||
pooled_data: The extracted hidden states from a pooling model.
|
||||
encoder_seq: Optional, the single encoder sequence. Should be None
|
||||
unless you are working with an encoder/decoder model.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
@ -635,8 +634,8 @@ class SequenceGroup:
|
||||
arrival_time: float,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
pooled_data: Optional[torch.Tensor] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@ -658,8 +657,8 @@ class SequenceGroup:
|
||||
self.lora_request = lora_request
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
self.state = SequenceGroupState()
|
||||
self.embeddings = embeddings
|
||||
self.pooling_params = pooling_params
|
||||
self.pooled_data = pooled_data
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.encoder_seq = encoder_seq
|
||||
self.trace_headers = trace_headers
|
||||
@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
"""The model output associated with a completion sequence group."""
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
samples: List[SequenceOutput]
|
||||
# Prompt logprob for each prompt query token.
|
||||
prompt_logprobs: Optional[PromptLogprobs]
|
||||
@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput(
|
||||
and self.prompt_logprobs == other.prompt_logprobs)
|
||||
|
||||
|
||||
class EmbeddingSequenceGroupOutput(
|
||||
class PoolingSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
):
|
||||
"""The model output associated with an embedding sequence group."""
|
||||
"""The model output associated with a pooling sequence group."""
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
embeddings: List[int]
|
||||
# Annotated as Any to be compatible with msgspec
|
||||
# The actual type is in SequenceGroup.pooled_data
|
||||
data: Any
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"EmbeddingSequenceGroupOutput("
|
||||
f"embeddings_shape={len(self.embeddings)})")
|
||||
return f"PoolingSequenceGroupOutput(data={self.data}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, EmbeddingSequenceGroupOutput):
|
||||
if not isinstance(other, PoolingSequenceGroupOutput):
|
||||
raise NotImplementedError()
|
||||
return self.embeddings == other.embeddings
|
||||
return self.data == other.data
|
||||
|
||||
|
||||
# cannot use msgspec.Struct here because Dynamo does not support it
|
||||
@ -1085,7 +1085,7 @@ class IntermediateTensors:
|
||||
elif isinstance(key, slice):
|
||||
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
||||
|
||||
def __setitem__(self, key: str, value):
|
||||
def __setitem__(self, key: str, value: torch.Tensor):
|
||||
self.tensors[key] = value
|
||||
|
||||
def __len__(self):
|
||||
@ -1103,16 +1103,12 @@ class PoolerOutput(
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The output from a pooling operation in the pooling model."""
|
||||
outputs: List[EmbeddingSequenceGroupOutput]
|
||||
outputs: List[PoolingSequenceGroupOutput]
|
||||
|
||||
# lazy import to avoid circular import
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
|
||||
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value):
|
||||
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __len__(self):
|
||||
@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
arrival_time=seq_group.arrival_time,
|
||||
sampling_params=original_params,
|
||||
lora_request=seq_group.lora_request,
|
||||
embeddings=seq_group.embeddings,
|
||||
pooling_params=seq_group.pooling_params,
|
||||
pooled_data=seq_group.pooled_data,
|
||||
encoder_seq=seq_group.encoder_seq,
|
||||
trace_headers=seq_group.trace_headers,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
|
Loading…
x
Reference in New Issue
Block a user