[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)

This commit is contained in:
Chang Su 2024-05-11 11:30:37 -07:00 committed by GitHub
parent 4e12131089
commit e254497b66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 1627 additions and 160 deletions

View File

@ -0,0 +1,17 @@
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats

View File

@ -0,0 +1,23 @@
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
responses = client.embeddings.create(input=[
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
],
model=model)
for data in responses.data:
print(data.embedding) # list of float of len 4096

View File

@ -19,12 +19,15 @@ pytest-forked
pytest-asyncio
pytest-rerunfailures
pytest-shard
httpx
# testing utils
awscli
einops # required for MPT
httpx
peft
requests
ray
peft
awscli
sentence-transformers # required for embedding
# Benchmarking
aiohttp

View File

@ -133,6 +133,10 @@ _VISION_LANGUAGE_MODELS = {
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
}
_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
class HfRunner:
@ -145,14 +149,7 @@ class HfRunner:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if model_name not in _VISION_LANGUAGE_MODELS:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
else:
if model_name in _VISION_LANGUAGE_MODELS:
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
model_name,
torch_dtype=torch_dtype,
@ -162,6 +159,20 @@ class HfRunner:
model_name,
torch_dtype=torch_dtype,
)
elif model_name in _EMBEDDING_MODELS:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype).cuda()
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
@ -334,6 +345,9 @@ class HfRunner:
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts)
def __del__(self):
del self.model
cleanup()
@ -459,6 +473,14 @@ class VllmRunner:
outputs = self.generate(prompts, beam_search_params)
return outputs
def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs
def __del__(self):
del self.model
cleanup()

View File

@ -9,8 +9,8 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
new_token_ids = list(range(num_new_tokens))
outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
new_token_ids = list(range(num_new_tokens))
outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids[eos_index] = eos_token_id
outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids[eos_index] = eos_token_id
outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,

View File

@ -14,6 +14,7 @@ class MockModelConfig:
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
@dataclass

View File

@ -23,6 +23,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@ -121,7 +122,7 @@ def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def server(zephyr_lora_files):
ray.init()
server_runner = ServerRunner.remote([
@ -150,6 +151,25 @@ def server(zephyr_lora_files):
ray.shutdown()
@pytest.fixture(scope="module")
def embedding_server(zephyr_lora_files):
ray.shutdown()
ray.init()
server_runner = ServerRunner.remote([
"--model",
EMBEDDING_MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
@pytest.fixture(scope="module")
def client():
client = openai.AsyncOpenAI(
@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
input = [
"The chef prepared a delicious meal.",
]
# test single embedding
embeddings = await client.embeddings.create(
model=model_name,
input=input,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 9
assert embeddings.usage.total_tokens == 9
# test using token IDs
input = [1, 1, 1, 1, 1]
embeddings = await client.embeddings.create(
model=model_name,
input=input,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 5
assert embeddings.usage.total_tokens == 5
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
# test List[str]
inputs = [
"The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky."
]
embeddings = await client.embeddings.create(
model=model_name,
input=inputs,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096
# test List[List[int]]
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
embeddings = await client.embeddings.create(
model=model_name,
input=inputs,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 4
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -0,0 +1,44 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_llama_embedding.py`.
"""
import pytest
import torch
import torch.nn.functional as F
MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.encode(example_prompts)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.encode(example_prompts)
del vllm_model
similarities = compare_embeddings(hf_outputs, vllm_outputs)
all_similarities = torch.stack(similarities)
tolerance = 1e-2
assert torch.all((all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"Not all values are within {tolerance} of 1.0"

View File

@ -36,14 +36,14 @@ def test_logits_processor_force_generate(
# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[0],
sampling_params=params_with_logprobs,
params=params_with_logprobs,
prompt_token_ids=None,
)
# test prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[1],
sampling_params=SamplingParams(
params=SamplingParams(
prompt_logprobs=3,
max_tokens=max_tokens,
),
@ -53,7 +53,7 @@ def test_logits_processor_force_generate(
# test grouped requests
vllm_model.model._add_request(
prompt=example_prompts[2],
sampling_params=SamplingParams(max_tokens=max_tokens),
params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None,
)

View File

@ -60,7 +60,7 @@ def test_random_sample_with_seed(
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=params,
params=params,
)
results = llm._run_engine(use_tqdm=False)

View File

@ -7,8 +7,8 @@ import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SamplerOutput, SequenceData,
SequenceGroupMetadata, SequenceGroupOutput,
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
@ -170,7 +170,7 @@ def create_sampler_output_list(
return [
SamplerOutput(outputs=[
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
output_token=token_id,

View File

@ -1,17 +1,17 @@
import pytest
from tests.core.utils import create_dummy_prompt
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)
@pytest.fixture
def sample_outputs():
return [
SequenceGroupOutput(samples=[
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
],
prompt_logprobs=None) for i in range(5)
prompt_logprobs=None) for i in range(5)
]
@ -32,10 +32,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs):
def test_sampler_output_setitem(sampler_output):
new_output = SequenceGroupOutput(samples=[
new_output = CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
],
prompt_logprobs=None)
prompt_logprobs=None)
sampler_output[2] = new_output
assert sampler_output[2] == new_output

View File

@ -6,7 +6,9 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
__version__ = "0.4.2"
@ -17,9 +19,12 @@ __all__ = [
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
]

View File

@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
@ -22,6 +23,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
class ModelConfig:
@ -126,6 +128,7 @@ class ModelConfig:
served_model_name)
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()
self._verify_quantization()
self._verify_cuda_graph()
@ -137,6 +140,11 @@ class ModelConfig:
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
@ -591,6 +599,7 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
"""
def __init__(
@ -602,6 +611,7 @@ class SchedulerConfig:
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
@ -610,6 +620,10 @@ class SchedulerConfig:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
self.max_num_batched_tokens = 512
elif embedding_mode:
# For embedding, choose specific value for higher throughput
self.max_num_batched_tokens = max(
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
@ -623,6 +637,7 @@ class SchedulerConfig:
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self._verify_args()

View File

@ -0,0 +1,84 @@
from typing import List, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
"""An embedding version of BlockSpaceManager for use in environments
with embedding models where block management is not required.
This class provides the same interface as BlockSpaceManager, but its
methods perform no actions or return simple values like True in specific
actions. It's designed to be used in scenarios where the overhead of
block management is unnecessary, such as in an embedding environment.
"""
def __init__(
self,
**kwargs,
) -> None:
pass
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# Always return OK for dummy purposes
return AllocStatus.OK
def allocate(self, seq_group: SequenceGroup) -> None:
# No actual allocation logic needed
pass
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
return True
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
) -> List[Tuple[int, int]]:
return None # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
return AllocStatus.OK
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> List[Tuple[int, int]]:
return None # type: ignore
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return True
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
return None # type: ignore
def free(self, seq: Sequence) -> None:
# No operation on free
return
def get_block_table(self, seq: Sequence) -> List[int]:
return None # type: ignore
def get_num_free_gpu_blocks(self) -> int:
return 1
def get_num_free_cpu_blocks(self) -> int:
return 1
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

View File

@ -35,6 +35,11 @@ class BlockSpaceManager(ABC):
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2
if version == "embedding":
from vllm.core.embedding_model_block_manager import (
EmbeddingModelBlockSpaceManager)
return EmbeddingModelBlockSpaceManager
raise ValueError(f"Unknown version {version=}")
@abstractmethod

View File

@ -270,9 +270,14 @@ class Scheduler:
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
if self.scheduler_config.embedding_mode:
version = "embedding"
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version="v2" if self.scheduler_config.
use_v2_block_manager else "v1")
version)
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
@ -968,6 +973,7 @@ class Scheduler:
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
pooling_params=seq_group.pooling_params,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,

View File

@ -574,6 +574,7 @@ class EngineArgs:
speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,

View File

@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.usage.usage_lib import UsageContext
@ -47,15 +48,16 @@ def _raise_exception_on_finish(
class AsyncStream:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
@ -71,7 +73,7 @@ class AsyncStream:
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
result = await self._queue.get()
if isinstance(result, Exception):
raise result
@ -108,7 +110,8 @@ class RequestTracker:
self.abort_request(rid)
def process_request_output(self,
request_output: RequestOutput,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
@ -196,7 +199,8 @@ class RequestTracker:
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]:
async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
return self.add_request(request_id,
prompt=prompt,
params=params,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
@ -511,7 +515,7 @@ class AsyncLLMEngine:
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@ -528,9 +532,9 @@ class AsyncLLMEngine:
max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"sampling_params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt,
sampling_params, shortened_token_ids, lora_request)
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)
if not self.is_running:
if self.start_engine_loop:
@ -562,7 +566,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
@ -597,8 +601,8 @@ class AsyncLLMEngine:
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
request.
The output `RequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
@ -643,25 +647,123 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
# Preprocess the request.
arrival_time = time.time()
try:
stream = await self.add_request(
async for output in self.process_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
prompt_token_ids,
lora_request,
multi_modal_data,
):
yield output
async def encode(
self,
prompt: Optional[str],
pooling_params: PoolingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.encode(
>>> example_input["input"],
>>> PoolingParams(),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
request_id,
prompt,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
):
yield output
async def process_request(
self,
request_id: str,
prompt: Optional[str],
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time()
stream = await self.add_request(
request_id,
prompt,
params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
try:
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id)
raise e

View File

@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
MultiModalData, PoolerOutput, SamplerOutput,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
@ -169,7 +172,8 @@ class LLMEngine:
load_config=load_config,
)
self._initialize_kv_caches()
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
@ -354,7 +358,7 @@ class LLMEngine:
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@ -370,7 +374,8 @@ class LLMEngine:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
params: Parameters for sampling or pooling. SamplingParams
for text generation. PoolingParams for pooling.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
@ -404,13 +409,6 @@ class LLMEngine:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = self.encode_request(
@ -432,6 +430,50 @@ class LLMEngine:
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
@ -443,11 +485,35 @@ class LLMEngine:
self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, multi_modal_data)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
return seq_group
def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
pooling_params=pooling_params)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
@ -484,13 +550,25 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _process_model_outputs(
self,
output: List[SamplerOutput],
output: List[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]:
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
@ -510,6 +588,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
@ -519,18 +600,19 @@ class LLMEngine:
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
return request_outputs
def step(self) -> List[RequestOutput]:
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
@ -570,7 +652,7 @@ class LLMEngine:
>>> while True:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params)
>>> engine.add_request(str(req_id),prompt,sampling_params)
>>>
>>> # continue the request processing
>>> request_outputs = engine.step()
@ -637,12 +719,15 @@ class LLMEngine:
# KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
gpu_cache_usage_sys = 0.
if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0.
if num_total_cpu > 0:
if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
@ -716,8 +801,10 @@ class LLMEngine:
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
])
best_of_requests.append(seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
for seq in seq_group.get_finished_seqs()

View File

@ -6,13 +6,17 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter
logger = init_logger(__name__)
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
@ -164,8 +168,89 @@ class LLM:
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
"""
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
requests_data = self._validate_and_prepare_requests(
prompts,
sampling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
)
# Add requests to the engine and run the engine
for request_data in requests_data:
self._add_request(**request_data)
return self._run_engine(use_tqdm)
def encode(
self,
prompts: Optional[Union[str, List[str]]] = None,
pooling_params: Optional[Union[PoolingParams,
List[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
"""
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
requests_data = self._validate_and_prepare_requests(
prompts,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
)
# Add requests to the engine and run the engine
for request_data in requests_data:
self._add_request(**request_data)
return self._run_engine(use_tqdm)
def _validate_and_prepare_requests(
self,
prompts: Optional[Union[str, List[str]]],
params: Union[Union[SamplingParams, PoolingParams],
List[Union[SamplingParams,
PoolingParams]]], # Unified parameter
prompt_token_ids: Optional[List[List[int]]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[dict]:
"""Validates and prepares request data for adding to the engine.
Ensures prompts and token IDs are consistent, and returns a list of
dictionaries with request data for further processing.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
@ -188,40 +273,43 @@ class LLM:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and sampling_params "
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
requests_data = []
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(
multi_modal_item = MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0),
) if multi_modal_data else None
requests_data.append({
"prompt":
prompt,
sampling_params[i]
if isinstance(sampling_params, list) else sampling_params,
"params":
params[i] if isinstance(params, list) else params,
"prompt_token_ids":
token_ids,
lora_request=lora_request,
# Get ith image while maintaining the batch dim.
multi_modal_data=MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0))
if multi_modal_data else None,
)
return self._run_engine(use_tqdm)
"lora_request":
lora_request,
"multi_modal_data":
multi_modal_item,
})
return requests_data
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
@ -229,12 +317,14 @@ class LLM:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
params,
prompt_token_ids,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
def _run_engine(
self, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
@ -245,7 +335,7 @@ class LLM:
postfix=f"Generation Speed: {0:.2f} toks/s",
)
# Run the engine.
outputs: List[RequestOutput] = []
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
@ -253,10 +343,12 @@ class LLM:
if output.finished:
outputs.append(output)
if use_tqdm:
total_toks += (sum(
len(stp.token_ids) for stp in output.outputs))
spd = total_toks / pbar.format_dict["elapsed"]
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
total_toks += sum(
len(stp.token_ids) for stp in output.outputs)
spd = total_toks / pbar.format_dict["elapsed"]
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
pbar.update(1)
if use_tqdm:
pbar.close()

View File

@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse)
CompletionRequest,
EmbeddingRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
@ -32,6 +34,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
logger = init_logger(__name__)
_running_tasks: Set[asyncio.Task] = set()
@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__":
args = parse_args()
@ -190,7 +205,8 @@ if __name__ == "__main__":
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,

View File

@ -1,13 +1,14 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
@ -363,6 +364,24 @@ class CompletionRequest(OpenAIBaseModel):
return data
class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
dimensions: Optional[int] = None
user: Optional[str] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(BaseModel):
index: int
object: str = "embedding"
embedding: List[float]
class EmbeddingResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[EmbeddingResponseData]
usage: UsageInfo
class ChatMessage(OpenAIBaseModel):
role: str
content: str

View File

@ -0,0 +1,134 @@
import time
from typing import AsyncIterator, List, Tuple
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> EmbeddingResponse:
data = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
embedding_data = EmbeddingResponseData(
index=idx, embedding=final_res.outputs.embedding)
data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
class OpenAIServingEmbedding(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# Return error for unsupported features.
if request.encoding_format == "base64":
return self.create_error_response(
"base64 encoding is not currently supported")
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
try:
prompt_is_tokens, prompts = parse_prompt_format(request.input)
pooling_params = request.to_pooling_params()
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt)
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt_text,
pooling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids))
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
return response
def _check_embedding_mode(self, embedding_mode: bool):
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")

View File

@ -9,7 +9,8 @@ from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse,
CompletionRequest,
EmbeddingRequest, ErrorResponse,
LogProbs, ModelCard, ModelList,
ModelPermission)
from vllm.logger import init_logger
@ -165,7 +166,8 @@ class OpenAIServing:
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
request: Union[ChatCompletionRequest, CompletionRequest,
EmbeddingRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
@ -191,6 +193,16 @@ class OpenAIServing:
prompt_ids)
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(

View File

@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
@ -123,8 +123,8 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output

View File

@ -0,0 +1,56 @@
from enum import IntEnum
import torch
import torch.nn as nn
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
class Pooler(nn.Module):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
normalize: Whether to normalize the pooled data.
"""
def __init__(self, pooling_type: PoolingType, normalize: bool):
super().__init__()
self.pooling_type = pooling_type
self.normalize = normalize
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools specific information from hidden states based on metadata."""
prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
if self.pooling_type == PoolingType.LAST:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
pooled_data = hidden_states[last_token_flat_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)

View File

@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceGroupOutput, SequenceOutput)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceOutput)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]
@ -1019,7 +1020,7 @@ def _build_sampler_output(
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:

View File

@ -9,7 +9,7 @@ from vllm.utils import is_hip
logger = init_logger(__name__)
# Architecture -> (module, class).
_MODELS = {
_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
@ -58,6 +58,12 @@ _MODELS = {
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
}
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
}
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
@ -114,6 +120,10 @@ class ModelRegistry:
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
@staticmethod
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
__all__ = [
"ModelRegistry",

View File

@ -0,0 +1,87 @@
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput
class LlamaEmbeddingModel(nn.Module):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.forward(input_ids, positions, kv_caches,
attn_metadata, inputs_embeds)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -0,0 +1,69 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import torch
from vllm.pooling_params import PoolingParams
from vllm.utils import is_pin_memory_available
class PoolingMetadata:
"""Metadata for pooling operations in the Pooler layer.
This class holds the necessary information for pooling operations,
providing context for how to perform pooling and other related operations.
Attributes:
seq_groups: List of (seq_ids, pooling_params).
seq_data: A mapping of sequence ID to additional sequence data.
prompt_lens: List of the lengths of each prompt.
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], PoolingParams]],
seq_data: Dict[int, Any], # Specific data related to sequences
prompt_lens: List[int],
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
def __repr__(self) -> str:
return ("PoolingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens})")
@dataclass
class PoolingTensors:
"""Tensors for pooling."""
prompt_lens: torch.Tensor
@classmethod
def from_pooling_metadata(
cls,
pooling_metadata: "PoolingMetadata",
device: torch.device,
) -> "PoolingTensors":
"""
Create PoolingTensors from PoolingMetadata.
Args:
pooling_metadata: PoolingMetadata instance to convert.
device: Device to store the tensors.
"""
# Convert prompt lengths to tensor
pin_memory = is_pin_memory_available()
prompt_lens_t = torch.tensor(
pooling_metadata.prompt_lens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
return cls(prompt_lens=prompt_lens_t.to(device=device,
non_blocking=True), )

View File

@ -57,8 +57,27 @@ class CompletionOutput:
f"stop_reason={self.stop_reason})")
class EmbeddingOutput:
"""The output data of one completion output of a request.
Args:
embedding: The embedding vector, which is a list of floats. The
length of vector depends on the model as listed in the embedding guide.
"""
def __init__(
self,
embedding: List[float],
) -> None:
self.embedding = embedding
def __repr__(self) -> str:
return (f"EmbeddingOutput("
f"embedding={len(self.embedding)}")
class RequestOutput:
"""The output data of a request to the LLM.
"""The output data of a completion request to the LLM.
Args:
request_id: The unique ID of the request.
@ -93,6 +112,9 @@ class RequestOutput:
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
if seq_group.sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
@ -148,3 +170,61 @@ class RequestOutput:
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request})")
class EmbeddingRequestOutput:
"""
The output data of an embedding request to the LLM.
Args:
request_id (str): A unique identifier for the embedding request.
outputs (EmbeddingOutput): The embedding results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the embedding is completed.
"""
def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
prompt_token_ids: List[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.finished = finished
self.outputs = outputs
@classmethod
def from_seq_group(cls,
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
if seq_group.embeddings is None:
raise ValueError(
"Embeddings are missing in seq_group for EmbeddingRequest.")
output = EmbeddingOutput(seq_group.embeddings)
prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished()
return cls(seq_group.request_id, output, prompt_token_ids, finished)
def __repr__(self):
"""
Returns a string representation of an EmbeddingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
Returns:
str: A string representation of the EmbeddingRequestOutput instance.
"""
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})")
class RequestOutputFactory:
@staticmethod
def create(seq_group):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group)

20
vllm/pooling_params.py Normal file
View File

@ -0,0 +1,20 @@
from typing import Any, Optional
class PoolingParams:
"""Pooling parameters for pooling.
Attributes:
additional_data: Any additional data needed for pooling.
"""
def __init__(self, additional_data: Optional[Any] = None):
self.additional_data = additional_data
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(additional_data=self.additional_data, )
def __repr__(self) -> str:
return (f"PoolingParams("
f"additional_metadata={self.additional_data})")

View File

@ -1,11 +1,13 @@
"""Sequence and its related classes."""
import copy
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
@ -375,12 +377,12 @@ class SequenceGroupState:
class MultiModalData:
"""Multi modal request.
Args:
type: The data type.
data: The actual data.
The required shape and semantic meaning of it depends on the vision
language config of the hosted model.
language config of the hosted model.
See `VisionLanguageConfig` in `config.py`.
"""
@ -402,16 +404,22 @@ class SequenceGroup:
arrival_time: The arrival time of the request.
lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model.
pooling_params: The pooling parameters used to generate the pooling
for an embedding model.
"""
def __init__(
self,
request_id: str,
seqs: List[Sequence],
sampling_params: SamplingParams,
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@ -425,6 +433,8 @@ class SequenceGroup:
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.multi_modal_data = multi_modal_data
self.embeddings = embeddings
self.pooling_params = pooling_params
@property
def prompt(self) -> str:
@ -479,12 +489,13 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params.use_beam_search:
if self.sampling_params and self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
else:
if self.sampling_params.best_of > self.num_seqs():
if (self.sampling_params
and self.sampling_params.best_of > self.num_seqs()):
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
@ -555,7 +566,7 @@ class SequenceGroup:
return all(seq.is_finished() for seq in self.get_seqs())
def is_prefill(self) -> bool:
# Every sequences should be in the same stage.
# Every sequence should be in the same stage.
return self.get_seqs()[0].is_prefill()
def __repr__(self) -> str:
@ -594,6 +605,7 @@ class SequenceGroupMetadata:
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
do_sample: bool = True,
pooling_params: Optional[PoolingParams] = None,
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
@ -605,6 +617,7 @@ class SequenceGroupMetadata:
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables
self.pooling_params = pooling_params
self.lora_request = lora_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
@ -669,8 +682,20 @@ class SequenceOutput:
return equal and log_probs_equal
class SequenceGroupOutput:
"""The model output associated with a sequence group."""
class SequenceGroupOutput(ABC):
"""The base class for model outputs associated with a sequence group."""
@abstractmethod
def __repr__(self) -> str:
pass
@abstractmethod
def __eq__(self, other: object) -> bool:
pass
class CompletionSequenceGroupOutput(SequenceGroupOutput):
"""The model output associated with a completion sequence group."""
def __init__(
self,
@ -682,26 +707,45 @@ class SequenceGroupOutput:
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:
return (f"SequenceGroupOutput(samples={self.samples}, "
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceGroupOutput):
if not isinstance(other, CompletionSequenceGroupOutput):
raise NotImplementedError()
return (self.samples == other.samples
and self.prompt_logprobs == other.prompt_logprobs)
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
"""The model output associated with an embedding sequence group."""
def __init__(
self,
embeddings: List[float],
) -> None:
self.embeddings = embeddings
def __repr__(self) -> str:
return (f"EmbeddingSequenceGroupOutput("
f"embeddings_shape={len(self.embeddings)})")
def __eq__(self, other: object) -> bool:
if not isinstance(other, EmbeddingSequenceGroupOutput):
raise NotImplementedError()
return self.embeddings == other.embeddings
@dataclass
class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This datastructure implements methods so it can be used like a list, but
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[SequenceGroupOutput]
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
@ -742,6 +786,27 @@ class SamplerOutput:
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
@dataclass
class PoolerOutput:
"""The output from a pooling operation in the embedding model."""
outputs: List[EmbeddingSequenceGroupOutput]
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
@dataclass
class ExecuteModelRequest:
"""The model execution request."""

View File

@ -4,7 +4,8 @@ from typing import Dict, List, Tuple
import torch
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SeqId = int
@ -94,7 +95,7 @@ def create_sequence_group_output(
for topk_logprob_index, _ in enumerate(topk_token_ids)
})
return SequenceGroupOutput(
return CompletionSequenceGroupOutput(
samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,

View File

@ -0,0 +1,266 @@
from typing import Dict, List, Optional, Set, Tuple
import torch
from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import BatchType, ModelRunner
logger = init_logger(__name__)
class EmbeddingModelRunner(ModelRunner):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
):
super().__init__(model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[PoolerOutput]:
(input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
return self.model.pooler(hidden_states=hidden_states,
pooling_metadata=pooling_metadata)
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
(
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
# Prepare PoolingMetadata
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
prompt_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
}
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
pooling_metadata = PoolingMetadata(seq_groups=None,
seq_data=None,
prompt_lens=None)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata

View File

@ -1,6 +1,6 @@
import time
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np
import torch
@ -287,18 +287,18 @@ class ModelRunner:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend(
[lora_id] *
(seq_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
lora_prompt_mapping.extend([lora_id] * (
seq_len - context_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
if seq_group_metadata.block_tables is None:
if _is_block_tables_empty(seq_group_metadata.block_tables):
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
@ -813,7 +813,6 @@ class ModelRunner:
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
@ -1139,3 +1138,15 @@ def _prepare_fake_inputs(
prompt_tokens = [0] * seq_len
fake_image_input = None
return SequenceData(prompt_tokens), fake_image_input
def _is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
if isinstance(block_tables, dict) and all(
value is None for value in block_tables.values()):
return True
return False

View File

@ -1,7 +1,7 @@
"""A GPU worker class."""
import gc
import os
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.distributed
@ -16,8 +16,9 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase
@ -68,7 +69,9 @@ class Worker(WorkerBase):
assert not self.lora_config, (
"To be tested: vision language model with LoRA settings.")
self.model_runner = ModelRunner(
ModelRunnerClass = (EmbeddingModelRunner if
self.model_config.embedding_mode else ModelRunner)
self.model_runner = ModelRunnerClass(
model_config,
parallel_config,
scheduler_config,
@ -83,7 +86,8 @@ class Worker(WorkerBase):
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CacheEngine
self.gpu_cache: List[torch.Tensor]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[torch.tensor]] = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
@ -209,7 +213,7 @@ class Worker(WorkerBase):
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
) -> List[Union[SamplerOutput, PoolerOutput]]:
if execute_model_req is None:
seq_group_metadata_list = None