[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)

Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
afeldman-nm 2024-08-06 16:51:47 -04:00 committed by GitHub
parent 660470e5a3
commit fd95e026e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 3957 additions and 333 deletions

View File

@ -148,8 +148,9 @@ steps:
- python3 cpu_offload.py - python3 cpu_offload.py
- python3 offline_inference_with_prefix.py - python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py - python3 llm_engine_example.py
- python3 llava_example.py - python3 offline_inference_vision_language.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- label: Models Test # 1hr10min - label: Models Test # 1hr10min
source_file_dependencies: source_file_dependencies:
@ -289,6 +290,7 @@ steps:
commands: commands:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
- pytest -v -s distributed/test_chunked_prefill_distributed.py - pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py

View File

@ -0,0 +1,99 @@
'''
Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART
'''
from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists
dtype = "float"
# Create a BART encoder/decoder model instance
llm = LLM(
model="facebook/bart-large-cnn",
dtype=dtype,
)
# Get BART tokenizer
tokenizer = llm.llm_engine.get_tokenizer_group()
# Test prompts
#
# This section shows all of the valid ways to prompt an
# encoder/decoder model.
#
# - Helpers for building prompts
text_prompt_raw = "Hello, my name is"
text_prompt = TextPrompt(prompt="The president of the United States is")
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
prompt="The capital of France is"))
# - Pass a single prompt to encoder/decoder model
# (implicitly encoder input prompt);
# decoder input prompt is assumed to be None
single_text_prompt_raw = text_prompt_raw # Pass a string directly
single_text_prompt = text_prompt # Pass a TextPrompt
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
# - Pass explicit encoder and decoder input prompts within one data structure.
# Encoder and decoder prompts can both independently be text or tokens, with
# no requirement that they be the same prompt type. Some example prompt-type
# combinations are shown below, note that these are not exhaustive.
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt string directly, &
# pass decoder prompt tokens
encoder_prompt=single_text_prompt_raw,
decoder_prompt=single_tokens_prompt,
)
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
# Pass TextPrompt to encoder, and
# pass decoder prompt string directly
encoder_prompt=single_text_prompt,
decoder_prompt=single_text_prompt_raw,
)
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt tokens directly, and
# pass TextPrompt to decoder
encoder_prompt=single_tokens_prompt,
decoder_prompt=single_text_prompt,
)
# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompt_lists(
['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt'])
# - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM.
prompts = [
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
] + zipped_prompt_list
print(prompts)
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
max_tokens=20,
)
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")

View File

@ -10,9 +10,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoTokenizer, BatchEncoding, BatchFeature) AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
BatchFeature)
from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
@ -21,9 +23,11 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel) destroy_model_parallel)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu) is_cpu, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -120,6 +124,40 @@ def example_prompts() -> List[str]:
return prompts return prompts
@pytest.fixture
def example_encoder_decoder_prompts() \
-> Dict[DecoderPromptType,
Tuple[List[str], List[Optional[str]]]]:
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
decoder prompt) tuple.
Returns:
* Encoder prompt list
* Decoder prompt list (reverse of encoder prompt list)
'''
encoder_prompts = []
for filename in _TEST_PROMPTS:
encoder_prompts += _read_prompts(filename)
custom_decoder_prompts = encoder_prompts[::-1]
empty_str_decoder_prompts = [""] * len(encoder_prompts)
none_decoder_prompts = [None] * len(encoder_prompts)
# NONE decoder prompt type
return {
DecoderPromptType.NONE:
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
DecoderPromptType.EMPTY_STR:
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
DecoderPromptType.CUSTOM:
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
}
@pytest.fixture @pytest.fixture
def example_long_prompts() -> List[str]: def example_long_prompts() -> List[str]:
prompts = [] prompts = []
@ -152,6 +190,7 @@ class HfRunner:
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False, is_embedding_model: bool = False,
is_vision_model: bool = False, is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
) -> None: ) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
@ -168,6 +207,8 @@ class HfRunner:
else: else:
if is_vision_model: if is_vision_model:
auto_cls = AutoModelForVision2Seq auto_cls = AutoModelForVision2Seq
elif is_encoder_decoder_model:
auto_cls = AutoModelForSeq2SeqLM
else: else:
auto_cls = AutoModelForCausalLM auto_cls = AutoModelForCausalLM
@ -314,6 +355,44 @@ class HfRunner:
all_logprobs.append(seq_logprobs) all_logprobs.append(seq_logprobs)
return all_logprobs return all_logprobs
def _hidden_states_to_logprobs(
self,
hidden_states,
num_logprobs,
) -> Tuple[List[Dict[int, float]], int]:
seq_logprobs: List[torch.Tensor] = []
output_len = len(hidden_states)
for _, hidden_state in enumerate(hidden_states):
last_hidden_states = hidden_state[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias",
None) is not None:
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
# convert to dict
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)
tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()
seq_logprobs_lst.append(tok_logprobs_dct)
return (
seq_logprobs_lst,
output_len,
)
def generate_greedy_logprobs_limit( def generate_greedy_logprobs_limit(
self, self,
prompts: List[str], prompts: List[str],
@ -346,33 +425,11 @@ class HfRunner:
**kwargs, **kwargs,
) )
seq_logprobs: List[torch.Tensor] = [] (
for _, hidden_states in enumerate(output.hidden_states): seq_logprobs_lst,
last_hidden_states = hidden_states[-1][0] output_len,
logits = torch.matmul( ) = self._hidden_states_to_logprobs(output.hidden_states,
last_hidden_states, num_logprobs)
self.model.get_output_embeddings().weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias",
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
# convert to dict
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)
tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()
seq_logprobs_lst.append(tok_logprobs_dct)
all_logprobs.append(seq_logprobs_lst) all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0] seq_ids = output.sequences[0]
@ -385,6 +442,57 @@ class HfRunner:
return [(output_ids, output_str, output_logprobs) return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs] for output_ids, output_str, output_logprobs in outputs]
def generate_encoder_decoder_greedy_logprobs_limit(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for (encoder_prompt,
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
encoder_input_ids = self.wrap_device(
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
decoder_input_ids = (
None if decoder_prompt is None else self.wrap_device(
self.tokenizer(decoder_prompt,
return_tensors="pt").input_ids))
output = self.model.generate(
encoder_input_ids,
decoder_input_ids=decoder_input_ids,
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts) return self.model.encode(prompts)
@ -416,7 +524,7 @@ class VllmRunner:
block_size: int = 16, block_size: int = 16,
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: Optional[bool] = False,
**kwargs, **kwargs,
) -> None: ) -> None:
self.model = LLM( self.model = LLM(
@ -465,6 +573,19 @@ class VllmRunner:
outputs.append((req_sample_output_ids, req_sample_output_strs)) outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs return outputs
def _final_steps_generate_w_logprobs(
self,
req_outputs: List[RequestOutput],
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
def generate_w_logprobs( def generate_w_logprobs(
self, self,
prompts: List[str], prompts: List[str],
@ -483,14 +604,21 @@ class VllmRunner:
req_outputs = self.model.generate(inputs, req_outputs = self.model.generate(inputs,
sampling_params=sampling_params) sampling_params=sampling_params)
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] return self._final_steps_generate_w_logprobs(req_outputs)
for req_output in req_outputs:
for sample in req_output.outputs: def generate_encoder_decoder_w_logprobs(
output_str = sample.text self,
output_ids = sample.token_ids encoder_decoder_prompts: Tuple[List[str], List[str]],
output_logprobs = sample.logprobs sampling_params: SamplingParams,
outputs.append((output_ids, output_str, output_logprobs)) ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
return outputs '''
Logprobs generation for vLLM encoder/decoder models
'''
assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
def generate_greedy( def generate_greedy(
self, self,
@ -523,6 +651,26 @@ class VllmRunner:
return [(output_ids, output_str, output_logprobs) return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs] for output_ids, output_str, output_logprobs in outputs]
def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs)
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
outputs = self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def generate_beam_search( def generate_beam_search(
self, self,
prompts: List[str], prompts: List[str],

View File

@ -9,33 +9,11 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus from vllm.sequence import SequenceGroup, SequenceStatus
from .utils import create_dummy_prompt from .utils import (append_new_token, append_new_token_seq_group,
create_dummy_prompt, get_sequence_groups,
schedule_and_update_computed_tokens)
def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def append_new_token(out, token_id: int):
seq_groups = get_sequence_groups(out)
for seq_group in seq_groups:
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def test_scheduler_add_seq_group(): def test_scheduler_add_seq_group():

View File

@ -0,0 +1,99 @@
from typing import List
import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.sequence import SequenceGroup
from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
get_sequence_groups, schedule_and_update_computed_tokens)
def test_scheduler_schedule_simple_encoder_decoder():
'''
Test basic scheduler functionality in the context
of an encoder/decoder model. Focus on testing
enc/dec-specific functionality sense tests already
exist for decoder-only functionality
Test behavior:
* Construct Scheduler
* Construct dummy encoder/decoder sequence groups
* Add dummy seq groups to scheduler backlog
* Schedule the next seq group & validate:
* Cross-attn block tables
* Updated states of seq groups
* Number of batched tokens
* Number of blocks to copy/swap-in/swap-out
* Number of scheduled seq groups
* Repeat for both prefill- and decode-phase
* Abort scheduled seq groups
* Assert that aborted seq groups no longer appear in
cross-attention block table
'''
block_size = 4
num_seq_group = 4
max_model_len = 16
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
req_id_list = []
for i in range(num_seq_group):
req_id = str(i)
req_id_list.append(req_id)
_, _, seq_group = create_dummy_prompt_encoder_decoder(
req_id, block_size, block_size, block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Schedule seq groups prefill.
num_tokens = block_size * num_seq_group
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
# - Verify that sequence group cross-attention block tables are
# registered with the block manager
assert all([(req_id in scheduler.block_manager.cross_block_tables)
for req_id in req_id_list])
# - Validate sequence-group status
assert set(get_sequence_groups(out)) == set(running)
# - Validate number of batched tokens
assert out.num_batched_tokens == num_tokens
# - Validate there are no remaining blocks to swap
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
# - Validate all seq groups were scheduled
assert len(seq_group_meta_list) == num_seq_group
append_new_token(out, 1)
# Schedule seq groups decode.
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
# - Verify that sequence group metadata includes encoder attention
# and cross-attention metadata
assert all([
not ((seq_group_meta.encoder_seq_data is None) or
(seq_group_meta.cross_block_table is None))
for seq_group_meta in seq_group_meta_list
])
# - Validate sequence-group status
assert set(get_sequence_groups(out)) == set(running)
# - Validate there is one batched token per seq group
assert out.num_batched_tokens == num_seq_group
# - Validate there are no remaining blocks to swap
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
# - Validate that all seq groups were scheduled
assert len(seq_group_meta_list) == num_seq_group
append_new_token(out, 1)
# Abort sequences
for req_id in req_id_list:
scheduler.abort_seq_group(req_id)
# - Verify that sequence group cross-attention block tables are
# NO LONGER registered with the block manager
assert req_id not in scheduler.block_manager.cross_block_tables

View File

@ -53,27 +53,30 @@ def create_dummy_prompt_encoder_decoder(
block_size = decoder_prompt_length block_size = decoder_prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1 # Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size". # and prompt "0 ... block_size". Note that the prompt string
# doesn't actually match the tokens
decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_tokens = list(range(decoder_prompt_length))
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
decoder_prompt = Sequence(int(request_id),
inputs = { inputs = {
"prompt": decoder_prompt_str, "prompt": decoder_prompt_str,
"prompt_token_ids": decoder_prompt_tokens, "prompt_token_ids": decoder_prompt_tokens,
"encoder_prompt": encoder_prompt_str,
"encoder_prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None, "multi_modal_data": None,
}, }
block_size=block_size)
decoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=True)
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
encoder_prompt = Sequence(int(request_id), encoder_prompt = Sequence(int(request_id),
inputs={ inputs=inputs,
"prompt": encoder_prompt_str, block_size=block_size,
"prompt_token_ids": encoder_prompt_tokens, from_decoder_prompt=False)
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt], seqs=[decoder_prompt],
sampling_params=SamplingParams( sampling_params=SamplingParams(
@ -139,17 +142,21 @@ def create_seq_group_encoder_decoder(
prompt_token_ids = [0] * seq_prompt_len prompt_token_ids = [0] * seq_prompt_len
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs = { inputs = {
"prompt": "", "prompt": "",
"prompt_token_ids": prompt_token_ids, "prompt_token_ids": prompt_token_ids,
"encoder_prompt": "",
"encoder_prompt_token_ids": prompt_token_ids,
"multi_modal_data": None, "multi_modal_data": None,
}, }
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
# Construct decoder input sequences
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
inputs=inputs,
block_size=16, block_size=16,
) from_decoder_prompt=True)
for i in range(output_len): for i in range(output_len):
seq.append_token_id( seq.append_token_id(
@ -158,16 +165,11 @@ def create_seq_group_encoder_decoder(
) )
seqs.append(seq) seqs.append(seq)
# Encoder sequence # Encoder input sequence
encoder_seq = Sequence( encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
seq_id=seq_id_start + len(seq_output_lens), inputs=inputs,
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16, block_size=16,
) from_decoder_prompt=False)
return SequenceGroup(request_id=request_id, return SequenceGroup(request_id=request_id,
seqs=seqs, seqs=seqs,
@ -178,3 +180,30 @@ def create_seq_group_encoder_decoder(
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size
# Helper functions for scheduler tests
def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def append_new_token(out, token_id: int):
seq_groups = get_sequence_groups(out)
for seq_group in seq_groups:
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})

View File

@ -0,0 +1,101 @@
"""For encoder/decoder models only:
Compare the outputs of HF and distributed vLLM when using greedy sampling.
Run:
```sh
cd $VLLM_PATH/tests
pytest distributed/test_basic_distributed_correctness_enc_dec.py
```
"""
import pytest
from tests.models.utils import DecoderPromptType
from vllm.utils import cuda_device_count_stateless
from ..models.utils import check_logprobs_close
from ..utils import fork_new_process_for_each_test
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model, distributed_executor_backend", [
("facebook/bart-large-cnn", "ray"),
("facebook/bart-large-cnn", "mp"),
])
@fork_new_process_for_each_test
def test_models(
model: str,
distributed_executor_backend: str,
hf_runner,
vllm_runner,
example_encoder_decoder_prompts,
) -> None:
'''
Test vLLM BART inference on more than one GPU, comparing
outputs against HF as a baseline.
Fork a new process for each test, to prevent CUDA from
being re-initialized by successive tests within the same
process.
Arguments:
* model: the HF ID of the specific BART variant under test
* distributed_executor_backend
* hf_runner: HuggingFace (HF) test model runner
* vllm_runner: vLLM test model runner
* example_encoder_decoder_prompts: test fixture which provides a
dictionary of dummy prompts
'''
dtype = "float"
max_tokens = 64
num_logprobs = 5
# Example inputs with non-trivial (i.e. not None/empty) encoder &
# decoder prompts.
test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_prompts, max_tokens, num_logprobs)
# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -3,9 +3,9 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, from tests.kernels.utils import override_backend_env_variable
override_backend_env_variable)
from vllm.attention.selector import which_attn_to_use from vllm.attention.selector import which_attn_to_use
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -4,8 +4,6 @@ Tests:
* E2E test of Encoder attention + Decoder self-attention + * E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively Encoder/decoder cross-attention (collectively
"encoder/decoder attention") "encoder/decoder attention")
* Confirm enc/dec models will fail for chunked prefill
* Confirm enc/dec models will fail for prefix caching
""" """
@ -15,19 +13,22 @@ import pytest
import torch import torch
from tests.kernels.utils import * from tests.kernels.utils import *
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention import Attention, AttentionMetadata AttentionType)
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.utils import is_hip from vllm.utils import is_hip
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
HEAD_SIZES = [64, 256] HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16] NUM_HEADS = [1, 16]
BATCH_SIZES = [1, 16] BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16] BLOCK_SIZES = [16]
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
CUDA_DEVICE = "cuda:0" CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS = [128] MAX_DEC_SEQ_LENS = [128]
@ -724,23 +725,58 @@ def _run_encoder_decoder_cross_attention_test(
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_encoder_only(num_heads: int, head_size: int, backend_name: str, def test_encoder_only(
batch_size: int, block_size: int, max_dec_seq_len: int, num_heads: int,
max_enc_seq_len: int, monkeypatch): head_size: int,
attn_backend: _Backend,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
):
'''
End-to-end encoder-only attention test:
* Construct fake test vectors for (1) encoder attention
* Construct (1) attention metadata structure with prefill-phase
encoder attention, and (2) an analogous attention metadata
structure but for decode-phase
* Test & validate encoder attention against ideal output
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
to be utilized.
Arguments:
* num_heads
* head_size,
* attn_backend: The attention backend to employ for testing
* batch_size
* block_size: KV cache block size
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend # Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name) with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally # Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size # to be more than necessary, since exceeding the kv cache size
# is not part of this test # is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, test_pt = TestPoint(num_heads, head_size, attn_backend.name,
block_size, max_dec_seq_len, max_enc_seq_len, 4096) batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
@ -774,7 +810,7 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
def test_e2e_enc_dec_attn( def test_e2e_enc_dec_attn(
num_heads: int, num_heads: int,
head_size: int, head_size: int,
backend_name: str, attn_backend: _Backend,
batch_size: int, batch_size: int,
block_size: int, block_size: int,
max_dec_seq_len: int, max_dec_seq_len: int,
max_enc_seq_len: int, max_enc_seq_len: int,
monkeypatch,
) -> None: ) -> None:
''' '''
End-to-end encoder/decoder test: End-to-end encoder/decoder test:
@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn(
cross-attention K/Vs are allowed to differ in seq len, as is often the case cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention. for cross-attention.
This test utilizes PyTest monkey patching to force the attention backend This test globally forces an override of the usual backend
via an environment variable. auto-selection process, forcing the specific backend-under-test
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip(). AMD GPUs, therefore this test simply is skipped if is_hip().
@ -830,23 +866,34 @@ def test_e2e_enc_dec_attn(
all prefill-phase attention operations (encoder, decoder, enc/dec cross), all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior (decoder & enc/dec cross.) This is intended to reflect the behavior
of ModelRunner, which constructs a single attention metadata structure for of EncoderDecoderModelRunner, which constructs a single attention metadata
each prefill or decode run. A realistic scenario would rely on the structure for each prefill or decode run. A realistic scenario would rely
attention backend to utilize the appropriate attention metadata fields on the attention backend to utilize the appropriate attention metadata
according to the value of attn_metadata.attention_type. Thus, this test is fields according to the value of attn_metadata.attention_type. Thus,
organized so as to confirm that the backend-under-test can handle a this test is organized so as to confirm that the backend-under-test can
shared prefill attention metadata structure & a shared decode attention handle a shared prefill attention metadata structure & a shared decode\
metadata structure. attention metadata structure.
Arguments:
* num_heads
* head_size,
* attn_backend: The attention backend to employ for testing
* batch_size
* block_size: KV cache block size
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
''' '''
# Force Attention wrapper backend # Force Attention wrapper backend
override_backend_env_variable(monkeypatch, backend_name) with global_force_attn_backend_context_manager(attn_backend):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally # Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size # to be more than necessary, since exceeding the kv cache size
# is not part of this test # is not part of this test
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, test_pt = TestPoint(num_heads, head_size, attn_backend.name,
block_size, max_dec_seq_len, max_enc_seq_len, 4096) batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
@ -870,8 +917,9 @@ def test_e2e_enc_dec_attn(
cross_block_base_addr, cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs) ) = _decoder_attn_setup(test_pt, test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase & decode-phase # Construct encoder/decoder cross-attention prefill-phase
# test params, including key/value tensors, cross-attention memory-mapping # & decode-phase test params, including key/value tensors,
# cross-attention memory-mapping
( (
prephase_cross_test_params, prephase_cross_test_params,

View File

@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window, sliding_window=sliding_window,
soft_cap=soft_cap, soft_cap=soft_cap,
) )
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"

View File

@ -8,24 +8,10 @@ from typing import Any, List, NamedTuple, Optional, Tuple, Union
import pytest import pytest
import torch import torch
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
AttentionMetadata, AttentionType)
from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.xformers import XFormersBackend
from vllm.utils import make_tensor_with_pad from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
class QKVInputs(NamedTuple): class QKVInputs(NamedTuple):

153
tests/models/test_bart.py Normal file
View File

@ -0,0 +1,153 @@
"""Compare the outputs of HF and vLLM for BART models using greedy sampling.
Run `pytest tests/models/test_bart.py`.
"""
from vllm.utils import is_cpu
if not is_cpu():
# CPU backend is not currently supported with encoder/decoder models
# skip test definitions entirely to avoid importing GPU kernel libs
# (xFormers, etc.)
import pytest
from tests.models.utils import DecoderPromptType
from .utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
DECODER_PROMPT_TYPES = ([
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
DecoderPromptType.NONE
])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
def test_models(
hf_runner,
vllm_runner,
example_encoder_decoder_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
decoder_prompt_type: DecoderPromptType,
) -> None:
'''
Test the vLLM BART model for a variety of encoder/decoder input prompts,
by validating it against HuggingFace (HF) BART.
Arguments:
* hf_runner: HuggingFace (HF) test model runner
* vllm_runner: vLLM test model runner
* example_encoder_decoder_prompts: test fixture which provides a
dictionary of dummy prompts
* model: the HF ID of the specific BART variant under test
* dtype: the tensor datatype to employ
* max_tokens
* num_logprobs
* decoder_prompt_type: key into the example_encoder_decoder_prompts
dictionary; selects specific encoder/decoder
prompt scenarios to test
A note on using HF BART as a baseline for validating vLLM BART,
specifically when the decoder prompt is None.
The HF GenerationMixin's default behavior is to force the first
decoded token to be <BOS> if the prompt does not already contain
<BOS> (this is accomplished using a logit
processor setting.)
So when we use HF BART as our baseline for comparison, note that
when the user provides a request with a None decoder prompt
(i.e. a singleton encoder prompt, or else an explicit encoder/
decoder prompt with the decoder sub-prompt set to None), HF and
vLLM handle this in different ways:
* HF will (1) tokenize the None prompt as an empty token-list,
(2) append <decoder-start-token> to the beginning, yielding
[<decoder-start-token>], (3) pass this token list to the model, and
then (4) after computing logits during prefill, override the model
logits & force <BOS> to be the first generated token.
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
start-token to the beginning, yielding [<decoder-start-token><BOS>],
(3) pass these tokens to the model & proceed with generation.
The net effect is that compared to vLLM, the list of HF *decoded* tokens
will contain one more initial <BOS> than the vLLM generated tokens,
because vLLM's <BOS> token is injected into the prompt rather than into
the generated output. This is in spite of the fact that overall, the
complete sequences (prompt + decoded tokens) produced by vLLM will match
HF.
So when we use HF decoded token output to validate vLLM's decoded token
output, the testing process must account for the difference in decoded
token sequences between vLLM and HF specifically in the
decoder-prompt-is-None case.
One option is to disable the logit processor feature that forces the
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
the problem entirely. However this is not "normal" BART usage.
The other option is - only in the decoder-prompt-is-None case - to
discard the first decoded token from the HF output before comparing it
to vLLM.
To that end, when testing the scenario where the decoder prompt is None
(and only in that one scenario), this test skips the first HF decoded
token during the process of validating the vLLM decoded output.
'''
test_case_prompts = example_encoder_decoder_prompts[
decoder_prompt_type]
# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
# Note: currently encoder/decoder models are only compatible with
# enforce_eager=True. Normally this is not a problem because
# for encoder/decoder models vLLM will
# default to enforce_eager=True if enforce_eager
# is left unspecified. However, the
# VllmRunner test fixture (which wraps around the LLM class) defaults to
# enforce_eager=False (a behavior which a number of already-exisitng
# decoder-only unit tests expect), so when testing an encoder/decoder
# model we must explicitly specify enforce_eager=True in the VllmRunner
# constructor.
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)
check_logprobs_close(outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens)

View File

@ -1,4 +1,5 @@
import warnings import warnings
from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
@ -45,11 +46,27 @@ def check_logprobs_close(
outputs_1_lst: Sequence[TokensTextLogprobs], outputs_1_lst: Sequence[TokensTextLogprobs],
name_0: str, name_0: str,
name_1: str, name_1: str,
num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True, warn_on_mismatch: bool = True,
): ):
""" """
Compare the logprobs of two sequences generated by different models, Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal. which should be similar but not necessarily equal.
Arguments:
* outputs_0_lst: First sequence to compare
* outputs_0_lst: Second sequence to compare
* name_0: sequence #0 name
* name_1: sequence #1 name
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard
before comparison, i.e. all
of sequence #1 will be compared to
sequence #0 beginning at index
num_outputs_0_skip_tokens
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences
""" """
assert len(outputs_0_lst) == len(outputs_1_lst) assert len(outputs_0_lst) == len(outputs_1_lst)
@ -65,6 +82,15 @@ def check_logprobs_close(
if logprobs_1 is None: if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1) logprobs_1 = [None] * len(output_ids_1)
# Skip specified number of initial sequence #0 tokens
# & logprobs, leaving output text as-is for simplicity
# (text mismatches may generate warnings but do not
# cause the test to fail.)
if num_outputs_0_skip_tokens < 0:
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
# Loop through generated tokens. # Loop through generated tokens.
for idx, (output_id_0, for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
@ -110,3 +136,13 @@ def check_logprobs_close(
warnings.simplefilter("always") warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2) warnings.warn(fail_msg, stacklevel=2)
class DecoderPromptType(Enum):
'''
For encoder/decoder models only -
'''
CUSTOM = 1
NONE = 2
EMPTY_STR = 3

View File

@ -0,0 +1,480 @@
from typing import List
import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_cpu
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
# CUDA graph scenarios to test
#
# Currently CUDA graph is not supported
ENFORCE_EAGER = [True]
BATCH_SIZES = [1, 4, 16, 64, 256]
def _create_model_runner(model: str, *args,
**kwargs) -> EncoderDecoderModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = EncoderDecoderModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True,
)
return model_runner
@pytest.mark.skipif(condition=is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_empty_seq_group(enforce_eager, ):
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(
input_tokens,
input_positions,
encoder_input_tokens,
encoder_input_positions,
attn_metadata,
return_seq_lens,
) = (
model_input.input_tokens,
model_input.input_positions,
model_input.encoder_input_tokens,
model_input.encoder_input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
assert input_tokens is None
assert input_positions is None
assert encoder_input_tokens is None
assert encoder_input_positions is None
assert attn_metadata is None
assert return_seq_lens is None
@pytest.mark.skipif(condition=is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_prepare_prompt(
batch_size,
enforce_eager,
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
)
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)
# Verify input metadata is correct for prompts.
# - Decoder attention metadata
device = model_runner.device
assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0
assert torch.equal(attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
assert attn_metadata.max_decode_seq_len == 0
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
# Test decoder subquery start locs.
start_idx = 0
start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
start_loc.append(start_idx)
assert torch.equal(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
# Test decoder seq start locs & context lengths
assert torch.equal(
attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
assert torch.equal(
attn_metadata.context_lens_tensor,
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
dtype=torch.int,
device=device),
)
# Verify block tables are correct for prompts
# - Decoder self-attention
expected = torch.tensor(
[[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device,
)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)
# Cuda graph should not be used for prefill.
assert attn_metadata.use_cuda_graph is False
# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == sum(encoder_seq_lens)
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the prefill phase
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
# Compute the index offset of the final token in each
# prompt (recall that the prompts are concatenated)
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(
expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype,
)
assert torch.equal(actual, expected)
@pytest.mark.skipif(condition=is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_prepare_decode(
batch_size,
enforce_eager,
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
)
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)
# Verify input metadata is correct for decode phase.
# - Decoder attention metadata
device = model_runner.device
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_decode_tokens > 0
assert torch.equal(attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(seq_lens)
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
# Test decoder subquery start locs.
start_idx = 0
start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += 1
start_loc.append(start_idx)
assert torch.equal(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
# Test decoder seq start locs. Note that for normal prefill it is
# equivalent to query_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
# Test seq_start_loc and context lengths
assert torch.equal(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
)
assert torch.equal(
attn_metadata.context_lens_tensor,
torch.tensor([seq_len - 1 for seq_len in seq_lens],
dtype=torch.int,
device=device))
# Verify block tables are correct for prompts
# - Decoder self-attention
expected = torch.tensor(
[block_tables[0] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention
expected = torch.tensor(
[cross_block_table for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)
# Cuda graph should is currently not supported for encoder/decoer.
assert attn_metadata.use_cuda_graph is False
# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == len(seq_lens)
assert len(input_positions) == len(seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == 0
assert len(encoder_input_tokens) == 0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the decode phase
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
# Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length
# of the decoded tokens for a given sequence is 1 and
# the final index offset into a given sequence's
# generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(
expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype,
)
assert torch.equal(actual, expected)

View File

@ -1,6 +1,7 @@
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder) AttentionMetadataBuilder,
AttentionType)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -8,6 +9,7 @@ __all__ = [
"Attention", "Attention",
"AttentionBackend", "AttentionBackend",
"AttentionMetadata", "AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder", "AttentionMetadataBuilder",
"Attention", "Attention",
"get_attn_backend", "get_attn_backend",

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (

View File

@ -1,6 +1,8 @@
import enum import enum
import os
from contextlib import contextmanager
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Generator, Optional, Type
import torch import torch
@ -8,7 +10,8 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
is_tpu, is_xpu)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -24,6 +27,66 @@ class _Backend(enum.Enum):
IPEX = enum.auto() IPEX = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None
backend_members = _Backend.__members__
if backend_name not in backend_members:
raise ValueError(f"Invalid attention backend '{backend_name}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
return _Backend[backend_name]
def get_env_variable_attn_backend() -> Optional[_Backend]:
'''
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* None otherwise
'''
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return (None
if backend_name is None else backend_name_to_enum(backend_name))
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
'''
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
'''
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> Optional[_Backend]:
'''
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
'''
return forced_attn_backend
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_attn_backend( def get_attn_backend(
num_heads: int, num_heads: int,
@ -101,16 +164,20 @@ def which_attn_to_use(
# Default case. # Default case.
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified # Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None: if backend_by_env_var is not None:
backend_members = _Backend.__members__ selected_backend = backend_name_to_enum(backend_by_env_var)
if backend_by_env_var not in backend_members:
raise ValueError(
f"Invalid attention backend '{backend_by_env_var}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
selected_backend = _Backend[backend_by_env_var]
if is_cpu(): if is_cpu():
if selected_backend != _Backend.TORCH_SDPA: if selected_backend != _Backend.TORCH_SDPA:
@ -193,3 +260,35 @@ def which_attn_to_use(
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
return selected_backend return selected_backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: _Backend) -> Generator[None, None, None]:
'''
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
'''
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)

View File

@ -12,7 +12,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu, is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once) print_warning_once)
@ -87,6 +88,9 @@ class ModelConfig:
enforce_eager: Whether to enforce eager execution. If True, we will enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode. disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid. If False, we will use CUDA graph and eager execution in hybrid.
If None, the user did not specify, so default to False -
except for encoder/decoder models, which currently require
eager mode.
max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
@ -121,7 +125,7 @@ class ModelConfig:
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None, quantization_param_path: Optional[str] = None,
enforce_eager: bool = False, enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20, max_logprobs: int = 20,
@ -160,6 +164,34 @@ class ModelConfig:
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None)
if getattr(self.hf_config, 'is_encoder_decoder', False):
if self.enforce_eager is None:
# *Only for encoder/decoder models* and
# *only if enforce_eager is unset*, override
# to enforce_eager=True
#
# Add a logger message since it is *somewhat* non-intuitive that
# enforce_eager is True when the user has not specified its
# value.
logger.info("Forcing enforce_eager == True because "
"enforce_eager setting was unspecified and "
"CUDAGraph is not supported with encoder/ "
"decoder models.")
self.enforce_eager = True
if not self.enforce_eager:
# Eager mode explicitly disabled by user for an encoder/
# decoder model; however CUDAGRAPH + encoder/decoder is
# not currently supported
raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
elif self.enforce_eager is None:
# *Only for decoder-only models*, enforce_eager
# defaults to False if unset. This is intuitive
# so no logging message needed.
self.enforce_eager = False
if (not self.disable_sliding_window if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None): and self.hf_text_config.sliding_window is not None):

View File

@ -1,15 +1,7 @@
"""Block manager utils.""" """Block manager utils."""
from vllm.sequence import SequenceGroup from vllm.sequence import SequenceGroup
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
# Exception strings for non-implemented block manager enc/dec scenarios STR_NOT_IMPL_ENC_DEC_SWA)
STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \
"is not currently supported."
def _get_block_mgr_sliding_window_attr(block_mgr): def _get_block_mgr_sliding_window_attr(block_mgr):

View File

@ -392,6 +392,19 @@ class Scheduler:
seq.status = SequenceStatus.FINISHED_ABORTED seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq) self.free_seq(seq)
self._free_seq_group_cross_attn_blocks(aborted_group)
def _free_seq_group_cross_attn_blocks(
self,
seq_group: SequenceGroup,
) -> None:
"""
Free a sequence group from a cross-attention block table.
Has no effect on decoder-only models.
"""
if seq_group.is_encoder_decoder():
self.block_manager.free_cross(seq_group)
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len( return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0 self.swapped) != 0
@ -963,6 +976,17 @@ class Scheduler:
# seq_id -> physical block numbers # seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
encoder_seq_data = seq_group.get_encoder_seq().data
# Block table for cross-attention
# Also managed at SequenceGroup level
cross_block_table = self.block_manager.get_cross_block_table(
seq_group)
else:
encoder_seq_data = None
cross_block_table = None
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
@ -1001,6 +1025,8 @@ class Scheduler:
token_chunk_size=token_chunk_size, token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums, computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but
@ -1032,6 +1058,8 @@ class Scheduler:
remaining: Deque[SequenceGroup] = deque() remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running: for seq_group in self.running:
if seq_group.is_finished(): if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list. # Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the # This list will be used to update the Mamba cache in the
# next step. # next step.

View File

@ -69,7 +69,7 @@ class EngineArgs:
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: Optional[bool] = None
max_context_len_to_capture: Optional[int] = None max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192 max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False

View File

@ -3,7 +3,7 @@ from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional) Mapping, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union from typing import Set, Tuple, Type, TypeVar, Union
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@ -22,7 +22,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
get_prompt_type)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
@ -42,7 +43,8 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter from vllm.utils import (Counter, is_embedding_model_config,
is_encoder_decoder_model_config)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
@ -502,8 +504,19 @@ class LLMEngine:
self.prompt_adapter_config.verify_with_model_config( self.prompt_adapter_config.verify_with_model_config(
self.model_config) self.model_config)
def _get_eos_token_id( def _get_bos_token_id(self,
self, lora_request: Optional[LoRARequest]) -> Optional[int]: lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
def _get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None: if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer " logger.warning("Using None for EOS token id because tokenizer "
"is not initialized") "is not initialized")
@ -511,6 +524,32 @@ class LLMEngine:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _get_decoder_start_token_id(self, ) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
dec_start_token_id = self._get_bos_token_id()
return dec_start_token_id
def _add_processed_request( def _add_processed_request(
self, self,
request_id: str, request_id: str,
@ -529,6 +568,16 @@ class LLMEngine:
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request) lora_request, prompt_adapter_request)
encoder_seq = None
if 'encoder_prompt_token_ids' in processed_inputs:
encoder_seq = Sequence(seq_id,
processed_inputs,
block_size,
eos_token_id,
lora_request,
prompt_adapter_request,
from_decoder_prompt=False)
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling( seq_group = self._create_sequence_group_with_sampling(
@ -538,7 +587,8 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling( seq_group = self._create_sequence_group_with_pooling(
request_id, request_id,
@ -546,7 +596,8 @@ class LLMEngine:
params, params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
else: else:
raise ValueError( raise ValueError(
"Either SamplingParams or PoolingParams must be provided.") "Either SamplingParams or PoolingParams must be provided.")
@ -562,6 +613,336 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
_LLMInputComponentsType = Tuple[str, List[int], ]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]] = None,
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id: Optional[int] = (
self._get_decoder_start_token_id())
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
(decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _tokenize_prompt(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[str] = None,
) -> List[int]:
'''
Wrapper around application of the model's
tokenizer.
Arguments:
* prompt
* request_id
* lora_request
Returns:
* prompt token ids
'''
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
def _extract_single_prompt_for_enc_dec_input(
self,
inputs: Optional[PromptInputs],
request_id: Optional[str] = None,
ptype: Optional[str] = None,
is_encoder_prompt: bool = False,
) -> Tuple[Optional[str], List[int]]:
'''
Only for encoder/decoder models:
Extract prompt & prompt_token_ids from any single
encoder or decoder input prompt. For encoder input prompts
in particular, also extract multi-modal data.
This function handles the following scenarios:
1. The user supplied a singleton encoder prompt
& the prompt/prompt-token-ids must be extracted.
2. The user supplied an explicit encoder/decoder
prompt & the prompt/prompt-token-ids must be
extracted from either the encoder and decoder prompts.
For decoder prompts in particular (scenario 2), special
processing is applied to the returned decoder token ids.
Arguments:
* request_id
* ptype: str representation of the input prompt type.
If `ptype` is `None`, assume that the prompt
type is unknown and must be inferred. This is the
case for ExplicitEncoderDecoder sub-prompts.
* inputs: single encoder or decoder input prompt
* is_encoder_prompt: True if encoder input prompt.
If False, decoder prompt tokens
are preprocessed.
Returns:
* prompt
* prompt_token_ids
'''
prompt_token_ids = None
ptype = (get_prompt_type(inputs) if ptype is None else ptype)
if inputs is None:
prompt = None
elif ptype == 'str':
prompt = inputs
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
)
elif ptype == 'TokensPrompt':
prompt = None
prompt_token_ids = inputs['prompt_token_ids']
else:
prompt = inputs['prompt']
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
)
if not is_encoder_prompt:
# Apply special pre-processing to
# decoder prompts
prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
prompt_token_ids, ))
assert prompt_token_ids is not None
return (
prompt,
prompt_token_ids,
)
def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
'''
bos_token_id = self._get_bos_token_id()
assert bos_token_id is not None
prompt_token_ids: List[int] = [bos_token_id]
return prompt_token_ids
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: Optional[str] = None,
) -> LLMInputs:
'''
For encoder/decoder models only:
Process an input prompt
into an `LLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* inputs: an input prompt
* request_id
Returns:
* `LLMInputs` instance
'''
ptype = get_prompt_type(inputs)
# Obtain encoder and decoder prompt tokens. Note
# that, no matter what, the decoder
# prompt type is unknown.
if ptype == "ExplicitEncoderDecoder":
# If input is explicit encoder/decoder prompt,
# then it remains to be determined what type
# of encoder prompt we have
extracted_encoder_prompt = inputs.get('encoder_prompt')
encoder_ptype = None
# Extract decoder prompt from explicit
# encoder/decoder prompt
extracted_decoder_prompt = inputs.get('decoder_prompt')
else:
# If input is singleton encoder prompt, then
# we know the encoder prompt type
extracted_encoder_prompt = inputs
encoder_ptype = ptype
# Decoder prompt is always unknown if
# encoder/decoder prompt is not explicit
extracted_decoder_prompt = None
# Invoke helper function to obtain encoder
# prompt and prompt token ids, either from
# singleton encoder prompt or from the
# encoder sub-prompt of an explicit
# encoder/decode scenario 2), special
# processing is applied to the returned decoder token ids
(
encoder_prompt,
encoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_encoder_prompt,
request_id=request_id,
ptype=encoder_ptype,
is_encoder_prompt=True,
)
# Invoke helper method to obtain
# decoder prompt and prompt token ids.
#
# The helper method will detect the decoder
# prompt type.
#
# Helper method will also apply special
# preprocessing unique to decoder prompts.
(
decoder_prompt,
decoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_decoder_prompt,
request_id=request_id,
ptype=None,
is_encoder_prompt=False,
)
return LLMInputs(
prompt_token_ids=decoder_prompt_token_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids,
encoder_prompt=encoder_prompt,
)
def _process_decoder_only_prompt(
self,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
request_id: Optional[str] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
'''
For decoder-only models:
Process an input prompt
into an `LLMInputs` instance.
Arguments:
* inputs: input prompt
* lora_request
* request_id
* prompt_adapter_request
Returns:
* `LLMInputs` instance
'''
if isinstance(inputs, str):
inputs = {"prompt": inputs}
prompt = inputs.get("prompt")
if "prompt_token_ids" not in inputs:
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=inputs.get("multi_modal_data"))
def process_model_inputs( def process_model_inputs(
self, self,
request_id: str, request_id: str,
@ -569,29 +950,25 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs: if self.is_encoder_decoder_model():
tokenizer = self.get_tokenizer_group("prompts must be None if " # Encoder-decoder model requires special mapping of
"skip_tokenizer_init is True") # input prompts to encoder & decoder
prompt_token_ids = tokenizer.encode(request_id=request_id, model_inputs = self._process_encoder_decoder_prompt(
prompt=inputs["prompt"], inputs,
lora_request=lora_request) request_id=request_id,
)
else: else:
prompt_token_ids = inputs["prompt_token_ids"] # Decoder-only operation
model_inputs = self._process_decoder_only_prompt(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
if prompt_adapter_request: return self.input_processor(model_inputs)
prompt_token_ids = \
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+ prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs)
def add_request( def add_request(
self, self,
@ -676,6 +1053,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
@ -701,7 +1079,8 @@ class LLMEngine:
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
return seq_group return seq_group
@ -713,6 +1092,7 @@ class LLMEngine:
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams.""" """Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
@ -724,7 +1104,8 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq)
return seq_group return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
@ -1214,3 +1595,9 @@ class LLMEngine:
seq_span.set_attribute( seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
def is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config)
def is_embedding_model(self):
return is_embedding_model_config(self.model_config)

View File

@ -121,12 +121,21 @@ class LLM:
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: int = 4,
cpu_offload_gb: float = 0, cpu_offload_gb: float = 0,
enforce_eager: bool = False, enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
'''
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size", removed_vision_keys = ("image_token_id", "image_feature_size",
@ -297,8 +306,8 @@ class LLM:
""" """
if self.llm_engine.model_config.embedding_mode: if self.llm_engine.model_config.embedding_mode:
raise ValueError( raise ValueError(
"LLM.generate() is only supported for generation models " "LLM.generate() is only supported for (conditional) generation "
"(XForCausalLM).") "models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None: if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( inputs = self._convert_v1_inputs(
@ -631,3 +640,9 @@ class LLM:
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()

View File

@ -1,5 +1,7 @@
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
TextPrompt, TokensPrompt, parse_and_batch_prompt) ParsedTokens, PromptInputs, SingletonPromptInputs,
TextPrompt, TokensPrompt, get_prompt_type,
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
@ -12,7 +14,18 @@ See also:
""" """
__all__ = [ __all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", "ParsedText",
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY", "ParsedTokens",
"InputContext", "InputRegistry" "parse_and_batch_prompt",
"TextPrompt",
"TokensPrompt",
"PromptInputs",
"LLMInputs",
"INPUT_REGISTRY",
"InputContext",
"InputRegistry",
"get_prompt_type",
"is_valid_encoder_decoder_llm_inputs",
"ExplicitEncoderDecoderPrompt",
"SingletonPromptInputs",
] ]

View File

@ -92,15 +92,114 @@ class TokensPrompt(TypedDict):
""" """
PromptInputs = Union[str, TextPrompt, TokensPrompt] SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
""" """
The inputs to the LLM, which can take one of the following forms: Set of possible schemas for a single LLM input:
- A text prompt (:class:`str` or :class:`TextPrompt`) - A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
A prompt of type SingletonPromptInputs may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
""" """
class ExplicitEncoderDecoderPrompt(TypedDict):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
SingletonPromptInputs schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an ExplicitEncoderDecoderPrompt may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure may not themselves
must be SingletonPromptInputs instances.
"""
encoder_prompt: SingletonPromptInputs
decoder_prompt: SingletonPromptInputs
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`)
"""
def _has_required_keys(
d: dict,
required_keys: set,
) -> bool:
return required_keys.issubset(d.keys())
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
"""
Get the type-name of the prompt argument instance, given that
isinstance() cannot apply to TypedDict subclasses directly.
If the prompt is None, return 'None' as the type name.
Arguments:
* prompt: LLM input prompt or None
Returns:
* String representation of prompt type
"""
if prompt is None:
return 'None'
required_keys_dict = {
'TextPrompt': {'prompt'},
'TokensPrompt': {'prompt_token_ids'},
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
}
if isinstance(prompt, dict):
for (ptype, required_keys) in required_keys_dict.items():
# Ignore type checking in the conditional below because type
# checker does not understand that is_dict(prompt) narrows
# down the possible types
if _has_required_keys(
prompt, # type: ignore
required_keys):
return ptype
raise ValueError(f"Invalid prompt {prompt}, valid types are "
"required_keys_dict={required_keys_dict}")
if isinstance(prompt, str):
return "str"
raise ValueError(f"Invalid prompt {prompt}")
class LLMInputs(TypedDict): class LLMInputs(TypedDict):
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
@ -114,8 +213,29 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
""" """
encoder_prompt_token_ids: NotRequired[List[int]]
"""The token IDs of the encoder prompt."""
encoder_prompt: NotRequired[Optional[str]]
"""
The original encoder prompt text corresponding to the token IDs, if
available.
"""
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
"""
Return True if the LLMInputs instance has the correct configuration
for encoder/decoder.
"""
# True if encoder prompt token ids field exists &
# is not None
return ('encoder_prompt_token_ids' in inputs
and inputs['encoder_prompt_token_ids'] is not None)

View File

@ -83,7 +83,16 @@ _EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
} }
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_MODELS = {
**_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_CONDITIONAL_GENERATION_MODELS
}
# Architecture -> type. # Architecture -> type.
# out of tree models # out of tree models

View File

@ -0,0 +1,996 @@
# Derived from BART implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model."""
import math
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BartConfig
from transformers.utils import logging
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
logger = logging.get_logger(__name__)
def get_bsz_seq_len(input_ids):
shp = input_ids.shape
ndim = len(shp)
if ndim == 1:
return 1, input_ids.numel()
else:
return shp[:2]
class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# Bart is set up so that if padding_idx is
# specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately.
# Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(
self,
positions: torch.Tensor,
attn_type: AttentionType,
) -> torch.Tensor:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
assert attn_type != AttentionType.ENCODER_DECODER
return super().forward(positions + self.offset)
class BartScaledWordEmbedding(VocabParallelEmbedding):
"""
This module overrides VocabParallelEmbedding's
forward by multiplying with embeddings scale.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) * self.embed_scale
class BartParallelLMHead(ParallelLMHead):
"""
This module overrides ParallelLMHead's
forward by dividing by embeddings scale,
yielding effectively the inverse of
BartScaledWordEmbedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) / self.embed_scale
class BartEncoderAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER)
output, _ = self.out_proj(attn_output)
return output
class BartDecoderSelfAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.DECODER)
output, _ = self.out_proj(attn_output)
return output
class BartCrossAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
output, _ = self.out_proj(attn_output)
return output
class BartEncoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartEncoderAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function,
quant_config)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Encoder layer output torch.Tensor
"""
residual = hidden_states
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
return hidden_states
class BartDecoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartDecoderSelfAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
self.activation_fn = get_act_fn(config.activation_function,
quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
'''
afeldman-nm: personally I would call this "cross-attention",
however I left the name as "encoder_attn" to maintain consistency
with the name of the pretrained weights.
'''
self.encoder_attn = BartCrossAttention(
self.embed_dim,
config.decoder_attention_heads,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
residual = decoder_hidden_states
# Self Attention
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
residual = hidden_states
hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class BartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
embed_dim = config.d_model
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
embed_dim,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList(
[BartEncoderLayer(config,cache_config,quant_config) \
for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
input_ids = input_ids.view(-1, input_ids.shape[-1])
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(
positions,
AttentionType.ENCODER,
)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
return hidden_states
class BartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config) \
for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
def forward(self, decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
inputs_embeds = self.embed_tokens(decoder_input_ids)
# embed positions
embed_pos = self.embed_positions(
decoder_positions,
AttentionType.DECODER,
)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
decoder_hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states,
)
return hidden_states
class BartModel(nn.Module):
_tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.encoder = BartEncoder(config,
cache_config,
quant_config=quant_config)
self.decoder = BartDecoder(config,
cache_config,
quant_config=quant_config)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs
class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model"
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
super().__init__()
self.config = config
self.model = BartModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
stacked_params_mapping = {
"q_proj": {
"param_name": "qkv_proj",
"shard_id": "q",
},
"k_proj": {
"param_name": "qkv_proj",
"shard_id": "k",
},
"v_proj": {
"param_name": "qkv_proj",
"shard_id": "v",
},
}
params_mapping = {
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
}
def _rename_key(self, key: str):
prefix = f"{self.base_model_prefix}."
key = key[len(prefix):] if key.startswith(prefix) else key
for src, dst in self.params_mapping.items():
key = key.replace(src, dst)
return key
def _rename_stacked_param(
self,
name: str,
) -> Tuple[str, Optional[str]]:
for key, mapping in self.stacked_params_mapping.items():
if key in name:
name = name.replace(key, mapping["param_name"])
return name, mapping["shard_id"]
return name, None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters())
weights_tuple_list = list(weights)
shared_embedding_weight = None
shared_embedding_shard_id = None
for name, loaded_weight in weights_tuple_list:
name = self._rename_key(name)
name, shard_id = self._rename_stacked_param(name)
if ('shared.weight' in name
or 'encoder.embed_tokens.weight' in name
or 'decoder.embed_tokens.weight' in name
or 'lm_head.weight' in name):
assert shared_embedding_weight is None, (
"Conflicting embedding weights.")
shared_embedding_weight = loaded_weight
shared_embedding_shard_id = shard_id
else:
# Skip the specific downstream task weight.
if name.startswith('cls.'):
continue
# use Pooler instead.
if name.startswith('pooler.'):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in model_params_dict:
continue
param = model_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if shard_id:
weight_loader(param, loaded_weight, shard_id)
else:
weight_loader(param, loaded_weight)
# Assign shared weight values
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
default_weight_loader)
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
default_weight_loader)
lm_head_in_param = top_params_dict['lm_head.weight']
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
default_weight_loader)
assert shared_embedding_weight is not None
if shared_embedding_shard_id:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
shared_embedding_shard_id)
else:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)

View File

@ -70,12 +70,20 @@ class RequestOutput:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt string of the request. prompt: The prompt string of the request.
For encoder/decoder models, this is the
decoder input prompt.
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
For encoder/decoder models, this is the
decoder input prompt token ids.
prompt_logprobs: The log probabilities to return per prompt token. prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
metrics: Metrics associated with the request. metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output. lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request;
None if decoder-only
encoder_prompt_token_ids: The token IDs of the encoder prompt;
None if decoder-only
""" """
def __init__( def __init__(
@ -88,6 +96,8 @@ class RequestOutput:
finished: bool, finished: bool,
metrics: Optional[RequestMetrics] = None, metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
@ -97,6 +107,8 @@ class RequestOutput:
self.finished = finished self.finished = finished
self.metrics = metrics self.metrics = metrics
self.lora_request = lora_request self.lora_request = lora_request
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
@ -137,6 +149,8 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished() finished = seq_group.is_finished()
finished_time = time.time() if finished else None finished_time = time.time() if finished else None
@ -148,12 +162,16 @@ class RequestOutput:
outputs, outputs,
finished, finished,
seq_group.metrics, seq_group.metrics,
lora_request=seq_group.lora_request) lora_request=seq_group.lora_request,
encoder_prompt=encoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"encoder_prompt={self.encoder_prompt!r}, "
f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished}, " f"finished={self.finished}, "

View File

@ -7,10 +7,11 @@ from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union) Union, cast)
import torch import torch
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
@ -244,13 +245,26 @@ class SequenceData:
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence. """Stores the data, status, and block information of a sequence.
The sequence is constructed from the LLMInputs instance passed
in through the `inputs` constructor argument.
For encoder/decoder models, LLMInputs encapsulates both a
decoder and encoder prompt, creating an ambiguity about which
prompt to construct the sequence from. The `from_decoder_prompt`
constructor argument signals whether to construct the Sequence
from the LLMInputs decoder prompt, or encoder prompt.
Args: Args:
seq_id: The ID of the sequence. seq_id: The ID of the sequence.
inputs: The inputs of the sequence. inputs: The inputs of the sequence.
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request. lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
(True) or encoder prompt (False.) Must be True
for decoder-only model.
""" """
@ -261,7 +275,8 @@ class Sequence:
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
from_decoder_prompt: bool = True,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.inputs = inputs self.inputs = inputs
@ -269,6 +284,36 @@ class Sequence:
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None
# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
#
# For encoder/decoder models the same `inputs`
# instance could be utilized to construct either an
# encoder sequence or a decoder sequence, because
# `LLMInputs` has both decoder- and encoder-oriented
# member variables (i.e. it encapsulates both an encoder
# and a decoder prompt.) The decision of which type of sequence
# to generate is determined by the `from_decoder_prompt` argument.
#
# When constructing a encoder sequence
# (`from_decoder_prompt` False) it matters that
# the `LLMInputs` instance stored in `inputs` is valid
# in the sense that its encoder-related member variables are
# populated; below, an exception is raised if this is
# not the case.
#
# When constructing a decoder sequence (`from_decoder_prompt` True)
# it does not matter whether `inputs` has its encoder-related
# member variables populated.
if not (from_decoder_prompt
or is_valid_encoder_decoder_llm_inputs(inputs)):
raise ValueError("Cannot extract encoder input prompt from "
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData(self.prompt_token_ids) self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
@ -289,11 +334,35 @@ class Sequence:
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
return self.inputs.get("prompt") if self._prompt is not None:
# Reuse precomputed prompt string
return self._prompt
# Select decoder or encoder input prompt str,
# as appropriate
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")
# Cache prompt
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
@property @property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"] if self._prompt_token_ids is not None:
# Reuse precomputed prompt token ids
return self._prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")
# Cache computed prompt token ids
self._prompt_token_ids = cast(List[int],
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
@ -472,6 +541,22 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return self.seqs[0].prompt_token_ids return self.seqs[0].prompt_token_ids
@property
def encoder_prompt(self) -> Optional[str]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt is distinct
# from the decoder's.
return (self.encoder_seq.prompt
if self.encoder_seq is not None else None)
@property
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are
# distinct from the decoder's.
return (self.encoder_seq.prompt_token_ids
if self.encoder_seq is not None else None)
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data. # All sequences in the group should have the same multi-modal data.

View File

@ -27,10 +27,93 @@ from typing_extensions import ParamSpec
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
SingletonPromptInputs)
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
"Chunked prefill for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
"Models with logits_soft_cap "
"require FlashInfer backend, which is "
"currently not supported for encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
"supported with encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
"currently supported with "
"encoder/decoder models.")
STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
"supported with encoder/decoder "
"models.")
STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"decoder models.")
# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
"STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
"STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
"STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
"STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
"STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}
# Constants related to forcing the attention backend selection
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock""" """Utility function to run async task in a lock"""
async with lock: async with lock:
return await task(*args, **kwargs) return await task(*args, **kwargs)
def is_encoder_decoder_model_config(model_config) -> bool:
'''
Extract the HF encoder/decoder model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
getattr(model_config.hf_config,
"is_encoder_decoder",
False)
def is_embedding_model_config(model_config) -> bool:
'''
Extract the embedding model flag from the ModelConfig instance.
Return False if model_config is None.
'''
return model_config is not None and \
model_config.embedding_mode
def build_explicit_enc_dec_prompt(
encoder_prompt: SingletonPromptInputs,
decoder_prompt: SingletonPromptInputs,
) -> ExplicitEncoderDecoderPrompt:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompt_lists(
enc_prompt_list: List[SingletonPromptInputs],
dec_prompt_list: List[SingletonPromptInputs],
) -> List[ExplicitEncoderDecoderPrompt]:
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt,
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
) -> List[Tuple[PromptInputs, PromptInputs]]:
return [(enc_dec_prompt['encoder_prompt'],
enc_dec_prompt['decoder_prompt'])
for enc_dec_prompt in enc_dec_prompts]

View File

@ -0,0 +1,472 @@
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, cast
import torch
import torch.distributed
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase,
ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInput":
return cast(
EncoderDecoderModelInput,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
_model_input_cls: Type[EncoderDecoderModelInput] = (
EncoderDecoderModelInput)
_builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
):
'''
EncoderDecoderModelRunner constructor.
`lora_config`, `multimodal_config`, and prompt_adapter_config are
unused (since these features are not yet supported for encoder/decoder
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self._maybe_force_supported_attention_backend()
super().__init__(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=None,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
)
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self)
def _maybe_force_supported_attention_backend(self):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
'''
def raise_backend_err():
# The user has specified an attention backend override
# which is invalid for encoder/decoder models
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
maybe_env_var_forced_backend = get_env_variable_attn_backend()
maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None
if not (is_forced_by_global or is_forced_by_env_var):
# The user has not already specified an attention backend
# override
logger.info("EncoderDecoderModelRunner requires "
"XFormers backend; overriding backend "
"auto-selection and forcing XFormers.")
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if maybe_global_forced_backend != _Backend.XFORMERS:
raise_backend_err()
elif is_forced_by_env_var:
# Backend override enforced by vLLM backend
# environment variable
if maybe_env_var_forced_backend != _Backend.XFORMERS:
raise_backend_err()
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
@torch.inference_mode()
def execute_model(
self,
model_input: EncoderDecoderModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[PoolerOutput]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner")
model_executable = self.model
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**seqlen_agnostic_kwargs)
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if not self.is_driver_worker:
return []
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInput:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
Since chunked prefill is not supported for encoder/decoder models,
`input_tokens` is assumed to be either entirely prefill tokens or
entirely decode tokens.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input))
# Inject attn_metadata encoder/cross-attention fields &
# encoder input tokens/positions into model_input.
# Frozen dataclass fields cannot be modified, so use
# dataclasses.replace to construct a new model input
# instance.
model_input = dataclasses.replace(
model_input,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt,
virtual_engine=virtual_engine)
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
model_config = self.model_config
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, _ = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
encoder_seq_data=seq_data,
cross_block_table=None,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInput,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
is_profile_run = (seq_group_metadata.block_tables is None)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
else:
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
# Convert cross-attention block tables to encoder input tensor
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max(
len(block_table) for block_table in cross_block_tables),
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)

56
vllm/worker/utils.py Normal file
View File

@ -0,0 +1,56 @@
'''
Worker-related helper functions.
'''
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
from vllm.worker.model_runner import GPUModelRunnerBase
def assert_enc_dec_mr_supported_scenario(
enc_dec_mr: GPUModelRunnerBase) -> None:
'''
Asserted that the provided encoder/decoder model runner instance reflects
a supported scenario.
'''
if enc_dec_mr.cache_config.enable_prefix_caching:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
if enc_dec_mr.sliding_window is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
None) is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
)
if enc_dec_mr.lora_config is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.multimodal_config is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
if not enc_dec_mr.model_config.enforce_eager:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'])
if enc_dec_mr.prompt_adapter_config is not None:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])

View File

@ -19,8 +19,11 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import (is_embedding_model_config,
is_encoder_decoder_model_config)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
@ -85,8 +88,10 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None: if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls ModelRunnerClass = model_runner_cls
elif self.model_config.embedding_mode: elif self._is_embedding_model():
ModelRunnerClass = EmbeddingModelRunner ModelRunnerClass = EmbeddingModelRunner
elif self._is_encoder_decoder_model():
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config, model_config,
parallel_config, parallel_config,
@ -107,6 +112,12 @@ class Worker(LocalOrDistributedWorkerBase):
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
def _is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config)
def _is_embedding_model(self):
return is_embedding_model_config(self.model_config)
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until # torch.distributed.all_reduce does not free the input tensor until