Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
parent
49628fe13e
commit
214efc2c3c
@ -44,6 +44,148 @@ We currently support the following OpenAI APIs:
|
||||
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
|
||||
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
|
||||
|
||||
## Score API for Cross Encoder Models
|
||||
|
||||
vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
|
||||
A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1.
|
||||
|
||||
### Example of usage for a pair of a string and a list of texts
|
||||
|
||||
In this case, the model will compare the first given text to each of the texts containing the list.
|
||||
|
||||
```bash
|
||||
curl -X 'POST' \
|
||||
'http://127.0.0.1:8000/v1/score' \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"text_1": "What is the capital of France?",
|
||||
"text_2": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris."
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"id": "score-request-id",
|
||||
"object": "list",
|
||||
"created": 693570,
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"object": "score",
|
||||
"score": [
|
||||
0.001094818115234375
|
||||
]
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"object": "score",
|
||||
"score": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"usage": {}
|
||||
}
|
||||
```
|
||||
|
||||
### Example of usage for a pair of two lists of texts
|
||||
|
||||
In this case, the model will compare the one by one, making pairs by same index correspondent in each list.
|
||||
|
||||
```bash
|
||||
curl -X 'POST' \
|
||||
'http://127.0.0.1:8000/v1/score' \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"encoding_format": "float",
|
||||
"text_1": [
|
||||
"What is the capital of Brazil?",
|
||||
"What is the capital of France?"
|
||||
],
|
||||
"text_2": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris."
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"id": "score-request-id",
|
||||
"object": "list",
|
||||
"created": 693447,
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"object": "score",
|
||||
"score": [
|
||||
1
|
||||
]
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"object": "score",
|
||||
"score": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"usage": {}
|
||||
}
|
||||
```
|
||||
|
||||
### Example of usage for a pair of two strings
|
||||
|
||||
In this case, the model will compare the strings of texts.
|
||||
|
||||
```bash
|
||||
curl -X 'POST' \
|
||||
'http://127.0.0.1:8000/v1/score' \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"encoding_format": "float",
|
||||
"text_1": "What is the capital of France?",
|
||||
"text_2": "The capital of France is Paris."
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"id": "score-request-id",
|
||||
"object": "list",
|
||||
"created": 693447,
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"object": "score",
|
||||
"score": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"usage": {}
|
||||
}
|
||||
```
|
||||
|
||||
## Extra Parameters
|
||||
|
||||
vLLM supports a set of parameters that are not part of the OpenAI API.
|
||||
|
58
examples/openai_cross_encoder_score.py
Normal file
58
examples/openai_cross_encoder_score.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""Examples Python client Score for Cross Encoder Models
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(prompt: json, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
|
||||
args = parser.parse_args()
|
||||
api_url = f"http://{args.host}:{args.port}/v1/score"
|
||||
|
||||
model_name = args.model
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Prompt for text_1 is string and text_2 is a list:")
|
||||
pprint.pprint(prompt)
|
||||
print("Score Response:")
|
||||
pprint.pprint(score_response.data)
|
||||
|
||||
text_1 = [
|
||||
"What is the capital of Brazil?", "What is the capital of France?"
|
||||
]
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Prompt for text_1 and text_2 are lists:")
|
||||
pprint.pprint(prompt)
|
||||
print("Score Response:")
|
||||
pprint.pprint(score_response.data)
|
||||
|
||||
text_1 = "What is the capital of Brazil?"
|
||||
text_2 = "The capital of Brazil is Brasilia."
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Prompt for text_1 and text_2 are strings:")
|
||||
pprint.pprint(prompt)
|
||||
print("Score Response:")
|
||||
pprint.pprint(score_response.data)
|
@ -265,6 +265,7 @@ class HfRunner:
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
is_embedding_model: bool = False,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
postprocess_inputs: Callable[..., BatchEncoding] = identity,
|
||||
@ -282,6 +283,14 @@ class HfRunner:
|
||||
device="cpu",
|
||||
trust_remote_code=True,
|
||||
).to(dtype=torch_dtype))
|
||||
elif is_cross_encoder:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import CrossEncoder
|
||||
self.model = CrossEncoder(model_name,
|
||||
device="cpu",
|
||||
trust_remote_code=True)
|
||||
self.model.model = self.wrap_device(self.model.model)\
|
||||
.to(dtype=torch_dtype)
|
||||
else:
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
self.model = self.wrap_device(
|
||||
@ -625,6 +634,9 @@ class HfRunner:
|
||||
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
||||
return self.model.encode(prompts)
|
||||
|
||||
def predict(self, prompts: List[List[str]]) -> torch.Tensor:
|
||||
return self.model.predict(prompts, convert_to_tensor=True)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@ -898,6 +910,14 @@ class VllmRunner:
|
||||
req_outputs = self.model.encode(inputs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[str, List[str]],
|
||||
text_2: Union[str, List[str]],
|
||||
) -> List[List[float]]:
|
||||
req_outputs = self.model.score(text_1, text_2)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
93
tests/entrypoints/openai/test_score.py
Normal file
93
tests/entrypoints/openai/test_score.py
Normal file
@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ScoreResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
score_response = requests.post(server.url_for("v1/score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
assert score.data[0].score[0] <= 0.01
|
||||
assert score.data[1].score[0] >= 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
text_1 = [
|
||||
"What is the capital of the United States?",
|
||||
"What is the capital of France?"
|
||||
]
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
score_response = requests.post(server.url_for("v1/score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
assert score.data[0].score[0] <= 0.01
|
||||
assert score.data[1].score[0] >= 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
|
||||
score_response = requests.post(server.url_for("v1/score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 1
|
||||
assert score.data[0].score[0] >= 0.9
|
95
tests/models/embedding/language/test_scoring.py
Normal file
95
tests/models/embedding/language/test_scoring.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Compare the embedding outputs of HF and vLLM models.
|
||||
|
||||
Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||
"""
|
||||
import math
|
||||
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"cross-encoder/ms-marco-MiniLM-L-6-v2", # Bert
|
||||
"BAAI/bge-reranker-v2-m3", # Roberta
|
||||
]
|
||||
|
||||
TEXTS_1 = [
|
||||
"What is the capital of France?",
|
||||
"What is the capital of Germany?",
|
||||
]
|
||||
|
||||
TEXTS_2 = [
|
||||
"The capital of France is Paris.",
|
||||
"The capital of Germany is Berlin.",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=MODELS)
|
||||
def model_name(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
|
||||
text_pair = [TEXTS_1[0], TEXTS_2[0]]
|
||||
|
||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||
hf_outputs = hf_model.predict([text_pair]).tolist()
|
||||
|
||||
with vllm_runner(model_name,
|
||||
task="embedding",
|
||||
dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
|
||||
|
||||
assert len(vllm_outputs) == 1
|
||||
assert len(hf_outputs) == 1
|
||||
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
|
||||
text_pairs = [
|
||||
[TEXTS_1[0], TEXTS_2[0]],
|
||||
[TEXTS_1[0], TEXTS_2[1]],
|
||||
]
|
||||
|
||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||
hf_outputs = hf_model.predict(text_pairs).tolist()
|
||||
|
||||
with vllm_runner(model_name,
|
||||
task="embedding",
|
||||
dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
|
||||
|
||||
assert len(vllm_outputs) == 2
|
||||
assert len(hf_outputs) == 2
|
||||
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
|
||||
text_pairs = [
|
||||
[TEXTS_1[0], TEXTS_2[0]],
|
||||
[TEXTS_1[1], TEXTS_2[1]],
|
||||
]
|
||||
|
||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||
hf_outputs = hf_model.predict(text_pairs).tolist()
|
||||
|
||||
with vllm_runner(model_name,
|
||||
task="embedding",
|
||||
dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)
|
||||
|
||||
assert len(vllm_outputs) == 2
|
||||
assert len(hf_outputs) == 2
|
||||
|
||||
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
@ -135,6 +135,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||
@ -143,6 +144,13 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
||||
@ -195,6 +203,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
_EXAMPLE_MODELS = {
|
||||
**_TEXT_GENERATION_EXAMPLE_MODELS,
|
||||
**_EMBEDDING_EXAMPLE_MODELS,
|
||||
**_CROSS_ENCODER_EXAMPLE_MODELS,
|
||||
**_MULTIMODAL_EXAMPLE_MODELS,
|
||||
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
||||
}
|
||||
|
@ -6,7 +6,10 @@ import torch.cuda
|
||||
from vllm.model_executor.models import (is_embedding_model,
|
||||
is_text_generation_model,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
|
||||
_EMBEDDING_MODELS,
|
||||
_MULTIMODAL_MODELS,
|
||||
_SPECULATIVE_DECODING_MODELS,
|
||||
_TEXT_GENERATION_MODELS,
|
||||
@ -29,22 +32,28 @@ def test_registry_imports(model_arch):
|
||||
model_arch in _TEXT_GENERATION_MODELS
|
||||
or model_arch in _MULTIMODAL_MODELS)
|
||||
|
||||
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
|
||||
assert is_embedding_model(model_cls) is (model_arch
|
||||
in _EMBEDDING_MODELS)
|
||||
in embedding_models)
|
||||
|
||||
assert supports_multimodal(model_cls) is (model_arch
|
||||
in _MULTIMODAL_MODELS)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
|
||||
("LlamaForCausalLM", False, False),
|
||||
("MllamaForConditionalGeneration", True, False),
|
||||
("LlavaForConditionalGeneration", True, True),
|
||||
@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [
|
||||
("LlamaForCausalLM", False, False, False),
|
||||
("MllamaForConditionalGeneration", True, False, False),
|
||||
("LlavaForConditionalGeneration", True, True, False),
|
||||
("BertForSequenceClassification", False, False, True),
|
||||
("RobertaForSequenceClassification", False, False, True),
|
||||
("XLMRobertaForSequenceClassification", False, False, True),
|
||||
])
|
||||
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
|
||||
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
|
||||
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
|
||||
|
||||
assert ModelRegistry.is_cross_encoder_model(model_arch) is is_ce
|
||||
|
||||
if init_cuda and current_platform.is_cuda_alike():
|
||||
assert not torch.cuda.is_initialized()
|
||||
|
||||
|
@ -712,6 +712,11 @@ class ModelConfig:
|
||||
def is_multimodal_model(self) -> bool:
|
||||
return self.multimodal_config is not None
|
||||
|
||||
@property
|
||||
def is_cross_encoder(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_cross_encoder_model(architectures)
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
|
@ -1357,6 +1357,7 @@ class Scheduler:
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=cross_block_table,
|
||||
state=seq_group.state,
|
||||
token_type_ids=seq_group.token_type_ids,
|
||||
# `multi_modal_data` will only be present for the 1st comm
|
||||
# between engine and worker.
|
||||
# the subsequent comms can still use delta, but
|
||||
|
@ -20,7 +20,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -817,6 +817,128 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
/,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||
|
||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||
the text_1 sentence will be replicated N times to pair with the text_2
|
||||
sentences. The input pairs are used to build a list of prompts for the
|
||||
cross encoder model. This class automatically batches the prompts,
|
||||
considering the memory constraint. For the best performance, put all
|
||||
of your texts into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
text_1: can be a single prompt or a list of prompts, in which
|
||||
case it has to have the same length as the text_2 list
|
||||
text_2: The texts to pair with the query to form the input
|
||||
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
||||
more details about the format of each prompts.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.score() is only supported for embedding models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if not self.llm_engine.model_config.is_cross_encoder:
|
||||
raise ValueError("Your model does not support the cross encoding")
|
||||
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
# lists of tokens to the `text` and `text_pair` kwargs
|
||||
def ensure_str(prompt: SingletonPrompt):
|
||||
if isinstance(prompt, dict):
|
||||
if "multi_modal_data" in prompt:
|
||||
raise ValueError("Multi-modal prompt is not "
|
||||
"supported for cross encoding")
|
||||
elif "prompt_token_ids" in prompt:
|
||||
prompt = tokenizer.decode(
|
||||
cast(TokensPrompt, prompt)["prompt_token_ids"])
|
||||
elif "prompt" in prompt:
|
||||
prompt = cast(TextPrompt, prompt)["prompt"]
|
||||
assert type(prompt) is str
|
||||
return prompt
|
||||
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [ensure_str(t) for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [ensure_str(t) for t in text_2]
|
||||
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if len(text_1) == 1:
|
||||
text_1 = text_1 * len(text_2)
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
parsed_prompts = []
|
||||
|
||||
for q, t in input_pairs:
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
parsed_prompts.append(engine_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
||||
|
@ -45,6 +45,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
@ -280,6 +282,10 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
@ -391,6 +397,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/score")
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Score API")
|
||||
|
||||
generator = await handler.create_score(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, ScoreResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
@ -466,8 +489,9 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
err = ErrorResponse(message=str(exc),
|
||||
type="BadRequestError",
|
||||
code=HTTPStatus.BAD_REQUEST)
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@ -565,6 +589,13 @@ def init_app_state(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embedding" else None
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger
|
||||
) if (model_config.task == "embedding" \
|
||||
and model_config.is_cross_encoder) else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
@ -806,6 +806,27 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
text_1: Union[List[str], str]
|
||||
text_2: Union[List[str], str]
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-chat-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
# doc: end-chat-embedding-pooling-params
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
@ -876,6 +897,21 @@ class EmbeddingResponse(OpenAIBaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ScoreResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "score"
|
||||
score: Union[List[float], str]
|
||||
|
||||
|
||||
class ScoreResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: List[ScoreResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
215
vllm/entrypoints/openai/serving_score.py
Normal file
215
vllm/entrypoints/openai/serving_score.py
Normal file
@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||
ScoreResponse, ScoreResponseData,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def request_output_to_score_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str) -> ScoreResponse:
|
||||
data: List[ScoreResponseData] = []
|
||||
score = None
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
if final_res is not None:
|
||||
score = final_res.outputs.embedding
|
||||
score_data = ScoreResponseData(index=idx, score=score)
|
||||
data.append(score_data)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=data,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
|
||||
str]) -> List:
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [t for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [t for t in text_2]
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if len(text_1) == 1:
|
||||
text_1 = text_1 * len(text_2)
|
||||
|
||||
return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
|
||||
|
||||
class OpenAIServingScores(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[ScoreResponse, ErrorResponse]:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"score-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
request_prompts = []
|
||||
engine_prompts = []
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
if not self.model_config.is_cross_encoder:
|
||||
raise ValueError("Model is not cross encoder.")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
|
||||
input_pairs = make_pairs(request.text_1, request.text_2)
|
||||
|
||||
for q, t in input_pairs:
|
||||
request_prompt = f"{q}{tokenizer.sep_token}{t}"
|
||||
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_score_response(
|
||||
final_res_batch_checked, request_id, created_time, model_name)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
@ -38,6 +38,9 @@ class TokensPrompt(TypedDict):
|
||||
prompt_token_ids: List[int]
|
||||
"""A list of token IDs to pass to the model."""
|
||||
|
||||
token_type_ids: NotRequired[List[int]]
|
||||
"""A list of token type IDs to pass to the cross encoder model."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||
"""
|
||||
DEPRECATED: Optional multi-modal data to pass to the model,
|
||||
@ -133,6 +136,9 @@ class TokenInputs(TypedDict):
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
token_type_ids: NotRequired[List[int]]
|
||||
"""The token type IDs of the prompt."""
|
||||
|
||||
prompt: NotRequired[str]
|
||||
"""
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
@ -160,6 +166,7 @@ class TokenInputs(TypedDict):
|
||||
|
||||
def token_inputs(
|
||||
prompt_token_ids: List[int],
|
||||
token_type_ids: Optional[List[int]] = None,
|
||||
prompt: Optional[str] = None,
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
|
||||
@ -170,6 +177,8 @@ def token_inputs(
|
||||
|
||||
if prompt is not None:
|
||||
inputs["prompt"] = prompt
|
||||
if token_type_ids is not None:
|
||||
inputs["token_type_ids"] = token_type_ids
|
||||
if multi_modal_data is not None:
|
||||
inputs["multi_modal_data"] = multi_modal_data
|
||||
if multi_modal_placeholders is not None:
|
||||
@ -234,6 +243,15 @@ class SingletonInputsAdapter:
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def token_type_ids(self) -> List[int]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||
return inputs.get("token_type_ids", [])
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
inputs = self.inputs
|
||||
|
@ -305,6 +305,7 @@ class InputPreprocessor:
|
||||
tokens_content = parsed["content"]
|
||||
|
||||
prompt_token_ids = tokens_content["prompt_token_ids"]
|
||||
token_type_ids = tokens_content.get("token_type_ids")
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
@ -318,6 +319,7 @@ class InputPreprocessor:
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
@ -3,11 +3,14 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
@ -152,3 +155,64 @@ class Pooler(nn.Module):
|
||||
]
|
||||
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
|
||||
class CrossEncodingPooler(nn.Module):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
|
||||
Attributes:
|
||||
pooling_type: The type of pooling to use.
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
classifier: nn.Module,
|
||||
pooler: Optional[nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.classifier = classifier
|
||||
self.pooler = pooler
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
"""Pools sentence pair scores from the hidden_states."""
|
||||
|
||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
offset = 0
|
||||
pooled_data_lst = []
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
|
||||
if self.pooler is not None:
|
||||
final_shape_tensor = self.pooler(pooled_data_i)
|
||||
else:
|
||||
final_shape_tensor = self.classifier(pooled_data_i)
|
||||
|
||||
pooled_data_lst.append(final_shape_tensor)
|
||||
offset += prompt_len
|
||||
|
||||
pooled_output = torch.stack(pooled_data_lst)
|
||||
|
||||
if self.pooler is not None:
|
||||
# apply classifier once on the full batch if possible
|
||||
pooled_output = self.classifier(pooled_output)
|
||||
logits = self.default_activation_function(pooled_output)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in logits
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
@ -11,14 +11,18 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
@ -48,7 +52,9 @@ class BertEmbedding(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
|
||||
@ -58,17 +64,34 @@ class BertEmbedding(nn.Module):
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
# Token type embeddings. (TODO: move off hotpath?)
|
||||
token_type_embeddings = self.token_type_embeddings(
|
||||
torch.zeros(input_shape,
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device))
|
||||
device=inputs_embeds.device)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[0, :]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -309,7 +332,8 @@ class BertModel(nn.Module):
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding):
|
||||
embedding_class: type = BertEmbedding,
|
||||
add_pooling_layer: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -319,6 +343,7 @@ class BertModel(nn.Module):
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -328,13 +353,17 @@ class BertModel(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embeddings(input_ids=input_ids,
|
||||
position_ids=position_ids)
|
||||
|
||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
@ -349,7 +378,7 @@ class BertModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "pooler" in name:
|
||||
if self.pooler is None and "pooler" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
@ -430,3 +459,78 @@ class BertEmbeddingModel(nn.Module):
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
|
||||
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of BertModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = CrossEncodingPooler(config, self.classifier,
|
||||
self.bert.pooler)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
self_weights = []
|
||||
|
||||
def weight_filter():
|
||||
for name, weight in weights:
|
||||
if name.startswith("bert."):
|
||||
yield (name[len("bert."):], weight)
|
||||
else:
|
||||
self_weights.append((name, weight))
|
||||
|
||||
self.bert.load_weights(weight_filter())
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in self_weights:
|
||||
if name.startswith("classifier"):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.bert(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
attn_metadata=attn_metadata,
|
||||
token_type_ids=token_type_ids)
|
||||
|
@ -7,6 +7,8 @@ from typing_extensions import TypeIs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import supports_kw
|
||||
|
||||
from .interfaces_base import is_embedding_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -350,3 +352,37 @@ def is_attention_free(
|
||||
return isinstance(model, _IsAttentionFreeType)
|
||||
|
||||
return isinstance(model, IsAttentionFree)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsCrossEncoding(Protocol):
|
||||
"""The interface required for all models that support cross encoding."""
|
||||
|
||||
supports_cross_encoding: ClassVar[Literal[True]] = True
|
||||
|
||||
|
||||
@overload
|
||||
def supports_cross_encoding(
|
||||
model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
|
||||
...
|
||||
|
||||
|
||||
def _supports_cross_encoding(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
||||
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, SupportsCrossEncoding)
|
||||
|
||||
return isinstance(model, SupportsCrossEncoding)
|
||||
|
||||
|
||||
def supports_cross_encoding(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
||||
return is_embedding_model(model) and _supports_cross_encoding(model)
|
||||
|
@ -21,7 +21,8 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_multimodal, supports_pp)
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -100,6 +101,7 @@ _EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
@ -121,6 +123,14 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||
"RobertaForSequenceClassification": ("roberta",
|
||||
"RobertaForSequenceClassification"),
|
||||
"XLMRobertaForSequenceClassification": ("roberta",
|
||||
"RobertaForSequenceClassification"),
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
@ -159,6 +169,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
_VLLM_MODELS = {
|
||||
**_TEXT_GENERATION_MODELS,
|
||||
**_EMBEDDING_MODELS,
|
||||
**_CROSS_ENCODER_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
}
|
||||
@ -193,6 +204,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
class _ModelInfo:
|
||||
is_text_generation_model: bool
|
||||
is_embedding_model: bool
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
@ -203,6 +215,7 @@ class _ModelInfo:
|
||||
return _ModelInfo(
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_embedding_model=is_embedding_model(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
@ -415,6 +428,12 @@ class _ModelRegistry:
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_embedding_model
|
||||
|
||||
def is_cross_encoder_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_cross_encoding
|
||||
|
||||
def is_multimodal_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -6,10 +6,17 @@ from transformers import RobertaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
@ -39,34 +46,93 @@ class RobertaEmbedding(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
|
||||
# Input embeddings.
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# TODO: figure out if there is a better way
|
||||
# to make to make position ids start at padding_idx + 1
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
position_ids += self.padding_idx + 1
|
||||
pos_list = []
|
||||
token_list = []
|
||||
offset = 0
|
||||
for seq_len in seq_lens:
|
||||
pos_list.append(position_ids[offset:offset + seq_len])
|
||||
token_list.append(input_ids[offset:offset + seq_len])
|
||||
offset += seq_len
|
||||
|
||||
new_pos_list = []
|
||||
for positions, tokens in zip(pos_list, token_list):
|
||||
# Verify assumption that incoming position are
|
||||
# always a sequence from 0 to N.
|
||||
expected_pos = torch.arange(positions.size()[0],
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
assert torch.equal(positions, expected_pos)
|
||||
new_pos_list.append(
|
||||
create_position_ids_from_input_ids(tokens, self.padding_idx))
|
||||
position_ids = torch.cat(new_pos_list)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
# Token type embeddings. (TODO: move off hotpath?)
|
||||
token_type_embeddings = self.token_type_embeddings(
|
||||
torch.zeros(input_shape,
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device))
|
||||
device=inputs_embeds.device)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
def create_position_ids_from_input_ids(input_ids,
|
||||
padding_idx,
|
||||
past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers.
|
||||
Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
|
||||
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[0, :] # take <s> token (equiv. to [CLS])
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
@ -85,6 +151,62 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
|
||||
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
roberta: An instance of BertModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.default_activation_function = \
|
||||
get_cross_encoder_activation_function(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=RobertaEmbedding,
|
||||
add_pooling_layer=False)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
self._pooler = CrossEncodingPooler(config, self.classifier)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
self_weights = []
|
||||
|
||||
def weight_filter():
|
||||
for name, weight in weights:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta."):], weight)
|
||||
else:
|
||||
self_weights.append((name, weight))
|
||||
|
||||
self.roberta.load_weights(weight_filter())
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in self_weights:
|
||||
if name.startswith("classifier"):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
@ -93,25 +215,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Verify assumption that position are always a sequence from
|
||||
# 0 to N. (Actually here we just check 0 and N to simplify).
|
||||
# This is important to fix the position which are assumed to
|
||||
# start from padding_idx + 1 instead of 0 in the Roberta models.
|
||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
|
||||
start_pos = torch.cat(
|
||||
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
|
||||
cumulative[:-1]))
|
||||
assert len(torch.nonzero(positions[start_pos])) == 0
|
||||
end_pos = cumulative - 1
|
||||
last_tokens = attn_metadata.seq_lens_tensor - 1
|
||||
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
|
||||
|
||||
return super().forward(input_ids=input_ids,
|
||||
positions=positions,
|
||||
return self.roberta(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
attn_metadata=attn_metadata,
|
||||
token_type_ids=token_type_ids)
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_extensions import NotRequired, TypeAlias
|
||||
|
||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||
|
||||
@ -208,6 +208,9 @@ class MultiModalInputsV2(TypedDict):
|
||||
prompt_token_ids: List[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
token_type_ids: NotRequired[List[int]]
|
||||
"""The token type IDs of the prompt."""
|
||||
|
||||
mm_kwargs: MultiModalKwargs
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
|
@ -60,7 +60,6 @@ class EmbeddingOutput:
|
||||
embedding: The embedding vector, which is a list of floats. The
|
||||
length of vector depends on the model as listed in the embedding guide.
|
||||
"""
|
||||
|
||||
embedding: List[float]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -363,6 +362,50 @@ class EmbeddingRequestOutput:
|
||||
f"finished={self.finished})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoreOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
Args:
|
||||
score: The score, which is a list of floats.
|
||||
index: The correspondent text index of the score.
|
||||
"""
|
||||
index: int
|
||||
score: List[float]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"ScoreOutput("
|
||||
f"score={self.score}), "
|
||||
f"index={self.index})")
|
||||
|
||||
|
||||
class ScoreRequestOutput:
|
||||
"""
|
||||
The output data of an score request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id (str): A unique identifier for the score request.
|
||||
outputs (score): The embedding results for the given input.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, outputs: "ScoreOutput"):
|
||||
self.request_id = request_id
|
||||
self.outputs = outputs
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Returns a string representation of an ScoreRequestOutput instance.
|
||||
|
||||
The representation includes the request_id and the number of outputs,
|
||||
providing a quick overview of the embedding request's results.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the ScoreRequestOutput instance.
|
||||
"""
|
||||
return (f"ScoreRequestOutput(request_id='{self.request_id}', "
|
||||
f"outputs={repr(self.outputs)}")
|
||||
|
||||
|
||||
class RequestOutputFactory:
|
||||
|
||||
@staticmethod
|
||||
|
@ -449,6 +449,10 @@ class Sequence:
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
return self.inputs.prompt_embeds
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> List[int]:
|
||||
return self.inputs.token_type_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
return self.inputs.multi_modal_data
|
||||
@ -687,6 +691,10 @@ class SequenceGroup:
|
||||
return (self.encoder_seq.prompt_token_ids
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> Optional[List[int]]:
|
||||
return self.first_seq.token_type_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> MultiModalDataDict:
|
||||
return self.first_seq.multi_modal_data
|
||||
@ -909,6 +917,7 @@ class SequenceGroupMetadata(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||
# doesn't allow to have union of 2 different dicts.
|
||||
token_type_ids: Optional[List[int]] = None
|
||||
multi_modal_data: Optional[Any] = None
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
@ -9,6 +9,7 @@ from huggingface_hub import (file_exists, hf_hub_download,
|
||||
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError)
|
||||
from torch import nn
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import (
|
||||
get_image_processor_config)
|
||||
@ -31,6 +32,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
@ -577,3 +579,16 @@ def try_get_generation_config(
|
||||
return GenerationConfig.from_model_config(config)
|
||||
except OSError: # Not found
|
||||
return None
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
if (hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None):
|
||||
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
assert function_name.startswith("torch.nn.modules."), \
|
||||
"Loading of activation functions is restricted to " \
|
||||
"torch.nn.modules for security reasons"
|
||||
return resolve_obj_by_qualname(function_name)()
|
||||
else:
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
|
||||
|
@ -50,6 +50,9 @@ class CPUEmbeddingModelRunner(
|
||||
]
|
||||
|
||||
model_executable = self.model
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_type_ids is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
@ -61,6 +64,7 @@ class CPUEmbeddingModelRunner(
|
||||
model_input.attn_metadata,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
**cross_enc_kwargs,
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_type_ids: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
virtual_engine: Optional[int] = None
|
||||
@ -54,6 +55,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"token_type_ids": self.token_type_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
@ -83,6 +85,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"token_type_ids": self.token_type_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
@ -112,6 +115,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
self.input_tokens: List[int] = []
|
||||
self.input_positions: Optional[
|
||||
List[int]] = [] if not self.use_mrope else None
|
||||
self.token_type_ids: Optional[List[int]] = []
|
||||
self.seq_lens: List[int] = []
|
||||
self.query_lens: List[int] = []
|
||||
self.prefill_block_tables: List[List[int]] = []
|
||||
@ -165,6 +169,10 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
if not input_data.use_mrope else input_data.input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
token_type_ids = torch.tensor(input_data.token_type_ids,
|
||||
dtype=torch.long,
|
||||
device="cpu") \
|
||||
if input_data.token_type_ids else None
|
||||
|
||||
# For multi-modal models
|
||||
multi_modal_kwargs = None
|
||||
@ -178,6 +186,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
token_type_ids=token_type_ids,
|
||||
seq_lens=input_data.seq_lens,
|
||||
query_lens=input_data.query_lens,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -285,6 +294,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
tokens = seq_data.get_token_ids()
|
||||
tokens = tokens[context_len:seq_len]
|
||||
token_positions = range(context_len, seq_len)
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
# For encoder-only models, the block_table is None,
|
||||
# and there is no need to initialize the slot_mapping.
|
||||
@ -301,6 +311,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
if data.input_positions is not None:
|
||||
data.input_positions.extend(token_positions)
|
||||
|
||||
if data.token_type_ids is not None:
|
||||
data.token_type_ids.extend(token_types if token_types else [])
|
||||
|
||||
# Update fields
|
||||
data.input_tokens.extend(tokens)
|
||||
data.num_prefills += 1
|
||||
|
@ -97,6 +97,10 @@ class EmbeddingModelRunner(
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_types is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
@ -105,7 +109,8 @@ class EmbeddingModelRunner(
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device))
|
||||
device=self.device),
|
||||
**cross_enc_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
|
@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_types: Optional[torch.Tensor] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
@ -200,6 +201,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
def simple_reinit(self):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.token_types[0].clear() # type: ignore
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
self.seq_lens[0] = 0 # type: ignore
|
||||
self.orig_seq_lens[0] = 0 # type: ignore
|
||||
@ -226,6 +228,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# Input tokens and positions.
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
token_types: Optional[List[List[int]]] = None,
|
||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
@ -291,6 +294,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_positions[seq_id].clear()
|
||||
|
||||
if token_types:
|
||||
self.token_types = token_types
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.token_types[seq_id].clear()
|
||||
|
||||
self.mrope_input_positions = None
|
||||
|
||||
if seq_lens:
|
||||
@ -354,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.input_positions = input_positions or []
|
||||
self.token_types = token_types or []
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
@ -386,6 +396,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||||
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||||
self.token_types = [[] for _ in range(self.n_seqs)]
|
||||
self.mrope_input_positions = None
|
||||
self.seq_lens = [0] * self.n_seqs
|
||||
self.orig_seq_lens = [0] * self.n_seqs
|
||||
@ -498,12 +509,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
# Compute tokens.
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||
inter_data.token_types[seq_idx].extend(
|
||||
token_types if token_types else [])
|
||||
inter_data.query_lens[seq_idx] = seq_len - context_len
|
||||
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
@ -561,6 +575,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
seq_idx][uncomputed_start:]
|
||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||
seq_idx][uncomputed_start:]
|
||||
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
|
||||
uncomputed_start:]
|
||||
context_len = prefix_cache_len
|
||||
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
@ -575,6 +591,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
seq_idx][-1:]
|
||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||
seq_idx][-1:]
|
||||
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
|
||||
-1:]
|
||||
inter_data.query_lens[seq_idx] = 1
|
||||
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
|
||||
|
||||
@ -803,9 +821,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = []
|
||||
token_types = []
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
input_tokens.extend(cur_input_tokens)
|
||||
for cur_token_types in inter_data.token_types:
|
||||
token_types.extend(cur_token_types)
|
||||
|
||||
if not input_tokens:
|
||||
# This may happen when all prefill requests hit
|
||||
@ -874,6 +895,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
|
||||
token_types_tensor = async_tensor_h2d(token_types, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory) \
|
||||
if token_types else None
|
||||
|
||||
if mrope_input_positions is not None:
|
||||
for idx in range(3):
|
||||
mrope_input_positions[idx].extend(
|
||||
@ -952,6 +979,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
token_types=token_types_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
|
Loading…
x
Reference in New Issue
Block a user