[Frontend] support matryoshka representation / support embedding API dimensions (#16331)
This commit is contained in:
parent
e92d7085bf
commit
fbf722c6e6
48
examples/offline_inference/embed_matryoshka_fy.py
Normal file
48
examples/offline_inference/embed_matryoshka_fy.py
Normal file
@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs, PoolingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Follow the white rabbit.", # English
|
||||
"Sigue al conejo blanco.", # Spanish
|
||||
"Suis le lapin blanc.", # French
|
||||
"跟着白兔走。", # Chinese
|
||||
"اتبع الأرنب الأبيض.", # Arabic
|
||||
"Folge dem weißen Kaninchen.", # German
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass task="embed" for embedding models
|
||||
model = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = model.embed(prompts, pooling_params=PoolingParams(dimensions=32))
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:")
|
||||
print("-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings: {embeds_trimmed} "
|
||||
f"(size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -960,19 +960,19 @@ class VllmRunner:
|
||||
req_outputs = self.model.classify(prompts)
|
||||
return [req_output.outputs.probs for req_output in req_outputs]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[list[float]]:
|
||||
def encode(self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
*args,
|
||||
**kwargs) -> list[list[float]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
req_outputs = self.model.embed(inputs)
|
||||
req_outputs = self.model.embed(inputs, *args, **kwargs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def score(
|
||||
|
82
tests/entrypoints/openai/test_embedding_dimensions.py
Normal file
82
tests/entrypoints/openai/test_embedding_dimensions.py
Normal file
@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
class ModelInfo(NamedTuple):
|
||||
name: str
|
||||
is_matryoshka: bool
|
||||
|
||||
|
||||
MODELS = [
|
||||
ModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
|
||||
ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
|
||||
]
|
||||
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
] * 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
async def test_validating_dimensions(model: ModelInfo):
|
||||
args = [
|
||||
"--task",
|
||||
"embed",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--trust_remote_code"
|
||||
]
|
||||
with RemoteOpenAIServer(model.name, args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
async def make_request(dimensions):
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model.name,
|
||||
input=input_texts,
|
||||
dimensions=dimensions,
|
||||
encoding_format="float",
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 3
|
||||
assert len(embeddings.data[0].embedding) > 0
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens > 0
|
||||
assert embeddings.usage.total_tokens > 0
|
||||
|
||||
if dimensions is not None:
|
||||
assert len(embeddings.data[0].embedding) == dimensions
|
||||
|
||||
if model.is_matryoshka:
|
||||
for dimensions in [None, 16]:
|
||||
await make_request(dimensions)
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
for dimensions in [-1]:
|
||||
await make_request(dimensions)
|
||||
|
||||
else:
|
||||
for dimensions in [None]:
|
||||
await make_request(dimensions)
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
for dimensions in [-1, 16]:
|
||||
await make_request(dimensions)
|
@ -8,7 +8,8 @@ import math
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.embedding.utils import check_embeddings_close
|
||||
from tests.models.embedding.utils import check_embeddings_close, matryoshka_fy
|
||||
from vllm import PoolingParams
|
||||
|
||||
SCORING_MODELS = [
|
||||
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
|
||||
@ -126,3 +127,40 @@ def test_embeddings(
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dimensions", [16, 32])
|
||||
def test_matryoshka(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype: str,
|
||||
dimensions: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
|
||||
example_prompts = EMBEDDING_PROMPTS
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
|
||||
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
|
||||
|
||||
with vllm_runner(model, task="embed", dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(
|
||||
example_prompts,
|
||||
pooling_params=PoolingParams(dimensions=dimensions))
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
@ -30,3 +30,10 @@ def check_embeddings_close(
|
||||
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
||||
|
||||
assert sim >= 1 - tol, fail_msg
|
||||
|
||||
|
||||
def matryoshka_fy(tensor, dimensions):
|
||||
tensor = torch.tensor(tensor)
|
||||
tensor = tensor[..., :dimensions]
|
||||
tensor = F.normalize(tensor, p=2, dim=1)
|
||||
return tensor
|
||||
|
@ -583,6 +583,15 @@ class ModelConfig:
|
||||
if getattr(user_config, k) is None:
|
||||
setattr(user_config, k, v)
|
||||
|
||||
if self.is_matryoshka:
|
||||
if user_config.normalize is None:
|
||||
user_config.normalize = True
|
||||
elif not user_config.normalize:
|
||||
raise ValueError(
|
||||
"`normalize` must be enabled (set to True) "
|
||||
"for models that are compatible with "
|
||||
"Matryoshka Representation.")
|
||||
|
||||
return user_config
|
||||
|
||||
return None
|
||||
|
@ -921,6 +921,11 @@ class LLM:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
elif isinstance(pooling_params, PoolingParams):
|
||||
pooling_params.verify(self.llm_engine.model_config)
|
||||
else:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(self.llm_engine.model_config)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
@ -939,6 +944,8 @@ class LLM:
|
||||
/,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> list[EmbeddingRequestOutput]:
|
||||
@ -953,6 +960,8 @@ class LLM:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
@ -968,6 +977,7 @@ class LLM:
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
|
@ -1006,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
@ -1068,7 +1069,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
|
||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
|
@ -80,9 +80,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"embd-{self._base_request_id(raw_request)}"
|
||||
@ -99,6 +96,13 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify(self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -146,8 +150,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
|
@ -97,7 +97,7 @@ class SimplePooler(nn.Module):
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
pooled_outputs = [self.build_output(data) for data in pooled_data]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
@ -217,14 +217,28 @@ class PoolerHead(nn.Module):
|
||||
self.normalize = normalize
|
||||
self.softmax = softmax
|
||||
|
||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
|
||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata):
|
||||
|
||||
dimensions_list = [
|
||||
pooling_param.dimensions
|
||||
for _, pooling_param in pooling_metadata.seq_groups
|
||||
]
|
||||
if any(d is not None for d in dimensions_list):
|
||||
# change the output dimension
|
||||
assert len(pooled_data) == len(dimensions_list)
|
||||
pooled_data = [
|
||||
vecs if d is None else vecs[..., :d]
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
|
||||
if self.normalize:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
F.normalize(data, p=2, dim=1) for data in pooled_data
|
||||
F.normalize(data, p=2, dim=-1) for data in pooled_data
|
||||
]
|
||||
else:
|
||||
pooled_data = F.normalize(pooled_data, p=2, dim=1)
|
||||
pooled_data = F.normalize(pooled_data, p=2, dim=-1)
|
||||
|
||||
if self.softmax:
|
||||
if isinstance(pooled_data, list):
|
||||
|
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
class PoolingParams(
|
||||
msgspec.Struct,
|
||||
@ -12,14 +15,30 @@ class PoolingParams(
|
||||
"""API parameters for pooling models. This is currently a placeholder.
|
||||
|
||||
Attributes:
|
||||
dimensions: Reduce the dimensions of embeddings
|
||||
if model support matryoshka representation.
|
||||
additional_data: Any additional data needed for pooling.
|
||||
"""
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
additional_data=self.additional_data)
|
||||
|
||||
def verify(self, model_config: "ModelConfig") -> None:
|
||||
if self.dimensions is not None:
|
||||
if not model_config.is_matryoshka:
|
||||
raise ValueError(
|
||||
f'Model "{model_config.served_model_name}" does not '
|
||||
f'support matryoshka representation, '
|
||||
f'changing output dimensions will lead to poor results.')
|
||||
if self.dimensions < 1:
|
||||
raise ValueError("Dimensions must be greater than 0")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"additional_metadata={self.additional_data})")
|
||||
|
Loading…
x
Reference in New Issue
Block a user