[Frontend] Separate pooling APIs in offline inference (#11129)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f93bf2b189
commit
eeec9e3390
@ -181,14 +181,14 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- VLLM_USE_V1=1 pytest -v -s v1
|
- VLLM_USE_V1=1 pytest -v -s v1
|
||||||
|
|
||||||
- label: Examples Test # 15min
|
- label: Examples Test # 25min
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/entrypoints
|
- vllm/entrypoints
|
||||||
- examples/
|
- examples/
|
||||||
commands:
|
commands:
|
||||||
- pip install awscli tensorizer # for llava example and tensorizer test
|
- pip install tensorizer # for tensorizer test
|
||||||
- python3 offline_inference.py
|
- python3 offline_inference.py
|
||||||
- python3 cpu_offload.py
|
- python3 cpu_offload.py
|
||||||
- python3 offline_inference_chat.py
|
- python3 offline_inference_chat.py
|
||||||
@ -198,6 +198,9 @@ steps:
|
|||||||
- python3 offline_inference_vision_language_multi_image.py
|
- python3 offline_inference_vision_language_multi_image.py
|
||||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
- python3 offline_inference_encoder_decoder.py
|
- python3 offline_inference_encoder_decoder.py
|
||||||
|
- python3 offline_inference_classification.py
|
||||||
|
- python3 offline_inference_embedding.py
|
||||||
|
- python3 offline_inference_scoring.py
|
||||||
- python3 offline_profile.py --model facebook/opt-125m
|
- python3 offline_profile.py --model facebook/opt-125m
|
||||||
|
|
||||||
- label: Prefix Caching Test # 9min
|
- label: Prefix Caching Test # 9min
|
||||||
|
@ -6,7 +6,7 @@ Pooling Models
|
|||||||
vLLM also supports pooling models, including embedding, reranking and reward models.
|
vLLM also supports pooling models, including embedding, reranking and reward models.
|
||||||
|
|
||||||
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
|
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
|
||||||
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
|
These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input
|
||||||
before returning them.
|
before returning them.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -45,20 +45,48 @@ which takes priority over both the model's and Sentence Transformers's defaults.
|
|||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
|
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
|
||||||
It returns the aggregated hidden states directly.
|
It returns the extracted hidden states directly, which is useful for reward models.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward")
|
||||||
|
output, = llm.encode("Hello, my name is")
|
||||||
|
|
||||||
|
data = output.outputs.data
|
||||||
|
print(f"Prompt: {prompt!r} | Data: {data!r}")
|
||||||
|
|
||||||
|
``LLM.embed``
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt.
|
||||||
|
It is primarily designed for embedding models.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
|
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
|
||||||
outputs = llm.encode("Hello, my name is")
|
output, = llm.embed("Hello, my name is")
|
||||||
|
|
||||||
outputs = model.encode(prompts)
|
embeds = output.outputs.embedding
|
||||||
for output in outputs:
|
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
|
||||||
embeddings = output.outputs.embedding
|
|
||||||
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
|
|
||||||
|
|
||||||
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.
|
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.
|
||||||
|
|
||||||
|
``LLM.classify``
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt.
|
||||||
|
It is primarily designed for classification models.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify")
|
||||||
|
output, = llm.classify("Hello, my name is")
|
||||||
|
|
||||||
|
probs = output.outputs.probs
|
||||||
|
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
|
||||||
|
|
||||||
|
A code example can be found in `examples/offline_inference_classification.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_classification.py>`_.
|
||||||
|
|
||||||
``LLM.score``
|
``LLM.score``
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
@ -71,7 +99,16 @@ These types of models serve as rerankers between candidate query-document pairs
|
|||||||
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
|
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
|
||||||
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.
|
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.
|
||||||
|
|
||||||
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score")
|
||||||
|
output, = llm.score("What is the capital of France?",
|
||||||
|
"The capital of Brazil is Brasilia.")
|
||||||
|
|
||||||
|
score = output.outputs.score
|
||||||
|
print(f"Score: {score}")
|
||||||
|
|
||||||
|
A code example can be found in `examples/offline_inference_scoring.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_scoring.py>`_.
|
||||||
|
|
||||||
Online Inference
|
Online Inference
|
||||||
----------------
|
----------------
|
||||||
|
28
examples/offline_inference_classification.py
Normal file
28
examples/offline_inference_classification.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
# You should pass task="classify" for classification models
|
||||||
|
model = LLM(
|
||||||
|
model="jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
task="classify",
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate logits. The output is a list of ClassificationRequestOutputs.
|
||||||
|
outputs = model.classify(prompts)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
probs = output.outputs.probs
|
||||||
|
probs_trimmed = ((str(probs[:16])[:-1] +
|
||||||
|
", ...]") if len(probs) > 16 else probs)
|
||||||
|
print(f"Prompt: {prompt!r} | "
|
||||||
|
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")
|
@ -9,14 +9,20 @@ prompts = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
|
# You should pass task="embed" for embedding models
|
||||||
model = LLM(
|
model = LLM(
|
||||||
model="intfloat/e5-mistral-7b-instruct",
|
model="intfloat/e5-mistral-7b-instruct",
|
||||||
task="embed", # You should pass task="embed" for embedding models
|
task="embed",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate embedding. The output is a list of PoolingRequestOutputs.
|
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||||
outputs = model.encode(prompts)
|
outputs = model.embed(prompts)
|
||||||
|
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
for output in outputs:
|
for prompt, output in zip(prompts, outputs):
|
||||||
print(output.outputs.embedding) # list of 4096 floats
|
embeds = output.outputs.embedding
|
||||||
|
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||||
|
", ...]") if len(embeds) > 16 else embeds)
|
||||||
|
print(f"Prompt: {prompt!r} | "
|
||||||
|
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||||
|
23
examples/offline_inference_scoring.py
Normal file
23
examples/offline_inference_scoring.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
texts_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
# You should pass task="score" for cross-encoder models
|
||||||
|
model = LLM(
|
||||||
|
model="BAAI/bge-reranker-v2-m3",
|
||||||
|
task="score",
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate scores. The output is a list of ScoringRequestOutputs.
|
||||||
|
outputs = model.score(text_1, texts_2)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
for text_2, output in zip(texts_2, outputs):
|
||||||
|
score = output.outputs.score
|
||||||
|
print(f"Pair: {[text_1, text_2]!r} | Score: {score}")
|
@ -133,7 +133,7 @@ def run_encode(model: str, modality: QueryModality):
|
|||||||
if req_data.image is not None:
|
if req_data.image is not None:
|
||||||
mm_data["image"] = req_data.image
|
mm_data["image"] = req_data.image
|
||||||
|
|
||||||
outputs = req_data.llm.encode({
|
outputs = req_data.llm.embed({
|
||||||
"prompt": req_data.prompt,
|
"prompt": req_data.prompt,
|
||||||
"multi_modal_data": mm_data,
|
"multi_modal_data": mm_data,
|
||||||
})
|
})
|
||||||
|
@ -719,14 +719,6 @@ class VllmRunner:
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def classify(self, prompts: List[str]) -> List[str]:
|
|
||||||
req_outputs = self.model.encode(prompts)
|
|
||||||
outputs = []
|
|
||||||
for req_output in req_outputs:
|
|
||||||
embedding = req_output.outputs.embedding
|
|
||||||
outputs.append(embedding)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -897,6 +889,10 @@ class VllmRunner:
|
|||||||
returned_outputs.append((token_ids, texts))
|
returned_outputs.append((token_ids, texts))
|
||||||
return returned_outputs
|
return returned_outputs
|
||||||
|
|
||||||
|
def classify(self, prompts: List[str]) -> List[List[float]]:
|
||||||
|
req_outputs = self.model.classify(prompts)
|
||||||
|
return [req_output.outputs.probs for req_output in req_outputs]
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -909,16 +905,16 @@ class VllmRunner:
|
|||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios)
|
audios=audios)
|
||||||
|
|
||||||
req_outputs = self.model.encode(inputs)
|
req_outputs = self.model.embed(inputs)
|
||||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||||
|
|
||||||
def score(
|
def score(
|
||||||
self,
|
self,
|
||||||
text_1: Union[str, List[str]],
|
text_1: Union[str, List[str]],
|
||||||
text_2: Union[str, List[str]],
|
text_2: Union[str, List[str]],
|
||||||
) -> List[List[float]]:
|
) -> List[float]:
|
||||||
req_outputs = self.model.score(text_1, text_2)
|
req_outputs = self.model.score(text_1, text_2)
|
||||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
return [req_output.outputs.score for req_output in req_outputs]
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -39,8 +39,8 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
|||||||
assert score.id is not None
|
assert score.id is not None
|
||||||
assert score.data is not None
|
assert score.data is not None
|
||||||
assert len(score.data) == 2
|
assert len(score.data) == 2
|
||||||
assert score.data[0].score[0] <= 0.01
|
assert score.data[0].score <= 0.01
|
||||||
assert score.data[1].score[0] >= 0.9
|
assert score.data[1].score >= 0.9
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -67,8 +67,8 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
|||||||
assert score.id is not None
|
assert score.id is not None
|
||||||
assert score.data is not None
|
assert score.data is not None
|
||||||
assert len(score.data) == 2
|
assert len(score.data) == 2
|
||||||
assert score.data[0].score[0] <= 0.01
|
assert score.data[0].score <= 0.01
|
||||||
assert score.data[1].score[0] >= 0.9
|
assert score.data[1].score >= 0.9
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -90,4 +90,4 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
|||||||
assert score.id is not None
|
assert score.id is not None
|
||||||
assert score.data is not None
|
assert score.data is not None
|
||||||
assert len(score.data) == 1
|
assert len(score.data) == 1
|
||||||
assert score.data[0].score[0] >= 0.9
|
assert score.data[0].score >= 0.9
|
||||||
|
@ -42,7 +42,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
|
|||||||
assert len(vllm_outputs) == 1
|
assert len(vllm_outputs) == 1
|
||||||
assert len(hf_outputs) == 1
|
assert len(hf_outputs) == 1
|
||||||
|
|
||||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@ -63,8 +63,8 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
|||||||
assert len(vllm_outputs) == 2
|
assert len(vllm_outputs) == 2
|
||||||
assert len(hf_outputs) == 2
|
assert len(hf_outputs) == 2
|
||||||
|
|
||||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
|
||||||
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@ -85,5 +85,5 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
|||||||
assert len(vllm_outputs) == 2
|
assert len(vllm_outputs) == 2
|
||||||
assert len(hf_outputs) == 2
|
assert len(hf_outputs) == 2
|
||||||
|
|
||||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
|
||||||
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, PoolingParams, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
|
||||||
from ..utils import fork_new_process_for_each_test
|
from ..utils import fork_new_process_for_each_test
|
||||||
@ -36,9 +36,8 @@ def test_oot_registration_text_generation(dummy_opt_path):
|
|||||||
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
|
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
|
||||||
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
||||||
prompts = ["Hello, my name is", "The text does not matter"]
|
prompts = ["Hello, my name is", "The text does not matter"]
|
||||||
sampling_params = PoolingParams()
|
|
||||||
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
|
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
|
||||||
outputs = llm.encode(prompts, sampling_params)
|
outputs = llm.embed(prompts)
|
||||||
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
assert all(v == 0 for v in output.outputs.embedding)
|
assert all(v == 0 for v in output.outputs.embedding)
|
||||||
|
@ -7,8 +7,11 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.outputs import (CompletionOutput, PoolingOutput,
|
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
|
||||||
PoolingRequestOutput, RequestOutput)
|
CompletionOutput, EmbeddingOutput,
|
||||||
|
EmbeddingRequestOutput, PoolingOutput,
|
||||||
|
PoolingRequestOutput, RequestOutput, ScoringOutput,
|
||||||
|
ScoringRequestOutput)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
@ -27,6 +30,12 @@ __all__ = [
|
|||||||
"CompletionOutput",
|
"CompletionOutput",
|
||||||
"PoolingOutput",
|
"PoolingOutput",
|
||||||
"PoolingRequestOutput",
|
"PoolingRequestOutput",
|
||||||
|
"EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput",
|
||||||
|
"ClassificationOutput",
|
||||||
|
"ClassificationRequestOutput",
|
||||||
|
"ScoringOutput",
|
||||||
|
"ScoringRequestOutput",
|
||||||
"LLMEngine",
|
"LLMEngine",
|
||||||
"EngineArgs",
|
"EngineArgs",
|
||||||
"AsyncLLMEngine",
|
"AsyncLLMEngine",
|
||||||
@ -34,26 +43,3 @@ __all__ = [
|
|||||||
"initialize_ray_cluster",
|
"initialize_ray_cluster",
|
||||||
"PoolingParams",
|
"PoolingParams",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str):
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
if name == "EmbeddingOutput":
|
|
||||||
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
|
|
||||||
"The original name will be removed in an upcoming version.")
|
|
||||||
|
|
||||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
|
||||||
|
|
||||||
return PoolingOutput
|
|
||||||
|
|
||||||
if name == "EmbeddingRequestOutput":
|
|
||||||
msg = ("EmbeddingRequestOutput has been renamed to "
|
|
||||||
"PoolingRequestOutput. "
|
|
||||||
"The original name will be removed in an upcoming version.")
|
|
||||||
|
|
||||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
|
||||||
|
|
||||||
return PoolingRequestOutput
|
|
||||||
|
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
||||||
|
@ -46,11 +46,10 @@ from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||||
ParallelSampleSequenceGroup, Sequence,
|
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
|
||||||
SequenceGroup, SequenceGroupBase,
|
SequenceGroupBase, SequenceGroupMetadata,
|
||||||
SequenceGroupMetadata, SequenceGroupOutput,
|
SequenceGroupOutput, SequenceStatus)
|
||||||
SequenceStatus)
|
|
||||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||||
init_tracer)
|
init_tracer)
|
||||||
from vllm.transformers_utils.config import try_get_generation_config
|
from vllm.transformers_utils.config import try_get_generation_config
|
||||||
@ -966,9 +965,9 @@ class LLMEngine:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_sequence_group_outputs(
|
def _process_sequence_group_outputs(
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
outputs: List[EmbeddingSequenceGroupOutput],
|
outputs: List[PoolingSequenceGroupOutput],
|
||||||
) -> None:
|
) -> None:
|
||||||
seq_group.embeddings = outputs[0].embeddings
|
seq_group.pooled_data = outputs[0].data
|
||||||
|
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
@ -1784,8 +1783,8 @@ class LLMEngine:
|
|||||||
num_prompt_tokens_iter)
|
num_prompt_tokens_iter)
|
||||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||||
# sampler output.
|
# sampler output.
|
||||||
if model_output and (model_output[0].spec_decode_worker_metrics
|
if model_output and isinstance(model_output[0], SamplerOutput) and (
|
||||||
is not None):
|
model_output[0].spec_decode_worker_metrics is not None):
|
||||||
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
|
||||||
else:
|
else:
|
||||||
spec_decode_metrics = None
|
spec_decode_metrics = None
|
||||||
|
@ -26,7 +26,9 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||||
GuidedDecodingRequest, LLMGuidedOptions)
|
GuidedDecodingRequest, LLMGuidedOptions)
|
||||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||||
|
PoolingRequestOutput, RequestOutput,
|
||||||
|
ScoringRequestOutput)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||||
@ -120,7 +122,7 @@ class LLM:
|
|||||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
DEPRECATE_LEGACY: ClassVar[bool] = True
|
||||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||||
|
|
||||||
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
|
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
|
||||||
@ -257,11 +259,14 @@ class LLM:
|
|||||||
self,
|
self,
|
||||||
prompts: Union[PromptType, Sequence[PromptType]],
|
prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
/,
|
/,
|
||||||
*,
|
|
||||||
sampling_params: Optional[Union[SamplingParams,
|
sampling_params: Optional[Union[SamplingParams,
|
||||||
Sequence[SamplingParams]]] = None,
|
Sequence[SamplingParams]]] = None,
|
||||||
|
*,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||||
|
GuidedDecodingRequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -275,6 +280,9 @@ class LLM:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||||
|
GuidedDecodingRequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -288,6 +296,9 @@ class LLM:
|
|||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||||
|
GuidedDecodingRequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -302,6 +313,9 @@ class LLM:
|
|||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||||
|
GuidedDecodingRequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -316,6 +330,9 @@ class LLM:
|
|||||||
prompt_token_ids: List[List[int]],
|
prompt_token_ids: List[List[int]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||||
|
GuidedDecodingRequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -328,6 +345,9 @@ class LLM:
|
|||||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||||
|
GuidedDecodingRequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -678,11 +698,12 @@ class LLM:
|
|||||||
self,
|
self,
|
||||||
prompts: Union[PromptType, Sequence[PromptType]],
|
prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
/,
|
/,
|
||||||
*,
|
|
||||||
pooling_params: Optional[Union[PoolingParams,
|
pooling_params: Optional[Union[PoolingParams,
|
||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
|
*,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -696,6 +717,7 @@ class LLM:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -709,6 +731,7 @@ class LLM:
|
|||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -723,6 +746,7 @@ class LLM:
|
|||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -737,6 +761,7 @@ class LLM:
|
|||||||
prompt_token_ids: List[List[int]],
|
prompt_token_ids: List[List[int]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -749,6 +774,7 @@ class LLM:
|
|||||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -768,7 +794,8 @@ class LLM:
|
|||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[PoolingRequestOutput]:
|
||||||
"""Generates the completions for the input prompts.
|
"""Apply pooling to the hidden states corresponding to the input
|
||||||
|
prompts.
|
||||||
|
|
||||||
This class automatically batches the given prompts, considering
|
This class automatically batches the given prompts, considering
|
||||||
the memory constraint. For the best performance, put all of your prompts
|
the memory constraint. For the best performance, put all of your prompts
|
||||||
@ -787,7 +814,7 @@ class LLM:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of ``PoolingRequestOutput`` objects containing the
|
A list of ``PoolingRequestOutput`` objects containing the
|
||||||
generated embeddings in the same order as the input prompts.
|
pooled hidden states in the same order as the input prompts.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
|
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
|
||||||
@ -833,28 +860,110 @@ class LLM:
|
|||||||
return self.engine_class.validate_outputs(outputs,
|
return self.engine_class.validate_outputs(outputs,
|
||||||
PoolingRequestOutput)
|
PoolingRequestOutput)
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> List[EmbeddingRequestOutput]:
|
||||||
|
"""
|
||||||
|
Generate an embedding vector for each prompt.
|
||||||
|
|
||||||
|
This class automatically batches the given prompts, considering
|
||||||
|
the memory constraint. For the best performance, put all of your prompts
|
||||||
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||||
|
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||||
|
for more details about the format of each prompts.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
prompt_adapter_request: Prompt Adapter request to use for
|
||||||
|
generation, if any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||||
|
embedding vectors in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
if self.llm_engine.model_config.task != "embed":
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding API is only enabled for `--task embed`")
|
||||||
|
|
||||||
|
items = self.encode(prompts,
|
||||||
|
use_tqdm=use_tqdm,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
|
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
||||||
|
|
||||||
|
def classify(
|
||||||
|
self,
|
||||||
|
prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> List[ClassificationRequestOutput]:
|
||||||
|
"""
|
||||||
|
Generate class logits for each prompt.
|
||||||
|
|
||||||
|
This class automatically batches the given prompts, considering
|
||||||
|
the memory constraint. For the best performance, put all of your prompts
|
||||||
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||||
|
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||||
|
for more details about the format of each prompts.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
prompt_adapter_request: Prompt Adapter request to use for
|
||||||
|
generation, if any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of ``ClassificationRequestOutput`` objects containing the
|
||||||
|
embedding vectors in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
if self.llm_engine.model_config.task != "classify":
|
||||||
|
raise ValueError(
|
||||||
|
"Classification API is only enabled for `--task classify`")
|
||||||
|
|
||||||
|
items = self.encode(prompts,
|
||||||
|
use_tqdm=use_tqdm,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
|
return [ClassificationRequestOutput.from_base(item) for item in items]
|
||||||
|
|
||||||
def score(
|
def score(
|
||||||
self,
|
self,
|
||||||
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||||
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||||
/,
|
/,
|
||||||
|
*,
|
||||||
truncate_prompt_tokens: Optional[int] = None,
|
truncate_prompt_tokens: Optional[int] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> List[PoolingRequestOutput]:
|
) -> List[ScoringRequestOutput]:
|
||||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
"""Generate similarity scores for all pairs ``<text,text_pair>``.
|
||||||
|
|
||||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
|
||||||
the text_1 sentence will be replicated N times to pair with the text_2
|
In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
|
||||||
sentences. The input pairs are used to build a list of prompts for the
|
times to pair with the ``text_2`` sentences.
|
||||||
|
The input pairs are used to build a list of prompts for the
|
||||||
cross encoder model. This class automatically batches the prompts,
|
cross encoder model. This class automatically batches the prompts,
|
||||||
considering the memory constraint. For the best performance, put all
|
considering the memory constraint. For the best performance, put all
|
||||||
of your texts into a single list and pass it to this method.
|
of your texts into a single list and pass it to this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_1: can be a single prompt or a list of prompts, in which
|
text_1: can be a single prompt or a list of prompts, in which
|
||||||
case it has to have the same length as the text_2 list
|
case it has to have the same length as the ``text_2`` list
|
||||||
text_2: The texts to pair with the query to form the input
|
text_2: The texts to pair with the query to form the input
|
||||||
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
||||||
more details about the format of each prompts.
|
more details about the format of each prompts.
|
||||||
@ -864,7 +973,7 @@ class LLM:
|
|||||||
generation, if any.
|
generation, if any.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of ``PoolingRequestOutput`` objects containing the
|
A list of ``ScoringRequestOutput`` objects containing the
|
||||||
generated scores in the same order as the input prompts.
|
generated scores in the same order as the input prompts.
|
||||||
"""
|
"""
|
||||||
runner_type = self.llm_engine.model_config.runner_type
|
runner_type = self.llm_engine.model_config.runner_type
|
||||||
@ -884,6 +993,8 @@ class LLM:
|
|||||||
|
|
||||||
if not self.llm_engine.model_config.is_cross_encoder:
|
if not self.llm_engine.model_config.is_cross_encoder:
|
||||||
raise ValueError("Your model does not support cross encoding")
|
raise ValueError("Your model does not support cross encoding")
|
||||||
|
if self.llm_engine.model_config.task != "score":
|
||||||
|
raise ValueError("Score API is only enabled for `--task score`")
|
||||||
|
|
||||||
tokenizer = self.llm_engine.get_tokenizer()
|
tokenizer = self.llm_engine.get_tokenizer()
|
||||||
|
|
||||||
@ -954,8 +1065,10 @@ class LLM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
return self.engine_class.validate_outputs(outputs,
|
items = self.engine_class.validate_outputs(outputs,
|
||||||
PoolingRequestOutput)
|
PoolingRequestOutput)
|
||||||
|
|
||||||
|
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||||
|
|
||||||
def start_profile(self) -> None:
|
def start_profile(self) -> None:
|
||||||
self.llm_engine.start_profile()
|
self.llm_engine.start_profile()
|
||||||
|
@ -900,7 +900,7 @@ class EmbeddingResponse(OpenAIBaseModel):
|
|||||||
class ScoreResponseData(OpenAIBaseModel):
|
class ScoreResponseData(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
object: str = "score"
|
object: str = "score"
|
||||||
score: Union[List[float], str]
|
score: float
|
||||||
|
|
||||||
|
|
||||||
class ScoreResponse(OpenAIBaseModel):
|
class ScoreResponse(OpenAIBaseModel):
|
||||||
|
@ -18,14 +18,15 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
|||||||
ErrorResponse, UsageInfo)
|
ErrorResponse, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||||
|
PoolingRequestOutput)
|
||||||
from vllm.utils import merge_async_iterators
|
from vllm.utils import merge_async_iterators
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_embedding(
|
def _get_embedding(
|
||||||
output: PoolingOutput,
|
output: EmbeddingOutput,
|
||||||
encoding_format: Literal["float", "base64"],
|
encoding_format: Literal["float", "base64"],
|
||||||
) -> Union[List[float], str]:
|
) -> Union[List[float], str]:
|
||||||
if encoding_format == "float":
|
if encoding_format == "float":
|
||||||
@ -46,8 +47,10 @@ def request_output_to_embedding_response(
|
|||||||
data: List[EmbeddingResponseData] = []
|
data: List[EmbeddingResponseData] = []
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
for idx, final_res in enumerate(final_res_batch):
|
for idx, final_res in enumerate(final_res_batch):
|
||||||
|
embedding_res = EmbeddingRequestOutput.from_base(final_res)
|
||||||
prompt_token_ids = final_res.prompt_token_ids
|
prompt_token_ids = final_res.prompt_token_ids
|
||||||
embedding = _get_embedding(final_res.outputs, encoding_format)
|
|
||||||
|
embedding = _get_embedding(embedding_res.outputs, encoding_format)
|
||||||
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||||
data.append(embedding_data)
|
data.append(embedding_data)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
ModelCard, ModelList,
|
ModelCard, ModelList,
|
||||||
ModelPermission,
|
ModelPermission, ScoreRequest,
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest,
|
TokenizeCompletionRequest,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
@ -73,7 +73,7 @@ class LoRAModulePath:
|
|||||||
|
|
||||||
|
|
||||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||||
EmbeddingCompletionRequest,
|
EmbeddingCompletionRequest, ScoreRequest,
|
||||||
TokenizeCompletionRequest]
|
TokenizeCompletionRequest]
|
||||||
|
|
||||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||||
@ -567,12 +567,14 @@ class OpenAIServing:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _base_request_id(raw_request: Request,
|
def _base_request_id(raw_request: Optional[Request],
|
||||||
default: Optional[str] = None) -> Optional[str]:
|
default: Optional[str] = None) -> Optional[str]:
|
||||||
"""Pulls the request id to use from a header, if provided"""
|
"""Pulls the request id to use from a header, if provided"""
|
||||||
default = default or random_uuid()
|
default = default or random_uuid()
|
||||||
return raw_request.headers.get(
|
if raw_request is None:
|
||||||
"X-Request-Id", default) if raw_request is not None else default
|
return default
|
||||||
|
|
||||||
|
return raw_request.headers.get("X-Request-Id", default)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_decoded_token(logprob: Logprob,
|
def _get_decoded_token(logprob: Logprob,
|
||||||
|
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
|||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||||
from vllm.inputs.data import TokensPrompt
|
from vllm.inputs.data import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingRequestOutput
|
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
from vllm.utils import make_async, merge_async_iterators
|
from vllm.utils import make_async, merge_async_iterators
|
||||||
|
|
||||||
@ -24,13 +24,13 @@ def request_output_to_score_response(
|
|||||||
final_res_batch: List[PoolingRequestOutput], request_id: str,
|
final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||||
created_time: int, model_name: str) -> ScoreResponse:
|
created_time: int, model_name: str) -> ScoreResponse:
|
||||||
data: List[ScoreResponseData] = []
|
data: List[ScoreResponseData] = []
|
||||||
score = None
|
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
for idx, final_res in enumerate(final_res_batch):
|
for idx, final_res in enumerate(final_res_batch):
|
||||||
if final_res is not None:
|
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||||
score = final_res.outputs.embedding
|
|
||||||
score_data = ScoreResponseData(index=idx, score=score)
|
score_data = ScoreResponseData(index=idx,
|
||||||
data.append(score_data)
|
score=classify_res.outputs.score)
|
||||||
|
data.append(score_data)
|
||||||
|
|
||||||
usage = UsageInfo(
|
usage = UsageInfo(
|
||||||
prompt_tokens=num_prompt_tokens,
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.config import PoolerConfig
|
from vllm.config import PoolerConfig
|
||||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||||
PoolingTensors)
|
PoolingTensors)
|
||||||
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
get_cross_encoder_activation_function)
|
get_cross_encoder_activation_function)
|
||||||
|
|
||||||
@ -22,7 +24,7 @@ class PoolingType(IntEnum):
|
|||||||
MEAN = 4
|
MEAN = 4
|
||||||
|
|
||||||
|
|
||||||
class Pooler(nn.Module):
|
class SimplePooler(nn.Module):
|
||||||
"""A layer that pools specific information from hidden states.
|
"""A layer that pools specific information from hidden states.
|
||||||
|
|
||||||
This layer does the following:
|
This layer does the following:
|
||||||
@ -35,22 +37,204 @@ class Pooler(nn.Module):
|
|||||||
normalize: Whether to normalize the pooled data.
|
normalize: Whether to normalize the pooled data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pooling_type(
|
||||||
|
pooling_type: PoolingType,
|
||||||
|
*,
|
||||||
|
normalize: bool,
|
||||||
|
softmax: bool,
|
||||||
|
step_tag_id: Optional[int] = None,
|
||||||
|
returned_token_ids: Optional[List[int]] = None,
|
||||||
|
) -> "SimplePooler":
|
||||||
|
if pooling_type == PoolingType.LAST:
|
||||||
|
assert step_tag_id is None and returned_token_ids is None
|
||||||
|
return LastPool(normalize=normalize, softmax=softmax)
|
||||||
|
if pooling_type == PoolingType.ALL:
|
||||||
|
assert step_tag_id is None and returned_token_ids is None
|
||||||
|
return AllPool(normalize=normalize, softmax=softmax)
|
||||||
|
if pooling_type == PoolingType.CLS:
|
||||||
|
assert step_tag_id is None and returned_token_ids is None
|
||||||
|
return CLSPool(normalize=normalize, softmax=softmax)
|
||||||
|
if pooling_type == PoolingType.MEAN:
|
||||||
|
assert step_tag_id is None and returned_token_ids is None
|
||||||
|
return MeanPool(normalize=normalize, softmax=softmax)
|
||||||
|
if pooling_type == PoolingType.STEP:
|
||||||
|
return StepPool(normalize=normalize,
|
||||||
|
softmax=softmax,
|
||||||
|
step_tag_id=step_tag_id,
|
||||||
|
returned_token_ids=returned_token_ids)
|
||||||
|
|
||||||
|
assert_never(pooling_type)
|
||||||
|
|
||||||
|
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.head = PoolerHead(normalize=normalize, softmax=softmax)
|
||||||
|
|
||||||
|
def get_prompt_lens(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return PoolingTensors.from_pooling_metadata(
|
||||||
|
pooling_metadata, hidden_states.device).prompt_lens
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput:
|
||||||
|
return PoolingSequenceGroupOutput(data)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||||
|
pooled_data = self.head(pooled_data)
|
||||||
|
pooled_outputs = [self.build_output(data) for data in pooled_data]
|
||||||
|
return PoolerOutput(outputs=pooled_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class CLSPool(SimplePooler):
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||||
|
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||||
|
return hidden_states[first_token_flat_indices]
|
||||||
|
|
||||||
|
|
||||||
|
class LastPool(SimplePooler):
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||||
|
return hidden_states[last_token_flat_indices]
|
||||||
|
|
||||||
|
|
||||||
|
class AllPool(SimplePooler):
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
pooled_data = list[torch.Tensor]()
|
||||||
|
for prompt_len in prompt_lens:
|
||||||
|
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||||
|
offset += prompt_len
|
||||||
|
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
class MeanPool(SimplePooler):
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
cumsum = torch.cumsum(hidden_states, dim=0)
|
||||||
|
start_indices = torch.cat([
|
||||||
|
torch.tensor([0], device=hidden_states.device),
|
||||||
|
torch.cumsum(prompt_lens[:-1], dim=0)
|
||||||
|
])
|
||||||
|
end_indices = torch.cumsum(prompt_lens, dim=0)
|
||||||
|
return (cumsum[end_indices - 1] - cumsum[start_indices] +
|
||||||
|
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class StepPool(SimplePooler):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pooling_type: PoolingType,
|
*,
|
||||||
normalize: bool,
|
normalize: bool,
|
||||||
softmax: bool,
|
softmax: bool,
|
||||||
step_tag_id: Optional[int] = None,
|
step_tag_id: Optional[int] = None,
|
||||||
returned_token_ids: Optional[List[int]] = None,
|
returned_token_ids: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(normalize=normalize, softmax=softmax)
|
||||||
|
|
||||||
self.pooling_type = pooling_type
|
|
||||||
self.normalize = normalize
|
|
||||||
self.softmax = softmax
|
|
||||||
self.step_tag_id = step_tag_id
|
self.step_tag_id = step_tag_id
|
||||||
self.returned_token_ids = returned_token_ids
|
self.returned_token_ids = returned_token_ids
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
returned_token_ids = self.returned_token_ids
|
||||||
|
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||||
|
hidden_states = hidden_states[:, returned_token_ids]
|
||||||
|
|
||||||
|
step_tag_id = self.step_tag_id
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
pooled_data = list[torch.Tensor]()
|
||||||
|
for prompt_len, seq_data_i in zip(prompt_lens,
|
||||||
|
pooling_metadata.seq_data.values()):
|
||||||
|
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||||
|
if step_tag_id is not None:
|
||||||
|
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
|
||||||
|
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
||||||
|
|
||||||
|
offset += prompt_len
|
||||||
|
pooled_data.append(pooled_data_i)
|
||||||
|
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, normalize: bool, softmax: bool) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.normalize = normalize
|
||||||
|
self.softmax = softmax
|
||||||
|
|
||||||
|
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
|
||||||
|
if self.normalize:
|
||||||
|
if isinstance(pooled_data, list):
|
||||||
|
pooled_data = [
|
||||||
|
F.normalize(data, p=2, dim=1) for data in pooled_data
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
pooled_data = F.normalize(pooled_data, p=2, dim=1)
|
||||||
|
|
||||||
|
if self.softmax:
|
||||||
|
if isinstance(pooled_data, list):
|
||||||
|
pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
|
||||||
|
else:
|
||||||
|
pooled_data = F.softmax(pooled_data, dim=-1)
|
||||||
|
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
class Pooler(nn.Module):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config_with_defaults(
|
def from_config_with_defaults(
|
||||||
cls,
|
cls,
|
||||||
@ -60,8 +244,8 @@ class Pooler(nn.Module):
|
|||||||
softmax: bool,
|
softmax: bool,
|
||||||
step_tag_id: Optional[int] = None,
|
step_tag_id: Optional[int] = None,
|
||||||
returned_token_ids: Optional[List[int]] = None,
|
returned_token_ids: Optional[List[int]] = None,
|
||||||
) -> "Pooler":
|
) -> SimplePooler:
|
||||||
return cls(
|
return SimplePooler.from_pooling_type(
|
||||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
pooling_type=PoolingType[pooler_config.pooling_type]
|
||||||
if pooler_config.pooling_type is not None else pooling_type,
|
if pooler_config.pooling_type is not None else pooling_type,
|
||||||
normalize=pooler_config.normalize
|
normalize=pooler_config.normalize
|
||||||
@ -75,85 +259,6 @@ class Pooler(nn.Module):
|
|||||||
returned_token_ids,
|
returned_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
pooling_metadata: PoolingMetadata,
|
|
||||||
) -> PoolerOutput:
|
|
||||||
"""Pools specific information from hidden states based on metadata."""
|
|
||||||
|
|
||||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
|
||||||
pooling_metadata, hidden_states.device).prompt_lens
|
|
||||||
|
|
||||||
if self.pooling_type is PoolingType.CLS:
|
|
||||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
|
||||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
|
|
||||||
dim=0)[:-1]
|
|
||||||
pooled_data = hidden_states[first_token_flat_indices]
|
|
||||||
elif self.pooling_type == PoolingType.LAST:
|
|
||||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
|
||||||
pooled_data = hidden_states[last_token_flat_indices]
|
|
||||||
elif self.pooling_type == PoolingType.ALL:
|
|
||||||
offset = 0
|
|
||||||
pooled_data = []
|
|
||||||
for prompt_len in prompt_lens:
|
|
||||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
|
||||||
offset += prompt_len
|
|
||||||
elif self.pooling_type == PoolingType.MEAN:
|
|
||||||
# Calculate mean pooling
|
|
||||||
cumsum = torch.cumsum(hidden_states, dim=0)
|
|
||||||
start_indices = torch.cat([
|
|
||||||
torch.tensor([0], device=hidden_states.device),
|
|
||||||
torch.cumsum(prompt_lens[:-1], dim=0)
|
|
||||||
])
|
|
||||||
end_indices = torch.cumsum(prompt_lens, dim=0)
|
|
||||||
pooled_data = (
|
|
||||||
cumsum[end_indices - 1] - cumsum[start_indices] +
|
|
||||||
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
|
||||||
elif self.pooling_type == PoolingType.STEP:
|
|
||||||
returned_token_ids = self.returned_token_ids
|
|
||||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
|
||||||
hidden_states = hidden_states[:, returned_token_ids]
|
|
||||||
|
|
||||||
step_tag_id = self.step_tag_id
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
pooled_data = []
|
|
||||||
for prompt_len, seq_data_i in zip(
|
|
||||||
prompt_lens, pooling_metadata.seq_data.values()):
|
|
||||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
|
||||||
if step_tag_id is not None:
|
|
||||||
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
|
|
||||||
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
|
||||||
|
|
||||||
offset += prompt_len
|
|
||||||
pooled_data.append(pooled_data_i)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
|
||||||
|
|
||||||
if self.normalize:
|
|
||||||
if isinstance(pooled_data, list):
|
|
||||||
pooled_data = [
|
|
||||||
nn.functional.normalize(data, p=2, dim=1)
|
|
||||||
for data in pooled_data
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
|
||||||
|
|
||||||
if self.softmax:
|
|
||||||
if isinstance(pooled_data, list):
|
|
||||||
pooled_data = [
|
|
||||||
nn.functional.softmax(data, dim=-1) for data in pooled_data
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
|
|
||||||
|
|
||||||
pooled_outputs = [
|
|
||||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
|
||||||
]
|
|
||||||
|
|
||||||
return PoolerOutput(outputs=pooled_outputs)
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEncodingPooler(nn.Module):
|
class CrossEncodingPooler(nn.Module):
|
||||||
"""A layer that pools specific information from hidden states.
|
"""A layer that pools specific information from hidden states.
|
||||||
@ -208,9 +313,8 @@ class CrossEncodingPooler(nn.Module):
|
|||||||
if self.pooler is not None:
|
if self.pooler is not None:
|
||||||
# apply classifier once on the full batch if possible
|
# apply classifier once on the full batch if possible
|
||||||
pooled_output = self.classifier(pooled_output)
|
pooled_output = self.classifier(pooled_output)
|
||||||
logits = self.default_activation_function(pooled_output)
|
|
||||||
|
|
||||||
pooled_outputs = [
|
scores = self.default_activation_function(pooled_output).squeeze(-1)
|
||||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in logits
|
|
||||||
]
|
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
|
||||||
return PoolerOutput(outputs=pooled_outputs)
|
return PoolerOutput(outputs=pooled_outputs)
|
||||||
|
@ -2,19 +2,20 @@ from array import array
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
import torch.nn as nn
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.attention.backends.xformers import XFormersImpl
|
from vllm.attention.backends.xformers import XFormersImpl
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.pooler import PoolerHead
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||||
PoolingTensors)
|
PoolingTensors)
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors,
|
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||||
PoolerOutput)
|
PoolingSequenceGroupOutput)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -52,6 +53,8 @@ class GritLMPooler(nn.Module):
|
|||||||
self.embed_pattern_ids = tokens_to_ids(
|
self.embed_pattern_ids = tokens_to_ids(
|
||||||
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
||||||
|
|
||||||
|
self.head = PoolerHead(normalize=True, softmax=False)
|
||||||
|
|
||||||
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
||||||
"""
|
"""
|
||||||
Find the first occurrence of target in arr starting from start_idx.
|
Find the first occurrence of target in arr starting from start_idx.
|
||||||
@ -75,7 +78,7 @@ class GritLMPooler(nn.Module):
|
|||||||
return i
|
return i
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def _get_instruction_len(self, prompt_token_ids: array) -> bool:
|
def _get_instruction_len(self, prompt_token_ids: array) -> int:
|
||||||
"""
|
"""
|
||||||
Get the length of the instruction in the prompt.
|
Get the length of the instruction in the prompt.
|
||||||
|
|
||||||
@ -168,10 +171,10 @@ class GritLMPooler(nn.Module):
|
|||||||
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
||||||
1)
|
1)
|
||||||
|
|
||||||
pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1)
|
pooled_data = self.head(mean_embeddings)
|
||||||
|
|
||||||
pooled_outputs = [
|
pooled_outputs = [
|
||||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
PoolingSequenceGroupOutput(data) for data in pooled_data
|
||||||
]
|
]
|
||||||
|
|
||||||
return PoolerOutput(outputs=pooled_outputs)
|
return PoolerOutput(outputs=pooled_outputs)
|
||||||
|
225
vllm/outputs.py
225
vllm/outputs.py
@ -1,9 +1,13 @@
|
|||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, Generic, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
@ -57,14 +61,26 @@ class PoolingOutput:
|
|||||||
"""The output data of one pooling output of a request.
|
"""The output data of one pooling output of a request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding: The embedding vector, which is a list of floats. The
|
data: The extracted hidden states.
|
||||||
length of vector depends on the model as listed in the embedding guide.
|
|
||||||
"""
|
"""
|
||||||
embedding: List[float]
|
data: torch.Tensor
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"PoolingOutput("
|
return (f"PoolingOutput(data={self.data})")
|
||||||
f"embedding={len(self.embedding)})")
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return (isinstance(other, self.__class__) and bool(
|
||||||
|
(self.data == other.data).all()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding(self) -> list[float]:
|
||||||
|
msg = ("`LLM.encode()` now returns raw outputs. "
|
||||||
|
"To return embeddings, use `LLM.embed()`. "
|
||||||
|
"To return class probabilities, use `LLM.classify()` "
|
||||||
|
"and access the `probs` attribute. ")
|
||||||
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
|
return self.data.tolist()
|
||||||
|
|
||||||
|
|
||||||
class RequestOutput:
|
class RequestOutput:
|
||||||
@ -316,7 +332,10 @@ class RequestOutput:
|
|||||||
f"multi_modal_placeholders={self.multi_modal_placeholders})")
|
f"multi_modal_placeholders={self.multi_modal_placeholders})")
|
||||||
|
|
||||||
|
|
||||||
class PoolingRequestOutput:
|
_O = TypeVar("_O", default=PoolingOutput)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingRequestOutput(Generic[_O]):
|
||||||
"""
|
"""
|
||||||
The output data of a pooling request to the LLM.
|
The output data of a pooling request to the LLM.
|
||||||
|
|
||||||
@ -327,24 +346,24 @@ class PoolingRequestOutput:
|
|||||||
finished (bool): A flag indicating whether the pooling is completed.
|
finished (bool): A flag indicating whether the pooling is completed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, request_id: str, outputs: "PoolingOutput",
|
def __init__(self, request_id: str, outputs: _O,
|
||||||
prompt_token_ids: List[int], finished: bool):
|
prompt_token_ids: List[int], finished: bool):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
self.finished = finished
|
self.finished = finished
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def from_seq_group(cls,
|
def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
|
||||||
seq_group: 'SequenceGroup') -> "PoolingRequestOutput":
|
pooled_data = seq_group.pooled_data
|
||||||
if seq_group.embeddings is None:
|
assert pooled_data is not None
|
||||||
raise ValueError(
|
|
||||||
"Embeddings are missing in seq_group for EmbeddingRequest.")
|
output = PoolingOutput(pooled_data)
|
||||||
output = PoolingOutput(seq_group.embeddings)
|
|
||||||
prompt_token_ids = seq_group.prompt_token_ids
|
prompt_token_ids = seq_group.prompt_token_ids
|
||||||
finished = seq_group.is_finished()
|
finished = seq_group.is_finished()
|
||||||
|
|
||||||
return cls(seq_group.request_id, output, prompt_token_ids, finished)
|
return PoolingRequestOutput(seq_group.request_id, output,
|
||||||
|
prompt_token_ids, finished)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""
|
"""
|
||||||
@ -356,89 +375,137 @@ class PoolingRequestOutput:
|
|||||||
Returns:
|
Returns:
|
||||||
str: A string representation of the PoolingRequestOutput instance.
|
str: A string representation of the PoolingRequestOutput instance.
|
||||||
"""
|
"""
|
||||||
return (f"PoolingRequestOutput(request_id='{self.request_id}', "
|
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
|
||||||
f"outputs={repr(self.outputs)}, "
|
f"outputs={self.outputs!r}, "
|
||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
f"finished={self.finished})")
|
f"finished={self.finished})")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ScoreOutput:
|
|
||||||
"""The output data of one completion output of a request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
score: The score, which is a list of floats.
|
|
||||||
index: The correspondent text index of the score.
|
|
||||||
"""
|
|
||||||
index: int
|
|
||||||
score: List[float]
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f"ScoreOutput("
|
|
||||||
f"score={self.score}), "
|
|
||||||
f"index={self.index})")
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreRequestOutput:
|
|
||||||
"""
|
|
||||||
The output data of an score request to the LLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_id (str): A unique identifier for the score request.
|
|
||||||
outputs (score): The embedding results for the given input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str, outputs: "ScoreOutput"):
|
|
||||||
self.request_id = request_id
|
|
||||||
self.outputs = outputs
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
"""
|
|
||||||
Returns a string representation of an ScoreRequestOutput instance.
|
|
||||||
|
|
||||||
The representation includes the request_id and the number of outputs,
|
|
||||||
providing a quick overview of the embedding request's results.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string representation of the ScoreRequestOutput instance.
|
|
||||||
"""
|
|
||||||
return (f"ScoreRequestOutput(request_id='{self.request_id}', "
|
|
||||||
f"outputs={repr(self.outputs)}")
|
|
||||||
|
|
||||||
|
|
||||||
class RequestOutputFactory:
|
class RequestOutputFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(seq_group: SequenceGroup,
|
def create(seq_group: SequenceGroup,
|
||||||
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
|
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
|
||||||
use_cache: bool = False):
|
use_cache: bool = False):
|
||||||
# Determine the type based on a condition, for example:
|
if seq_group.pooled_data is not None:
|
||||||
if hasattr(seq_group,
|
|
||||||
'embeddings') and seq_group.embeddings is not None:
|
|
||||||
return PoolingRequestOutput.from_seq_group(seq_group)
|
return PoolingRequestOutput.from_seq_group(seq_group)
|
||||||
else:
|
else:
|
||||||
return RequestOutput.from_seq_group(seq_group, use_cache,
|
return RequestOutput.from_seq_group(seq_group, use_cache,
|
||||||
seq_id_to_seq_group)
|
seq_id_to_seq_group)
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str):
|
@dataclass
|
||||||
import warnings
|
class EmbeddingOutput:
|
||||||
|
"""The output data of one embedding output of a request.
|
||||||
|
|
||||||
if name == "EmbeddingOutput":
|
Args:
|
||||||
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
|
embedding: The embedding vector, which is a list of floats.
|
||||||
"The original name will be removed in an upcoming version.")
|
Its length depends on the hidden dimension of the model.
|
||||||
|
"""
|
||||||
|
embedding: list[float]
|
||||||
|
|
||||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
@staticmethod
|
||||||
|
def from_base(pooling_output: PoolingOutput):
|
||||||
|
pooled_data = pooling_output.data
|
||||||
|
if pooled_data.ndim != 1:
|
||||||
|
raise ValueError("pooled_data should be a 1-D embedding vector")
|
||||||
|
|
||||||
return PoolingOutput
|
return EmbeddingOutput(pooled_data.tolist())
|
||||||
|
|
||||||
if name == "EmbeddingRequestOutput":
|
@property
|
||||||
msg = ("EmbeddingRequestOutput has been renamed to "
|
def hidden_size(self) -> int:
|
||||||
"PoolingRequestOutput. "
|
return len(self.embedding)
|
||||||
"The original name will be removed in an upcoming version.")
|
|
||||||
|
|
||||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
def __repr__(self) -> str:
|
||||||
|
return f"EmbeddingOutput(hidden_size={self.hidden_size})"
|
||||||
|
|
||||||
return PoolingRequestOutput
|
|
||||||
|
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_base(request_output: PoolingRequestOutput):
|
||||||
|
return EmbeddingRequestOutput(
|
||||||
|
request_id=request_output.request_id,
|
||||||
|
outputs=EmbeddingOutput.from_base(request_output.outputs),
|
||||||
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
|
finished=request_output.finished,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassificationOutput:
|
||||||
|
"""The output data of one classification output of a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs: The probability vector, which is a list of floats.
|
||||||
|
Its length depends on the number of classes.
|
||||||
|
"""
|
||||||
|
probs: list[float]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_base(pooling_output: PoolingOutput):
|
||||||
|
pooled_data = pooling_output.data
|
||||||
|
if pooled_data.ndim != 1:
|
||||||
|
raise ValueError("pooled_data should be a 1-D probability vector")
|
||||||
|
|
||||||
|
return ClassificationOutput(pooled_data.tolist())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self) -> int:
|
||||||
|
return len(self.probs)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ClassificationOutput(num_classes={self.num_classes})"
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_base(request_output: PoolingRequestOutput):
|
||||||
|
return ClassificationRequestOutput(
|
||||||
|
request_id=request_output.request_id,
|
||||||
|
outputs=ClassificationOutput.from_base(request_output.outputs),
|
||||||
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
|
finished=request_output.finished,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoringOutput:
|
||||||
|
"""The output data of one scoring output of a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score: The similarity score, which is a scalar value.
|
||||||
|
"""
|
||||||
|
score: float
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_base(pooling_output: PoolingOutput):
|
||||||
|
pooled_data = pooling_output.data
|
||||||
|
if pooled_data.ndim != 0:
|
||||||
|
raise ValueError("pooled_data should be a scalar score")
|
||||||
|
|
||||||
|
return ScoringOutput(pooled_data.item())
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ScoringOutput(score={self.score})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding(self) -> list[float]:
|
||||||
|
msg = ("`LLM.score()` now returns scalar scores. "
|
||||||
|
"Please access it via the `score` attribute. ")
|
||||||
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
|
return [self.score]
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_base(request_output: PoolingRequestOutput):
|
||||||
|
return ScoringRequestOutput(
|
||||||
|
request_id=request_output.request_id,
|
||||||
|
outputs=ScoringOutput.from_base(request_output.outputs),
|
||||||
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
|
finished=request_output.finished,
|
||||||
|
)
|
||||||
|
@ -617,10 +617,9 @@ class SequenceGroup:
|
|||||||
sampling_params: The sampling parameters used to generate the outputs.
|
sampling_params: The sampling parameters used to generate the outputs.
|
||||||
arrival_time: The arrival time of the request.
|
arrival_time: The arrival time of the request.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
embeddings: The embeddings vectors of the prompt of the sequence group
|
pooling_params: The parameters used to generate the pooler
|
||||||
for a pooling model.
|
|
||||||
pooling_params: The pooling parameters used to generate the pooling
|
|
||||||
for a pooling model.
|
for a pooling model.
|
||||||
|
pooled_data: The extracted hidden states from a pooling model.
|
||||||
encoder_seq: Optional, the single encoder sequence. Should be None
|
encoder_seq: Optional, the single encoder sequence. Should be None
|
||||||
unless you are working with an encoder/decoder model.
|
unless you are working with an encoder/decoder model.
|
||||||
trace_headers: OpenTelemetry trace headers.
|
trace_headers: OpenTelemetry trace headers.
|
||||||
@ -635,8 +634,8 @@ class SequenceGroup:
|
|||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
embeddings: Optional[List[float]] = None,
|
|
||||||
pooling_params: Optional[PoolingParams] = None,
|
pooling_params: Optional[PoolingParams] = None,
|
||||||
|
pooled_data: Optional[torch.Tensor] = None,
|
||||||
encoder_seq: Optional[Sequence] = None,
|
encoder_seq: Optional[Sequence] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -658,8 +657,8 @@ class SequenceGroup:
|
|||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
self.state = SequenceGroupState()
|
self.state = SequenceGroupState()
|
||||||
self.embeddings = embeddings
|
|
||||||
self.pooling_params = pooling_params
|
self.pooling_params = pooling_params
|
||||||
|
self.pooled_data = pooled_data
|
||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
self.encoder_seq = encoder_seq
|
self.encoder_seq = encoder_seq
|
||||||
self.trace_headers = trace_headers
|
self.trace_headers = trace_headers
|
||||||
@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput(
|
|||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
array_like=True): # type: ignore[call-arg]
|
array_like=True): # type: ignore[call-arg]
|
||||||
__metaclass__ = SequenceGroupOutput
|
|
||||||
"""The model output associated with a completion sequence group."""
|
"""The model output associated with a completion sequence group."""
|
||||||
|
__metaclass__ = SequenceGroupOutput
|
||||||
samples: List[SequenceOutput]
|
samples: List[SequenceOutput]
|
||||||
# Prompt logprob for each prompt query token.
|
# Prompt logprob for each prompt query token.
|
||||||
prompt_logprobs: Optional[PromptLogprobs]
|
prompt_logprobs: Optional[PromptLogprobs]
|
||||||
@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput(
|
|||||||
and self.prompt_logprobs == other.prompt_logprobs)
|
and self.prompt_logprobs == other.prompt_logprobs)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingSequenceGroupOutput(
|
class PoolingSequenceGroupOutput(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
array_like=True, # type: ignore[call-arg]
|
array_like=True, # type: ignore[call-arg]
|
||||||
):
|
):
|
||||||
"""The model output associated with an embedding sequence group."""
|
"""The model output associated with a pooling sequence group."""
|
||||||
__metaclass__ = SequenceGroupOutput
|
__metaclass__ = SequenceGroupOutput
|
||||||
embeddings: List[int]
|
# Annotated as Any to be compatible with msgspec
|
||||||
|
# The actual type is in SequenceGroup.pooled_data
|
||||||
|
data: Any
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"EmbeddingSequenceGroupOutput("
|
return f"PoolingSequenceGroupOutput(data={self.data}"
|
||||||
f"embeddings_shape={len(self.embeddings)})")
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, EmbeddingSequenceGroupOutput):
|
if not isinstance(other, PoolingSequenceGroupOutput):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return self.embeddings == other.embeddings
|
return self.data == other.data
|
||||||
|
|
||||||
|
|
||||||
# cannot use msgspec.Struct here because Dynamo does not support it
|
# cannot use msgspec.Struct here because Dynamo does not support it
|
||||||
@ -1085,7 +1085,7 @@ class IntermediateTensors:
|
|||||||
elif isinstance(key, slice):
|
elif isinstance(key, slice):
|
||||||
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
||||||
|
|
||||||
def __setitem__(self, key: str, value):
|
def __setitem__(self, key: str, value: torch.Tensor):
|
||||||
self.tensors[key] = value
|
self.tensors[key] = value
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -1103,16 +1103,12 @@ class PoolerOutput(
|
|||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
array_like=True): # type: ignore[call-arg]
|
array_like=True): # type: ignore[call-arg]
|
||||||
"""The output from a pooling operation in the pooling model."""
|
"""The output from a pooling operation in the pooling model."""
|
||||||
outputs: List[EmbeddingSequenceGroupOutput]
|
outputs: List[PoolingSequenceGroupOutput]
|
||||||
|
|
||||||
# lazy import to avoid circular import
|
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
||||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
|
|
||||||
return self.outputs[idx]
|
return self.outputs[idx]
|
||||||
|
|
||||||
def __setitem__(self, idx: int, value):
|
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
|
||||||
self.outputs[idx] = value
|
self.outputs[idx] = value
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
|||||||
arrival_time=seq_group.arrival_time,
|
arrival_time=seq_group.arrival_time,
|
||||||
sampling_params=original_params,
|
sampling_params=original_params,
|
||||||
lora_request=seq_group.lora_request,
|
lora_request=seq_group.lora_request,
|
||||||
embeddings=seq_group.embeddings,
|
|
||||||
pooling_params=seq_group.pooling_params,
|
pooling_params=seq_group.pooling_params,
|
||||||
|
pooled_data=seq_group.pooled_data,
|
||||||
encoder_seq=seq_group.encoder_seq,
|
encoder_seq=seq_group.encoder_seq,
|
||||||
trace_headers=seq_group.trace_headers,
|
trace_headers=seq_group.trace_headers,
|
||||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user