[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
This commit is contained in:
parent
4e12131089
commit
e254497b66
17
examples/offline_inference_embedding.py
Normal file
17
examples/offline_inference_embedding.py
Normal file
@ -0,0 +1,17 @@
|
||||
from vllm import LLM
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = model.encode(prompts)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
print(output.outputs.embedding) # list of 4096 floats
|
23
examples/openai_embedding_client.py
Normal file
23
examples/openai_embedding_client.py
Normal file
@ -0,0 +1,23 @@
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
responses = client.embeddings.create(input=[
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
],
|
||||
model=model)
|
||||
|
||||
for data in responses.data:
|
||||
print(data.embedding) # list of float of len 4096
|
@ -19,12 +19,15 @@ pytest-forked
|
||||
pytest-asyncio
|
||||
pytest-rerunfailures
|
||||
pytest-shard
|
||||
httpx
|
||||
|
||||
# testing utils
|
||||
awscli
|
||||
einops # required for MPT
|
||||
httpx
|
||||
peft
|
||||
requests
|
||||
ray
|
||||
peft
|
||||
awscli
|
||||
sentence-transformers # required for embedding
|
||||
|
||||
# Benchmarking
|
||||
aiohttp
|
||||
|
@ -133,6 +133,10 @@ _VISION_LANGUAGE_MODELS = {
|
||||
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = [
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
]
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
@ -145,14 +149,7 @@ class HfRunner:
|
||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
self.model_name = model_name
|
||||
if model_name not in _VISION_LANGUAGE_MODELS:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
self.processor = None
|
||||
else:
|
||||
if model_name in _VISION_LANGUAGE_MODELS:
|
||||
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
@ -162,6 +159,20 @@ class HfRunner:
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
elif model_name in _EMBEDDING_MODELS:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
device="cpu",
|
||||
).to(dtype=torch_dtype).cuda()
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
self.processor = None
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||
@ -334,6 +345,9 @@ class HfRunner:
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
||||
return self.model.encode(prompts)
|
||||
|
||||
def __del__(self):
|
||||
del self.model
|
||||
cleanup()
|
||||
@ -459,6 +473,14 @@ class VllmRunner:
|
||||
outputs = self.generate(prompts, beam_search_params)
|
||||
return outputs
|
||||
|
||||
def encode(self, prompts: List[str]) -> List[List[float]]:
|
||||
req_outputs = self.model.encode(prompts)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
embedding = req_output.outputs.embedding
|
||||
outputs.append(embedding)
|
||||
return outputs
|
||||
|
||||
def __del__(self):
|
||||
del self.model
|
||||
cleanup()
|
||||
|
@ -9,8 +9,8 @@ from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
|
@ -14,6 +14,7 @@ class MockModelConfig:
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
embedding_mode = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -23,6 +23,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
||||
# generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
@ -121,7 +122,7 @@ def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope="module")
|
||||
def server(zephyr_lora_files):
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote([
|
||||
@ -150,6 +151,25 @@ def server(zephyr_lora_files):
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def embedding_server(zephyr_lora_files):
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote([
|
||||
"--model",
|
||||
EMBEDDING_MODEL_NAME,
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
client = openai.AsyncOpenAI(
|
||||
@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
|
||||
or "less_than_equal" in exc_info.value.message)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
input = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input,
|
||||
encoding_format="float",
|
||||
)
|
||||
assert embeddings.id is not None
|
||||
assert embeddings.data is not None and len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 9
|
||||
assert embeddings.usage.total_tokens == 9
|
||||
|
||||
# test using token IDs
|
||||
input = [1, 1, 1, 1, 1]
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input,
|
||||
encoding_format="float",
|
||||
)
|
||||
assert embeddings.id is not None
|
||||
assert embeddings.data is not None and len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 5
|
||||
assert embeddings.usage.total_tokens == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test List[str]
|
||||
inputs = [
|
||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||
"Stars twinkle brightly in the night sky."
|
||||
]
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=inputs,
|
||||
encoding_format="float",
|
||||
)
|
||||
assert embeddings.id is not None
|
||||
assert embeddings.data is not None and len(embeddings.data) == 3
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
|
||||
# test List[List[int]]
|
||||
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||
[25, 32, 64, 77]]
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=inputs,
|
||||
encoding_format="float",
|
||||
)
|
||||
assert embeddings.id is not None
|
||||
assert embeddings.data is not None and len(embeddings.data) == 4
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 17
|
||||
assert embeddings.usage.total_tokens == 17
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
44
tests/models/test_embedding.py
Normal file
44
tests/models/test_embedding.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_llama_embedding.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
MODELS = [
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
]
|
||||
|
||||
|
||||
def compare_embeddings(embeddings1, embeddings2):
|
||||
similarities = [
|
||||
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
|
||||
for e1, e2 in zip(embeddings1, embeddings2)
|
||||
]
|
||||
return similarities
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
del vllm_model
|
||||
|
||||
similarities = compare_embeddings(hf_outputs, vllm_outputs)
|
||||
all_similarities = torch.stack(similarities)
|
||||
tolerance = 1e-2
|
||||
assert torch.all((all_similarities <= 1.0 + tolerance)
|
||||
& (all_similarities >= 1.0 - tolerance)
|
||||
), f"Not all values are within {tolerance} of 1.0"
|
@ -36,14 +36,14 @@ def test_logits_processor_force_generate(
|
||||
# test logits_processors when prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[0],
|
||||
sampling_params=params_with_logprobs,
|
||||
params=params_with_logprobs,
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
# test prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[1],
|
||||
sampling_params=SamplingParams(
|
||||
params=SamplingParams(
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
@ -53,7 +53,7 @@ def test_logits_processor_force_generate(
|
||||
# test grouped requests
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[2],
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens),
|
||||
params=SamplingParams(max_tokens=max_tokens),
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
|
@ -60,7 +60,7 @@ def test_random_sample_with_seed(
|
||||
llm._add_request(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=None,
|
||||
sampling_params=params,
|
||||
params=params,
|
||||
)
|
||||
|
||||
results = llm._run_engine(use_tqdm=False)
|
||||
|
@ -7,8 +7,8 @@ import torch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata, SequenceGroupOutput,
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceData, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
@ -170,7 +170,7 @@ def create_sampler_output_list(
|
||||
|
||||
return [
|
||||
SamplerOutput(outputs=[
|
||||
SequenceGroupOutput(
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
output_token=token_id,
|
||||
|
@ -1,17 +1,17 @@
|
||||
import pytest
|
||||
|
||||
from tests.core.utils import create_dummy_prompt
|
||||
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||
SequenceOutput)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
|
||||
SequenceData, SequenceOutput)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_outputs():
|
||||
return [
|
||||
SequenceGroupOutput(samples=[
|
||||
CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None) for i in range(5)
|
||||
prompt_logprobs=None) for i in range(5)
|
||||
]
|
||||
|
||||
|
||||
@ -32,10 +32,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs):
|
||||
|
||||
|
||||
def test_sampler_output_setitem(sampler_output):
|
||||
new_output = SequenceGroupOutput(samples=[
|
||||
new_output = CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None)
|
||||
prompt_logprobs=None)
|
||||
sampler_output[2] = new_output
|
||||
assert sampler_output[2] == new_output
|
||||
|
||||
|
@ -6,7 +6,9 @@ from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
||||
EmbeddingRequestOutput, RequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.4.2"
|
||||
@ -17,9 +19,12 @@ __all__ = [
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"EmbeddingOutput",
|
||||
"EmbeddingRequestOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
"AsyncEngineArgs",
|
||||
"initialize_ray_cluster",
|
||||
"PoolingParams",
|
||||
]
|
||||
|
@ -9,6 +9,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
|
||||
|
||||
@ -22,6 +23,7 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GB = 1 << 30
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@ -126,6 +128,7 @@ class ModelConfig:
|
||||
served_model_name)
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_embedding_mode()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
|
||||
@ -137,6 +140,11 @@ class ModelConfig:
|
||||
"either 'auto' or 'slow'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_embedding_mode(self) -> None:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
self.embedding_mode = any(
|
||||
ModelRegistry.is_embedding_model(arch) for arch in architectures)
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = ["gptq", "squeezellm"]
|
||||
@ -591,6 +599,7 @@ class SchedulerConfig:
|
||||
prompt latency) before scheduling next prompt.
|
||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens.
|
||||
embedding_mode: Whether the running model is for embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -602,6 +611,7 @@ class SchedulerConfig:
|
||||
num_lookahead_slots: int = 0,
|
||||
delay_factor: float = 0.0,
|
||||
enable_chunked_prefill: bool = False,
|
||||
embedding_mode: Optional[bool] = False,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@ -610,6 +620,10 @@ class SchedulerConfig:
|
||||
# It is the values that have the best balance between ITL
|
||||
# and TTFT on A100. Note it is not optimized for throughput.
|
||||
self.max_num_batched_tokens = 512
|
||||
elif embedding_mode:
|
||||
# For embedding, choose specific value for higher throughput
|
||||
self.max_num_batched_tokens = max(
|
||||
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value
|
||||
# for higher throughput.
|
||||
@ -623,6 +637,7 @@ class SchedulerConfig:
|
||||
self.num_lookahead_slots = num_lookahead_slots
|
||||
self.delay_factor = delay_factor
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
self.embedding_mode = embedding_mode
|
||||
|
||||
self._verify_args()
|
||||
|
||||
|
84
vllm/core/embedding_model_block_manager.py
Normal file
84
vllm/core/embedding_model_block_manager.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
|
||||
|
||||
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||
"""An embedding version of BlockSpaceManager for use in environments
|
||||
with embedding models where block management is not required.
|
||||
|
||||
This class provides the same interface as BlockSpaceManager, but its
|
||||
methods perform no actions or return simple values like True in specific
|
||||
actions. It's designed to be used in scenarios where the overhead of
|
||||
block management is unnecessary, such as in an embedding environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
# Always return OK for dummy purposes
|
||||
return AllocStatus.OK
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# No actual allocation logic needed
|
||||
pass
|
||||
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
return True
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
return None # type: ignore
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
return AllocStatus.OK
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> List[Tuple[int, int]]:
|
||||
return None # type: ignore
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
return True
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
return None # type: ignore
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
# No operation on free
|
||||
return
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
return None # type: ignore
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return 1
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return 1
|
||||
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(self,
|
||||
seq_group: SequenceGroup) -> List[int]:
|
||||
return None # type: ignore
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
pass
|
@ -35,6 +35,11 @@ class BlockSpaceManager(ABC):
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
return BlockSpaceManagerV2
|
||||
|
||||
if version == "embedding":
|
||||
from vllm.core.embedding_model_block_manager import (
|
||||
EmbeddingModelBlockSpaceManager)
|
||||
return EmbeddingModelBlockSpaceManager
|
||||
|
||||
raise ValueError(f"Unknown version {version=}")
|
||||
|
||||
@abstractmethod
|
||||
|
@ -270,9 +270,14 @@ class Scheduler:
|
||||
self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
version = "v1"
|
||||
if self.scheduler_config.use_v2_block_manager:
|
||||
version = "v2"
|
||||
if self.scheduler_config.embedding_mode:
|
||||
version = "embedding"
|
||||
|
||||
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
|
||||
version="v2" if self.scheduler_config.
|
||||
use_v2_block_manager else "v1")
|
||||
version)
|
||||
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManagerImpl(
|
||||
@ -968,6 +973,7 @@ class Scheduler:
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
do_sample=do_sample,
|
||||
pooling_params=seq_group.pooling_params,
|
||||
token_chunk_size=token_chunk_size,
|
||||
lora_request=seq_group.lora_request,
|
||||
computed_block_nums=common_computed_block_nums,
|
||||
|
@ -574,6 +574,7 @@ class EngineArgs:
|
||||
speculative_config.num_lookahead_slots),
|
||||
delay_factor=self.scheduler_delay_factor,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
embedding_mode=model_config.embedding_mode,
|
||||
)
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
|
@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -47,15 +48,16 @@ def _raise_exception_on_finish(
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs for a request that can be
|
||||
iterated over asynchronously."""
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
that can be iterated over asynchronously."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
||||
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||
Exception]) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
@ -71,7 +73,7 @@ class AsyncStream:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> RequestOutput:
|
||||
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
|
||||
result = await self._queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
@ -108,7 +110,8 @@ class RequestTracker:
|
||||
self.abort_request(rid)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
request_output: Union[RequestOutput,
|
||||
EmbeddingRequestOutput],
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
@ -196,7 +199,8 @@ class RequestTracker:
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
async def step_async(self) -> List[RequestOutput]:
|
||||
async def step_async(
|
||||
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return self.add_request(request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
@ -511,7 +515,7 @@ class AsyncLLMEngine:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -528,9 +532,9 @@ class AsyncLLMEngine:
|
||||
max_log_len]
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"sampling_params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s.", request_id, shortened_prompt,
|
||||
sampling_params, shortened_token_ids, lora_request)
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s.", request_id, shortened_prompt, params,
|
||||
shortened_token_ids, lora_request)
|
||||
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
@ -562,7 +566,7 @@ class AsyncLLMEngine:
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
@ -597,8 +601,8 @@ class AsyncLLMEngine:
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine for the
|
||||
request.
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
|
||||
Details:
|
||||
- If the engine is not running, start the background loop,
|
||||
@ -643,25 +647,123 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
|
||||
try:
|
||||
stream = await self.add_request(
|
||||
async for output in self.process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
):
|
||||
yield output
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
|
||||
Details:
|
||||
- If the engine is not running, start the background loop,
|
||||
which iteratively invokes
|
||||
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
|
||||
to process the waiting requests.
|
||||
- Add the request to the engine's `RequestTracker`.
|
||||
On the next background loop, this request will be sent to
|
||||
the underlying engine.
|
||||
Also, a corresponding `AsyncStream` will be created.
|
||||
- Wait for the request outputs from `AsyncStream` and yield them.
|
||||
|
||||
Example:
|
||||
>>> # Please refer to entrypoints/api_server.py for
|
||||
>>> # the complete example.
|
||||
>>>
|
||||
>>> # initialize the engine and the example input
|
||||
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
>>> example_input = {
|
||||
>>> "input": "What is LLM?",
|
||||
>>> "request_id": 0,
|
||||
>>> }
|
||||
>>>
|
||||
>>> # start the generation
|
||||
>>> results_generator = engine.encode(
|
||||
>>> example_input["input"],
|
||||
>>> PoolingParams(),
|
||||
>>> example_input["request_id"])
|
||||
>>>
|
||||
>>> # get the results
|
||||
>>> final_output = None
|
||||
>>> async for request_output in results_generator:
|
||||
>>> if await request.is_disconnected():
|
||||
>>> # Abort the request if the client disconnects.
|
||||
>>> await engine.abort(request_id)
|
||||
>>> # Return or raise an error
|
||||
>>> ...
|
||||
>>> final_output = request_output
|
||||
>>>
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self.process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
pooling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
):
|
||||
yield output
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
arrival_time = time.time()
|
||||
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
try:
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
|
@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
MultiModalData, PoolerOutput, SamplerOutput,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
@ -169,7 +172,8 @@ class LLMEngine:
|
||||
load_config=load_config,
|
||||
)
|
||||
|
||||
self._initialize_kv_caches()
|
||||
if not self.model_config.embedding_mode:
|
||||
self._initialize_kv_caches()
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
if is_usage_stats_enabled():
|
||||
@ -354,7 +358,7 @@ class LLMEngine:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -370,7 +374,8 @@ class LLMEngine:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
sampling_params: The sampling parameters for text generation.
|
||||
params: Parameters for sampling or pooling. SamplingParams
|
||||
for text generation. PoolingParams for pooling.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
@ -404,13 +409,6 @@ class LLMEngine:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
if (sampling_params.logprobs
|
||||
and sampling_params.logprobs > max_logprobs) or (
|
||||
sampling_params.prompt_logprobs
|
||||
and sampling_params.prompt_logprobs > max_logprobs):
|
||||
raise ValueError(f"Cannot request more than "
|
||||
f"{max_logprobs} logprobs.")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = self.encode_request(
|
||||
@ -432,6 +430,50 @@ class LLMEngine:
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
eos_token_id, lora_request)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
if (sampling_params.logprobs
|
||||
and sampling_params.logprobs > max_logprobs) or (
|
||||
sampling_params.prompt_logprobs
|
||||
and sampling_params.prompt_logprobs > max_logprobs):
|
||||
raise ValueError(f"Cannot request more than "
|
||||
f"{max_logprobs} logprobs.")
|
||||
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
@ -443,11 +485,35 @@ class LLMEngine:
|
||||
self.generation_config_fields)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time, lora_request, multi_modal_data)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
return seq_group
|
||||
|
||||
def _create_sequence_group_with_pooling(
|
||||
self,
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
pooling_params: PoolingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with PoolingParams."""
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
pooling_params = pooling_params.clone()
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
pooling_params=pooling_params)
|
||||
return seq_group
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Aborts a request(s) with the given ID.
|
||||
@ -484,13 +550,25 @@ class LLMEngine:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _process_sequence_group_outputs(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
outputs: List[EmbeddingSequenceGroupOutput],
|
||||
) -> None:
|
||||
seq_group.embeddings = outputs[0].embeddings
|
||||
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
return
|
||||
|
||||
def _process_model_outputs(
|
||||
self,
|
||||
output: List[SamplerOutput],
|
||||
output: List[Union[SamplerOutput, PoolerOutput]],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> List[RequestOutput]:
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
@ -510,6 +588,9 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
if self.model_config.embedding_mode:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
continue
|
||||
|
||||
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
||||
if seq_group_meta.do_sample:
|
||||
@ -519,18 +600,19 @@ class LLMEngine:
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
for scheduled_seq_group in scheduled_seq_groups:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
for seq_group in ignored_seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
return request_outputs
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
.. figure:: https://i.imgur.com/sv2HssD.png
|
||||
@ -570,7 +652,7 @@ class LLMEngine:
|
||||
>>> while True:
|
||||
>>> if example_inputs:
|
||||
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
||||
>>> engine.add_request(str(req_id), prompt, sampling_params)
|
||||
>>> engine.add_request(str(req_id),prompt,sampling_params)
|
||||
>>>
|
||||
>>> # continue the request processing
|
||||
>>> request_outputs = engine.step()
|
||||
@ -637,12 +719,15 @@ class LLMEngine:
|
||||
|
||||
# KV Cache Usage in %
|
||||
num_total_gpu = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
gpu_cache_usage_sys = 0.
|
||||
if num_total_gpu is not None:
|
||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
|
||||
)
|
||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
|
||||
num_total_cpu = self.cache_config.num_cpu_blocks
|
||||
cpu_cache_usage_sys = 0.
|
||||
if num_total_cpu > 0:
|
||||
if num_total_cpu is not None and num_total_cpu > 0:
|
||||
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
||||
)
|
||||
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
@ -716,8 +801,10 @@ class LLMEngine:
|
||||
seq.get_output_len()
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
])
|
||||
best_of_requests.append(seq_group.sampling_params.best_of)
|
||||
n_requests.append(seq_group.sampling_params.n)
|
||||
if seq_group.sampling_params is not None:
|
||||
best_of_requests.append(
|
||||
seq_group.sampling_params.best_of)
|
||||
n_requests.append(seq_group.sampling_params.n)
|
||||
finished_reason_requests.extend([
|
||||
SequenceStatus.get_finished_reason(seq.status)
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
|
@ -6,13 +6,17 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
@ -164,8 +168,89 @@ class LLM:
|
||||
multi_modal_data: Multi modal data.
|
||||
|
||||
Returns:
|
||||
A list of `RequestOutput` objects containing the generated
|
||||
completions in the same order as the input prompts.
|
||||
A list of `RequestOutput` objects containing the
|
||||
generated completions in the same order as the input prompts.
|
||||
"""
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
requests_data = self._validate_and_prepare_requests(
|
||||
prompts,
|
||||
sampling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
|
||||
# Add requests to the engine and run the engine
|
||||
for request_data in requests_data:
|
||||
self._add_request(**request_data)
|
||||
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
List[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
NOTE: This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts to generate completions for.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data.
|
||||
|
||||
Returns:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
generated embeddings in the same order as the input prompts.
|
||||
"""
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
requests_data = self._validate_and_prepare_requests(
|
||||
prompts,
|
||||
pooling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
|
||||
# Add requests to the engine and run the engine
|
||||
for request_data in requests_data:
|
||||
self._add_request(**request_data)
|
||||
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _validate_and_prepare_requests(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]],
|
||||
params: Union[Union[SamplingParams, PoolingParams],
|
||||
List[Union[SamplingParams,
|
||||
PoolingParams]]], # Unified parameter
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[dict]:
|
||||
"""Validates and prepares request data for adding to the engine.
|
||||
|
||||
Ensures prompts and token IDs are consistent, and returns a list of
|
||||
dictionaries with request data for further processing.
|
||||
"""
|
||||
if prompts is None and prompt_token_ids is None:
|
||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||
@ -188,40 +273,43 @@ class LLM:
|
||||
assert prompt_token_ids is not None
|
||||
num_requests = len(prompt_token_ids)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
elif isinstance(sampling_params,
|
||||
list) and len(sampling_params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and sampling_params "
|
||||
if isinstance(params, list) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
if multi_modal_data:
|
||||
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
||||
|
||||
# Add requests to the engine.
|
||||
requests_data = []
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||
i]
|
||||
self._add_request(
|
||||
|
||||
multi_modal_item = MultiModalData(
|
||||
type=multi_modal_data.type,
|
||||
data=multi_modal_data.data[i].unsqueeze(0),
|
||||
) if multi_modal_data else None
|
||||
|
||||
requests_data.append({
|
||||
"prompt":
|
||||
prompt,
|
||||
sampling_params[i]
|
||||
if isinstance(sampling_params, list) else sampling_params,
|
||||
"params":
|
||||
params[i] if isinstance(params, list) else params,
|
||||
"prompt_token_ids":
|
||||
token_ids,
|
||||
lora_request=lora_request,
|
||||
# Get ith image while maintaining the batch dim.
|
||||
multi_modal_data=MultiModalData(
|
||||
type=multi_modal_data.type,
|
||||
data=multi_modal_data.data[i].unsqueeze(0))
|
||||
if multi_modal_data else None,
|
||||
)
|
||||
return self._run_engine(use_tqdm)
|
||||
"lora_request":
|
||||
lora_request,
|
||||
"multi_modal_data":
|
||||
multi_modal_item,
|
||||
})
|
||||
|
||||
return requests_data
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
@ -229,12 +317,14 @@ class LLM:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
params,
|
||||
prompt_token_ids,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
def _run_engine(
|
||||
self, use_tqdm: bool
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
@ -245,7 +335,7 @@ class LLM:
|
||||
postfix=f"Generation Speed: {0:.2f} toks/s",
|
||||
)
|
||||
# Run the engine.
|
||||
outputs: List[RequestOutput] = []
|
||||
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||
total_toks = 0
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
@ -253,10 +343,12 @@ class LLM:
|
||||
if output.finished:
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
total_toks += (sum(
|
||||
len(stp.token_ids) for stp in output.outputs))
|
||||
spd = total_toks / pbar.format_dict["elapsed"]
|
||||
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
|
||||
if isinstance(output, RequestOutput):
|
||||
# Calculate tokens only for RequestOutput
|
||||
total_toks += sum(
|
||||
len(stp.token_ids) for stp in output.outputs)
|
||||
spd = total_toks / pbar.format_dict["elapsed"]
|
||||
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
|
||||
pbar.update(1)
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
|
@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest, ErrorResponse)
|
||||
CompletionRequest,
|
||||
EmbeddingRequest, ErrorResponse)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
@ -32,6 +34,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
openai_serving_chat: OpenAIServingChat
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
openai_serving_embedding: OpenAIServingEmbedding
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await openai_serving_embedding.create_embedding(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
@ -190,7 +205,8 @@ if __name__ == "__main__":
|
||||
args.chat_template)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine, model_config, served_model_names, args.lora_modules)
|
||||
|
||||
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
||||
served_model_names)
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
|
@ -1,13 +1,14 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@ -363,6 +364,24 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
# doc: end-embedding-pooling-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class LogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class EmbeddingResponseData(BaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: List[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
134
vllm/entrypoints/openai/serving_embedding.py
Normal file
134
vllm/entrypoints/openai/serving_embedding.py
Normal file
@ -0,0 +1,134 @@
|
||||
import time
|
||||
from typing import AsyncIterator, List, Tuple
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> EmbeddingResponse:
|
||||
data = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
embedding_data = EmbeddingResponseData(
|
||||
index=idx, embedding=final_res.outputs.embedding)
|
||||
data.append(embedding_data)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=data,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||
served_model_names: List[str]):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=None)
|
||||
self._check_embedding_mode(model_config.embedding_mode)
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest,
|
||||
raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.encoding_format == "base64":
|
||||
return self.create_error_response(
|
||||
"base64 encoding is not currently supported")
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators = []
|
||||
try:
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.input)
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request, prompt_ids=prompt)
|
||||
else:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request, prompt=prompt)
|
||||
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
generators.append(
|
||||
self.engine.generate(prompt_text,
|
||||
pooling_params,
|
||||
f"{request_id}-{i}",
|
||||
prompt_token_ids=prompt_ids))
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(f"{request_id}-{i}")
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch, request_id, created_time, model_name)
|
||||
|
||||
return response
|
||||
|
||||
def _check_embedding_mode(self, embedding_mode: bool):
|
||||
if not embedding_mode:
|
||||
logger.warning(
|
||||
"embedding_mode is False. Embedding API will not work.")
|
||||
else:
|
||||
logger.info("Activating the server engine with embedding enabled.")
|
@ -9,7 +9,8 @@ from typing_extensions import Annotated
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest, ErrorResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList,
|
||||
ModelPermission)
|
||||
from vllm.logger import init_logger
|
||||
@ -165,7 +166,8 @@ class OpenAIServing:
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
request: Union[ChatCompletionRequest, CompletionRequest,
|
||||
EmbeddingRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
@ -191,6 +193,16 @@ class OpenAIServing:
|
||||
prompt_ids)
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request, EmbeddingRequest):
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for embedding "
|
||||
f"generation. Please reduce the length of the input.", )
|
||||
return input_ids, input_text
|
||||
|
||||
if request.max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
@ -123,8 +123,8 @@ class GPUExecutor(ExecutorBase):
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
||||
|
56
vllm/model_executor/layers/pooler.py
Normal file
56
vllm/model_executor/layers/pooler.py
Normal file
@ -0,0 +1,56 @@
|
||||
from enum import IntEnum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
"""Enumeration for different types of pooling methods."""
|
||||
LAST = 0
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
|
||||
Attributes:
|
||||
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
||||
super().__init__()
|
||||
self.pooling_type = pooling_type
|
||||
self.normalize = normalize
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
"""Pools specific information from hidden states based on metadata."""
|
||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
if self.pooling_type == PoolingType.LAST:
|
||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||
pooled_data = hidden_states[last_token_flat_indices]
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
|
||||
if self.normalize:
|
||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||
]
|
||||
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||
SequenceOutput)
|
||||
|
||||
# (num_token_ids, num_parent_ids) per sequence group.
|
||||
SampleResultType = List[Tuple[List[int], List[int]]]
|
||||
@ -1019,7 +1020,7 @@ def _build_sampler_output(
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||
sampler_output.append(
|
||||
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||
|
||||
# If not specified, store None values in SamplerOutput.
|
||||
if on_device_tensors is not None:
|
||||
|
@ -9,7 +9,7 @@ from vllm.utils import is_hip
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Architecture -> (module, class).
|
||||
_MODELS = {
|
||||
_GENERATION_MODELS = {
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||
@ -58,6 +58,12 @@ _MODELS = {
|
||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||
}
|
||||
|
||||
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
|
||||
|
||||
# Architecture -> type.
|
||||
# out of tree models
|
||||
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
||||
@ -114,6 +120,10 @@ class ModelRegistry:
|
||||
global _OOT_MODELS
|
||||
_OOT_MODELS[model_arch] = model_cls
|
||||
|
||||
@staticmethod
|
||||
def is_embedding_model(model_arch: str) -> bool:
|
||||
return model_arch in _EMBEDDING_MODELS
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModelRegistry",
|
||||
|
87
vllm/model_executor/models/llama_embedding.py
Normal file
87
vllm/model_executor/models/llama_embedding.py
Normal file
@ -0,0 +1,87 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import PoolerOutput
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module):
|
||||
"""A model that uses Llama with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the LlamaModel and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of LlamaModel used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = LlamaModel(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.forward(input_ids, positions, kv_caches,
|
||||
attn_metadata, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.model.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
69
vllm/model_executor/pooling_metadata.py
Normal file
69
vllm/model_executor/pooling_metadata.py
Normal file
@ -0,0 +1,69 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class PoolingMetadata:
|
||||
"""Metadata for pooling operations in the Pooler layer.
|
||||
|
||||
This class holds the necessary information for pooling operations,
|
||||
providing context for how to perform pooling and other related operations.
|
||||
|
||||
Attributes:
|
||||
seq_groups: List of (seq_ids, pooling_params).
|
||||
seq_data: A mapping of sequence ID to additional sequence data.
|
||||
prompt_lens: List of the lengths of each prompt.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]],
|
||||
seq_data: Dict[int, Any], # Specific data related to sequences
|
||||
prompt_lens: List[int],
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("PoolingMetadata("
|
||||
f"seq_groups={self.seq_groups}, "
|
||||
f"seq_data={self.seq_data}, "
|
||||
f"prompt_lens={self.prompt_lens})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingTensors:
|
||||
"""Tensors for pooling."""
|
||||
|
||||
prompt_lens: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def from_pooling_metadata(
|
||||
cls,
|
||||
pooling_metadata: "PoolingMetadata",
|
||||
device: torch.device,
|
||||
) -> "PoolingTensors":
|
||||
"""
|
||||
Create PoolingTensors from PoolingMetadata.
|
||||
|
||||
Args:
|
||||
pooling_metadata: PoolingMetadata instance to convert.
|
||||
device: Device to store the tensors.
|
||||
"""
|
||||
# Convert prompt lengths to tensor
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
prompt_lens_t = torch.tensor(
|
||||
pooling_metadata.prompt_lens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
return cls(prompt_lens=prompt_lens_t.to(device=device,
|
||||
non_blocking=True), )
|
@ -57,8 +57,27 @@ class CompletionOutput:
|
||||
f"stop_reason={self.stop_reason})")
|
||||
|
||||
|
||||
class EmbeddingOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
Args:
|
||||
embedding: The embedding vector, which is a list of floats. The
|
||||
length of vector depends on the model as listed in the embedding guide.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: List[float],
|
||||
) -> None:
|
||||
self.embedding = embedding
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"EmbeddingOutput("
|
||||
f"embedding={len(self.embedding)}")
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
"""The output data of a request to the LLM.
|
||||
"""The output data of a completion request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
@ -93,6 +112,9 @@ class RequestOutput:
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
if seq_group.sampling_params is None:
|
||||
raise ValueError(
|
||||
"Sampling parameters are missing for a CompletionRequest.")
|
||||
seqs = seq_group.get_seqs()
|
||||
if len(seqs) == 1:
|
||||
top_n_seqs = seqs
|
||||
@ -148,3 +170,61 @@ class RequestOutput:
|
||||
f"finished={self.finished}, "
|
||||
f"metrics={self.metrics}, "
|
||||
f"lora_request={self.lora_request})")
|
||||
|
||||
|
||||
class EmbeddingRequestOutput:
|
||||
"""
|
||||
The output data of an embedding request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id (str): A unique identifier for the embedding request.
|
||||
outputs (EmbeddingOutput): The embedding results for the given input.
|
||||
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
|
||||
finished (bool): A flag indicating whether the embedding is completed.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
|
||||
prompt_token_ids: List[int], finished: bool):
|
||||
self.request_id = request_id
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.finished = finished
|
||||
self.outputs = outputs
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls,
|
||||
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
|
||||
if seq_group.embeddings is None:
|
||||
raise ValueError(
|
||||
"Embeddings are missing in seq_group for EmbeddingRequest.")
|
||||
output = EmbeddingOutput(seq_group.embeddings)
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
finished = seq_group.is_finished()
|
||||
|
||||
return cls(seq_group.request_id, output, prompt_token_ids, finished)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Returns a string representation of an EmbeddingRequestOutput instance.
|
||||
|
||||
The representation includes the request_id and the number of outputs,
|
||||
providing a quick overview of the embedding request's results.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the EmbeddingRequestOutput instance.
|
||||
"""
|
||||
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
|
||||
f"outputs={repr(self.outputs)}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"finished={self.finished})")
|
||||
|
||||
|
||||
class RequestOutputFactory:
|
||||
|
||||
@staticmethod
|
||||
def create(seq_group):
|
||||
# Determine the type based on a condition, for example:
|
||||
if hasattr(seq_group,
|
||||
'embeddings') and seq_group.embeddings is not None:
|
||||
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
||||
else:
|
||||
return RequestOutput.from_seq_group(seq_group)
|
||||
|
20
vllm/pooling_params.py
Normal file
20
vllm/pooling_params.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class PoolingParams:
|
||||
"""Pooling parameters for pooling.
|
||||
|
||||
Attributes:
|
||||
additional_data: Any additional data needed for pooling.
|
||||
"""
|
||||
|
||||
def __init__(self, additional_data: Optional[Any] = None):
|
||||
self.additional_data = additional_data
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return PoolingParams(additional_data=self.additional_data, )
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"PoolingParams("
|
||||
f"additional_metadata={self.additional_data})")
|
@ -1,11 +1,13 @@
|
||||
"""Sequence and its related classes."""
|
||||
import copy
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -375,12 +377,12 @@ class SequenceGroupState:
|
||||
|
||||
class MultiModalData:
|
||||
"""Multi modal request.
|
||||
|
||||
|
||||
Args:
|
||||
type: The data type.
|
||||
data: The actual data.
|
||||
The required shape and semantic meaning of it depends on the vision
|
||||
language config of the hosted model.
|
||||
language config of the hosted model.
|
||||
See `VisionLanguageConfig` in `config.py`.
|
||||
"""
|
||||
|
||||
@ -402,16 +404,22 @@ class SequenceGroup:
|
||||
arrival_time: The arrival time of the request.
|
||||
lora_request: LoRA request.
|
||||
multi_modal_data: Multi modal data associated with the request.
|
||||
embeddings: The embeddings vectors of the prompt of the sequence group
|
||||
for an embedding model.
|
||||
pooling_params: The pooling parameters used to generate the pooling
|
||||
for an embedding model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
seqs: List[Sequence],
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: float,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
@ -425,6 +433,8 @@ class SequenceGroup:
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
self.state = SequenceGroupState()
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.embeddings = embeddings
|
||||
self.pooling_params = pooling_params
|
||||
|
||||
@property
|
||||
def prompt(self) -> str:
|
||||
@ -479,12 +489,13 @@ class SequenceGroup:
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
lifetime of the request."""
|
||||
if self.sampling_params.use_beam_search:
|
||||
if self.sampling_params and self.sampling_params.use_beam_search:
|
||||
# For beam search, maximally there will always be `best_of` beam
|
||||
# candidates running in the future.
|
||||
return self.sampling_params.best_of
|
||||
else:
|
||||
if self.sampling_params.best_of > self.num_seqs():
|
||||
if (self.sampling_params
|
||||
and self.sampling_params.best_of > self.num_seqs()):
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
@ -555,7 +566,7 @@ class SequenceGroup:
|
||||
return all(seq.is_finished() for seq in self.get_seqs())
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
# Every sequences should be in the same stage.
|
||||
# Every sequence should be in the same stage.
|
||||
return self.get_seqs()[0].is_prefill()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -594,6 +605,7 @@ class SequenceGroupMetadata:
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]],
|
||||
do_sample: bool = True,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
token_chunk_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
@ -605,6 +617,7 @@ class SequenceGroupMetadata:
|
||||
self.seq_data = seq_data
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
self.pooling_params = pooling_params
|
||||
self.lora_request = lora_request
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.multi_modal_data = multi_modal_data
|
||||
@ -669,8 +682,20 @@ class SequenceOutput:
|
||||
return equal and log_probs_equal
|
||||
|
||||
|
||||
class SequenceGroupOutput:
|
||||
"""The model output associated with a sequence group."""
|
||||
class SequenceGroupOutput(ABC):
|
||||
"""The base class for model outputs associated with a sequence group."""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: object) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
||||
"""The model output associated with a completion sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -682,26 +707,45 @@ class SequenceGroupOutput:
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroupOutput(samples={self.samples}, "
|
||||
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceGroupOutput):
|
||||
if not isinstance(other, CompletionSequenceGroupOutput):
|
||||
raise NotImplementedError()
|
||||
return (self.samples == other.samples
|
||||
and self.prompt_logprobs == other.prompt_logprobs)
|
||||
|
||||
|
||||
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
||||
"""The model output associated with an embedding sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: List[float],
|
||||
) -> None:
|
||||
self.embeddings = embeddings
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"EmbeddingSequenceGroupOutput("
|
||||
f"embeddings_shape={len(self.embeddings)})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, EmbeddingSequenceGroupOutput):
|
||||
raise NotImplementedError()
|
||||
return self.embeddings == other.embeddings
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||
each of which contains one possible candidate for the next token.
|
||||
|
||||
This datastructure implements methods so it can be used like a list, but
|
||||
This data structure implements methods, so it can be used like a list, but
|
||||
also has optional fields for device tensors.
|
||||
"""
|
||||
|
||||
outputs: List[SequenceGroupOutput]
|
||||
outputs: List[CompletionSequenceGroupOutput]
|
||||
|
||||
# On-device tensor containing probabilities of each token.
|
||||
sampled_token_probs: Optional["torch.Tensor"] = None
|
||||
@ -742,6 +786,27 @@ class SamplerOutput:
|
||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolerOutput:
|
||||
"""The output from a pooling operation in the embedding model."""
|
||||
outputs: List[EmbeddingSequenceGroupOutput]
|
||||
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other,
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelRequest:
|
||||
"""The model execution request."""
|
||||
|
@ -4,7 +4,8 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
@ -94,7 +95,7 @@ def create_sequence_group_output(
|
||||
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
||||
})
|
||||
|
||||
return SequenceGroupOutput(
|
||||
return CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
|
266
vllm/worker/embedding_model_runner.py
Normal file
266
vllm/worker/embedding_model_runner.py
Normal file
@ -0,0 +1,266 @@
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import BatchType, ModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbeddingModelRunner(ModelRunner):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
):
|
||||
super().__init__(model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
vision_language_config=vision_language_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[PoolerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_input
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
if self.lora_config:
|
||||
self.set_active_loras(lora_requests, lora_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
prefill_meta = attn_metadata.prefill_metadata
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
graph_batch_size = input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
if self.vision_language_config:
|
||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
return self.model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||
if self.is_driver_worker:
|
||||
prefill_reqs = []
|
||||
decode_reqs = []
|
||||
for seq_group_meta in seq_group_metadata_list:
|
||||
if seq_group_meta.is_prompt:
|
||||
prefill_reqs.append(seq_group_meta)
|
||||
else:
|
||||
decode_reqs.append(seq_group_meta)
|
||||
|
||||
# Prepare input tensors.
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
prefill_attn_metadata,
|
||||
prompt_lens,
|
||||
subquery_lens,
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
lora_requests,
|
||||
multi_modal_input,
|
||||
slot_mapping,
|
||||
) = self._prepare_prompt(prefill_reqs)
|
||||
(
|
||||
decode_input_tokens,
|
||||
decode_input_positions,
|
||||
decode_attn_metadata,
|
||||
decode_lora_index_mapping,
|
||||
decode_lora_prompt_mapping,
|
||||
decode_lora_requests,
|
||||
decode_slot_mapping,
|
||||
) = self._prepare_decode(decode_reqs)
|
||||
|
||||
# Prepare PoolingMetadata
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
|
||||
if not self.scheduler_config.chunked_prefill_enabled:
|
||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||
|
||||
num_prefills = len(prompt_lens)
|
||||
num_prefill_tokens = len(input_tokens)
|
||||
num_decode_tokens = len(decode_input_tokens)
|
||||
|
||||
# Coalesce tensors. Note that attn_metadata is currently not
|
||||
# coalesced for simplicity.
|
||||
input_tokens.extend(decode_input_tokens)
|
||||
input_positions.extend(decode_input_positions)
|
||||
slot_mapping.extend(decode_slot_mapping)
|
||||
lora_index_mapping.extend(decode_lora_index_mapping)
|
||||
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
||||
lora_requests.update(decode_lora_requests)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
# Broadcast the metadata.
|
||||
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
||||
# If it only contains 1 type, it triggers a single broadcast.
|
||||
if (prefill_attn_metadata is not None
|
||||
and decode_attn_metadata is not None):
|
||||
batch_type = BatchType.MIXED
|
||||
elif prefill_attn_metadata is not None:
|
||||
batch_type = BatchType.PREFILL
|
||||
else:
|
||||
batch_type = BatchType.DECODE
|
||||
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
"multi_modal_input": multi_modal_input,
|
||||
"num_prefill_tokens": num_prefill_tokens,
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
"num_prefills": num_prefills,
|
||||
"batch_type": batch_type,
|
||||
}
|
||||
if prefill_attn_metadata is not None:
|
||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
||||
else:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
|
||||
# Broadcast decode attn metadata for mixed batch type.
|
||||
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
||||
# We can potentially reduce the overhead by coelescing tensors.
|
||||
if batch_type == BatchType.MIXED:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
slot_mapping = metadata_dict.pop("slot_mapping")
|
||||
num_prefills = metadata_dict.pop("num_prefills")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
||||
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
||||
batch_type = metadata_dict.pop("batch_type")
|
||||
|
||||
# Create an attention metadata.
|
||||
prefill_attn_metadata = None
|
||||
decode_attn_metadata = None
|
||||
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
||||
prefill_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
else:
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
|
||||
pooling_metadata = PoolingMetadata(seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None)
|
||||
|
||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
||||
# separately.
|
||||
if batch_type == BatchType.MIXED:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
prefill_metadata=prefill_attn_metadata,
|
||||
decode_metadata=decode_attn_metadata,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_input)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> PoolingMetadata:
|
||||
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
return pooling_metadata
|
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -287,18 +287,18 @@ class ModelRunner:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
||||
lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(seq_len - context_len
|
||||
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||
lora_prompt_mapping.extend([lora_id] * (
|
||||
seq_len - context_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
seq_group_metadata.multi_modal_data.data)
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
@ -813,7 +813,6 @@ class ModelRunner:
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
# This represents the maximum number of different requests
|
||||
# that will have unique loras, an therefore the max amount of memory
|
||||
# consumption create dummy lora request copies from the lora request
|
||||
@ -1139,3 +1138,15 @@ def _prepare_fake_inputs(
|
||||
prompt_tokens = [0] * seq_len
|
||||
fake_image_input = None
|
||||
return SequenceData(prompt_tokens), fake_image_input
|
||||
|
||||
|
||||
def _is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||
"""
|
||||
Check if block_tables is None or a dictionary with all None values.
|
||||
"""
|
||||
if block_tables is None:
|
||||
return True
|
||||
if isinstance(block_tables, dict) and all(
|
||||
value is None for value in block_tables.values()):
|
||||
return True
|
||||
return False
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -16,8 +16,9 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
init_custom_ar)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
@ -68,7 +69,9 @@ class Worker(WorkerBase):
|
||||
assert not self.lora_config, (
|
||||
"To be tested: vision language model with LoRA settings.")
|
||||
|
||||
self.model_runner = ModelRunner(
|
||||
ModelRunnerClass = (EmbeddingModelRunner if
|
||||
self.model_config.embedding_mode else ModelRunner)
|
||||
self.model_runner = ModelRunnerClass(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
@ -83,7 +86,8 @@ class Worker(WorkerBase):
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: CacheEngine
|
||||
self.gpu_cache: List[torch.Tensor]
|
||||
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||||
self.gpu_cache: Optional[List[torch.tensor]] = None
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
@ -209,7 +213,7 @@ class Worker(WorkerBase):
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||
|
||||
if execute_model_req is None:
|
||||
seq_group_metadata_list = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user