[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
660470e5a3
commit
fd95e026e0
@ -148,8 +148,9 @@ steps:
|
||||
- python3 cpu_offload.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 llava_example.py
|
||||
- python3 offline_inference_vision_language.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
|
||||
- label: Models Test # 1hr10min
|
||||
source_file_dependencies:
|
||||
@ -289,6 +290,7 @@ steps:
|
||||
commands:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
|
||||
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
|
99
examples/offline_inference_encoder_decoder.py
Normal file
99
examples/offline_inference_encoder_decoder.py
Normal file
@ -0,0 +1,99 @@
|
||||
'''
|
||||
Demonstrate prompting of text-to-text
|
||||
encoder/decoder models, specifically BART
|
||||
'''
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.utils import zip_enc_dec_prompt_lists
|
||||
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
|
||||
prompt="The capital of France is"))
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
)
|
||||
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompt_lists(
|
||||
['An encoder prompt', 'Another encoder prompt'],
|
||||
['A decoder prompt', 'Another decoder prompt'])
|
||||
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
prompts = [
|
||||
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||
] + zipped_prompt_list
|
||||
|
||||
print(prompts)
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Encoder prompt: {encoder_prompt!r}, "
|
||||
f"Decoder prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
@ -10,9 +10,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
||||
AutoTokenizer, BatchEncoding, BatchFeature)
|
||||
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
|
||||
BatchFeature)
|
||||
|
||||
from tests.models.utils import DecoderPromptType
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
@ -21,9 +23,11 @@ from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
is_cpu)
|
||||
is_cpu, to_enc_dec_tuple_list,
|
||||
zip_enc_dec_prompt_lists)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -120,6 +124,40 @@ def example_prompts() -> List[str]:
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_encoder_decoder_prompts() \
|
||||
-> Dict[DecoderPromptType,
|
||||
Tuple[List[str], List[Optional[str]]]]:
|
||||
'''
|
||||
Returns an encoder prompt list and a decoder prompt list, wherein each pair
|
||||
of same-index entries in both lists corresponds to an (encoder prompt,
|
||||
decoder prompt) tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
* Encoder prompt list
|
||||
* Decoder prompt list (reverse of encoder prompt list)
|
||||
'''
|
||||
|
||||
encoder_prompts = []
|
||||
for filename in _TEST_PROMPTS:
|
||||
encoder_prompts += _read_prompts(filename)
|
||||
|
||||
custom_decoder_prompts = encoder_prompts[::-1]
|
||||
empty_str_decoder_prompts = [""] * len(encoder_prompts)
|
||||
none_decoder_prompts = [None] * len(encoder_prompts)
|
||||
|
||||
# NONE decoder prompt type
|
||||
return {
|
||||
DecoderPromptType.NONE:
|
||||
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
|
||||
DecoderPromptType.EMPTY_STR:
|
||||
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
|
||||
DecoderPromptType.CUSTOM:
|
||||
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_long_prompts() -> List[str]:
|
||||
prompts = []
|
||||
@ -152,6 +190,7 @@ class HfRunner:
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
is_embedding_model: bool = False,
|
||||
is_vision_model: bool = False,
|
||||
is_encoder_decoder_model: bool = False,
|
||||
) -> None:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
@ -168,6 +207,8 @@ class HfRunner:
|
||||
else:
|
||||
if is_vision_model:
|
||||
auto_cls = AutoModelForVision2Seq
|
||||
elif is_encoder_decoder_model:
|
||||
auto_cls = AutoModelForSeq2SeqLM
|
||||
else:
|
||||
auto_cls = AutoModelForCausalLM
|
||||
|
||||
@ -314,6 +355,44 @@ class HfRunner:
|
||||
all_logprobs.append(seq_logprobs)
|
||||
return all_logprobs
|
||||
|
||||
def _hidden_states_to_logprobs(
|
||||
self,
|
||||
hidden_states,
|
||||
num_logprobs,
|
||||
) -> Tuple[List[Dict[int, float]], int]:
|
||||
seq_logprobs: List[torch.Tensor] = []
|
||||
output_len = len(hidden_states)
|
||||
for _, hidden_state in enumerate(hidden_states):
|
||||
last_hidden_states = hidden_state[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states,
|
||||
self.model.get_output_embeddings().weight.t(),
|
||||
)
|
||||
if getattr(self.model.get_output_embeddings(), "bias",
|
||||
None) is not None:
|
||||
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
|
||||
# convert to dict
|
||||
seq_logprobs_lst: List[Dict[int, float]] = []
|
||||
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||
# drop prompt logprobs
|
||||
if tok_idx == 0:
|
||||
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
|
||||
topk = tok_logprobs.topk(num_logprobs)
|
||||
|
||||
tok_logprobs_dct = {}
|
||||
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
|
||||
tok_logprobs_dct[token_id.item()] = logprob.item()
|
||||
|
||||
seq_logprobs_lst.append(tok_logprobs_dct)
|
||||
|
||||
return (
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
)
|
||||
|
||||
def generate_greedy_logprobs_limit(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -346,33 +425,11 @@ class HfRunner:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
seq_logprobs: List[torch.Tensor] = []
|
||||
for _, hidden_states in enumerate(output.hidden_states):
|
||||
last_hidden_states = hidden_states[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states,
|
||||
self.model.get_output_embeddings().weight.t(),
|
||||
)
|
||||
if getattr(self.model.get_output_embeddings(), "bias",
|
||||
None) is not None:
|
||||
logits += self.model.get_output_embeddings(
|
||||
).bias.unsqueeze(0)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
|
||||
# convert to dict
|
||||
seq_logprobs_lst: List[Dict[int, float]] = []
|
||||
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||
# drop prompt logprobs
|
||||
if tok_idx == 0:
|
||||
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
|
||||
topk = tok_logprobs.topk(num_logprobs)
|
||||
|
||||
tok_logprobs_dct = {}
|
||||
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
|
||||
tok_logprobs_dct[token_id.item()] = logprob.item()
|
||||
|
||||
seq_logprobs_lst.append(tok_logprobs_dct)
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states,
|
||||
num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
@ -385,6 +442,57 @@ class HfRunner:
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs_limit(
|
||||
self,
|
||||
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
||||
'''
|
||||
Greedy logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
all_logprobs: List[List[Dict[int, float]]] = []
|
||||
all_output_ids: List[List[int]] = []
|
||||
all_output_strs: List[str] = []
|
||||
|
||||
for (encoder_prompt,
|
||||
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
|
||||
encoder_input_ids = self.wrap_device(
|
||||
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
|
||||
decoder_input_ids = (
|
||||
None if decoder_prompt is None else self.wrap_device(
|
||||
self.tokenizer(decoder_prompt,
|
||||
return_tensors="pt").input_ids))
|
||||
|
||||
output = self.model.generate(
|
||||
encoder_input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
|
||||
num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
output_ids = seq_ids[-output_len:]
|
||||
all_output_ids.append(output_ids.tolist())
|
||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||
|
||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
||||
return self.model.encode(prompts)
|
||||
|
||||
@ -416,7 +524,7 @@ class VllmRunner:
|
||||
block_size: int = 16,
|
||||
enable_chunked_prefill: bool = False,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
enforce_eager: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
@ -465,6 +573,19 @@ class VllmRunner:
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
def _final_steps_generate_w_logprobs(
|
||||
self,
|
||||
req_outputs: List[RequestOutput],
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||
for req_output in req_outputs:
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = sample.token_ids
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append((output_ids, output_str, output_logprobs))
|
||||
return outputs
|
||||
|
||||
def generate_w_logprobs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -483,14 +604,21 @@ class VllmRunner:
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||
for req_output in req_outputs:
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = sample.token_ids
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append((output_ids, output_str, output_logprobs))
|
||||
return outputs
|
||||
return self._final_steps_generate_w_logprobs(req_outputs)
|
||||
|
||||
def generate_encoder_decoder_w_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
'''
|
||||
Logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
assert sampling_params.logprobs is not None
|
||||
req_outputs = self.model.generate(encoder_decoder_prompts,
|
||||
sampling_params=sampling_params)
|
||||
return self._final_steps_generate_w_logprobs(req_outputs)
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
@ -523,6 +651,26 @@ class VllmRunner:
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||
use_beam_search=False,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs)
|
||||
'''
|
||||
Greedy logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
outputs = self.generate_encoder_decoder_w_logprobs(
|
||||
encoder_decoder_prompts, greedy_logprobs_params)
|
||||
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: List[str],
|
||||
|
@ -9,33 +9,11 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(out, token_id: int):
|
||||
seq_groups = get_sequence_groups(out)
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
from .utils import (append_new_token, append_new_token_seq_group,
|
||||
create_dummy_prompt, get_sequence_groups,
|
||||
schedule_and_update_computed_tokens)
|
||||
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
|
99
tests/core/test_scheduler_encoder_decoder.py
Normal file
99
tests/core/test_scheduler_encoder_decoder.py
Normal file
@ -0,0 +1,99 @@
|
||||
from typing import List
|
||||
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
|
||||
get_sequence_groups, schedule_and_update_computed_tokens)
|
||||
|
||||
|
||||
def test_scheduler_schedule_simple_encoder_decoder():
|
||||
'''
|
||||
Test basic scheduler functionality in the context
|
||||
of an encoder/decoder model. Focus on testing
|
||||
enc/dec-specific functionality sense tests already
|
||||
exist for decoder-only functionality
|
||||
|
||||
Test behavior:
|
||||
* Construct Scheduler
|
||||
* Construct dummy encoder/decoder sequence groups
|
||||
* Add dummy seq groups to scheduler backlog
|
||||
* Schedule the next seq group & validate:
|
||||
* Cross-attn block tables
|
||||
* Updated states of seq groups
|
||||
* Number of batched tokens
|
||||
* Number of blocks to copy/swap-in/swap-out
|
||||
* Number of scheduled seq groups
|
||||
* Repeat for both prefill- and decode-phase
|
||||
* Abort scheduled seq groups
|
||||
* Assert that aborted seq groups no longer appear in
|
||||
cross-attention block table
|
||||
'''
|
||||
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
|
||||
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
req_id_list = []
|
||||
for i in range(num_seq_group):
|
||||
req_id = str(i)
|
||||
req_id_list.append(req_id)
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
req_id, block_size, block_size, block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Schedule seq groups prefill.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# - Verify that sequence group cross-attention block tables are
|
||||
# registered with the block manager
|
||||
assert all([(req_id in scheduler.block_manager.cross_block_tables)
|
||||
for req_id in req_id_list])
|
||||
# - Validate sequence-group status
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# - Validate number of batched tokens
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
# - Validate there are no remaining blocks to swap
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
# - Validate all seq groups were scheduled
|
||||
assert len(seq_group_meta_list) == num_seq_group
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Schedule seq groups decode.
|
||||
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# - Verify that sequence group metadata includes encoder attention
|
||||
# and cross-attention metadata
|
||||
assert all([
|
||||
not ((seq_group_meta.encoder_seq_data is None) or
|
||||
(seq_group_meta.cross_block_table is None))
|
||||
for seq_group_meta in seq_group_meta_list
|
||||
])
|
||||
# - Validate sequence-group status
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# - Validate there is one batched token per seq group
|
||||
assert out.num_batched_tokens == num_seq_group
|
||||
# - Validate there are no remaining blocks to swap
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
# - Validate that all seq groups were scheduled
|
||||
assert len(seq_group_meta_list) == num_seq_group
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Abort sequences
|
||||
for req_id in req_id_list:
|
||||
scheduler.abort_seq_group(req_id)
|
||||
# - Verify that sequence group cross-attention block tables are
|
||||
# NO LONGER registered with the block manager
|
||||
assert req_id not in scheduler.block_manager.cross_block_tables
|
@ -53,27 +53,30 @@ def create_dummy_prompt_encoder_decoder(
|
||||
block_size = decoder_prompt_length
|
||||
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
# and prompt "0 ... block_size". Note that the prompt string
|
||||
# doesn't actually match the tokens
|
||||
decoder_prompt_tokens = list(range(decoder_prompt_length))
|
||||
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
|
||||
|
||||
decoder_prompt = Sequence(int(request_id),
|
||||
inputs={
|
||||
"prompt": decoder_prompt_str,
|
||||
"prompt_token_ids": decoder_prompt_tokens,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=block_size)
|
||||
|
||||
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
|
||||
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
|
||||
|
||||
inputs = {
|
||||
"prompt": decoder_prompt_str,
|
||||
"prompt_token_ids": decoder_prompt_tokens,
|
||||
"encoder_prompt": encoder_prompt_str,
|
||||
"encoder_prompt_token_ids": encoder_prompt_tokens,
|
||||
"multi_modal_data": None,
|
||||
}
|
||||
|
||||
decoder_prompt = Sequence(int(request_id),
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
from_decoder_prompt=True)
|
||||
|
||||
encoder_prompt = Sequence(int(request_id),
|
||||
inputs={
|
||||
"prompt": encoder_prompt_str,
|
||||
"prompt_token_ids": encoder_prompt_tokens,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=block_size)
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
from_decoder_prompt=False)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[decoder_prompt],
|
||||
sampling_params=SamplingParams(
|
||||
@ -139,17 +142,21 @@ def create_seq_group_encoder_decoder(
|
||||
|
||||
prompt_token_ids = [0] * seq_prompt_len
|
||||
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"encoder_prompt": "",
|
||||
"encoder_prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
}
|
||||
|
||||
seqs = []
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
inputs={
|
||||
"prompt": "",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=16,
|
||||
)
|
||||
# Construct decoder input sequences
|
||||
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
|
||||
inputs=inputs,
|
||||
block_size=16,
|
||||
from_decoder_prompt=True)
|
||||
|
||||
for i in range(output_len):
|
||||
seq.append_token_id(
|
||||
@ -158,16 +165,11 @@ def create_seq_group_encoder_decoder(
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
# Encoder sequence
|
||||
encoder_seq = Sequence(
|
||||
seq_id=seq_id_start + len(seq_output_lens),
|
||||
inputs={
|
||||
"prompt": "",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=16,
|
||||
)
|
||||
# Encoder input sequence
|
||||
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
|
||||
inputs=inputs,
|
||||
block_size=16,
|
||||
from_decoder_prompt=False)
|
||||
|
||||
return SequenceGroup(request_id=request_id,
|
||||
seqs=seqs,
|
||||
@ -177,4 +179,31 @@ def create_seq_group_encoder_decoder(
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
# Helper functions for scheduler tests
|
||||
|
||||
|
||||
def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(out, token_id: int):
|
||||
seq_groups = get_sequence_groups(out)
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
101
tests/distributed/test_basic_distributed_correctness_enc_dec.py
Normal file
101
tests/distributed/test_basic_distributed_correctness_enc_dec.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""For encoder/decoder models only:
|
||||
Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
|
||||
Run:
|
||||
```sh
|
||||
cd $VLLM_PATH/tests
|
||||
|
||||
pytest distributed/test_basic_distributed_correctness_enc_dec.py
|
||||
```
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import DecoderPromptType
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from ..models.utils import check_logprobs_close
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("model, distributed_executor_backend", [
|
||||
("facebook/bart-large-cnn", "ray"),
|
||||
("facebook/bart-large-cnn", "mp"),
|
||||
])
|
||||
@fork_new_process_for_each_test
|
||||
def test_models(
|
||||
model: str,
|
||||
distributed_executor_backend: str,
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_encoder_decoder_prompts,
|
||||
) -> None:
|
||||
'''
|
||||
Test vLLM BART inference on more than one GPU, comparing
|
||||
outputs against HF as a baseline.
|
||||
|
||||
Fork a new process for each test, to prevent CUDA from
|
||||
being re-initialized by successive tests within the same
|
||||
process.
|
||||
|
||||
Arguments:
|
||||
|
||||
* model: the HF ID of the specific BART variant under test
|
||||
* distributed_executor_backend
|
||||
* hf_runner: HuggingFace (HF) test model runner
|
||||
* vllm_runner: vLLM test model runner
|
||||
* example_encoder_decoder_prompts: test fixture which provides a
|
||||
dictionary of dummy prompts
|
||||
'''
|
||||
|
||||
dtype = "float"
|
||||
max_tokens = 64
|
||||
num_logprobs = 5
|
||||
|
||||
# Example inputs with non-trivial (i.e. not None/empty) encoder &
|
||||
# decoder prompts.
|
||||
test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
test_prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Configuration settings for HF baseline
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
is_encoder_decoder_model=True) as hf_model:
|
||||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
@ -3,9 +3,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
|
||||
override_backend_env_variable)
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import which_attn_to_use
|
||||
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -4,8 +4,6 @@ Tests:
|
||||
* E2E test of Encoder attention + Decoder self-attention +
|
||||
Encoder/decoder cross-attention (collectively
|
||||
"encoder/decoder attention")
|
||||
* Confirm enc/dec models will fail for chunked prefill
|
||||
* Confirm enc/dec models will fail for prefix caching
|
||||
|
||||
"""
|
||||
|
||||
@ -15,19 +13,22 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import *
|
||||
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
|
||||
|
||||
HEAD_SIZES = [64, 256]
|
||||
|
||||
NUM_HEADS = [1, 16]
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
BLOCK_SIZES = [16]
|
||||
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
|
||||
CUDA_DEVICE = "cuda:0"
|
||||
|
||||
MAX_DEC_SEQ_LENS = [128]
|
||||
@ -724,57 +725,92 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
||||
batch_size: int, block_size: int, max_dec_seq_len: int,
|
||||
max_enc_seq_len: int, monkeypatch):
|
||||
def test_encoder_only(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
attn_backend: _Backend,
|
||||
batch_size: int,
|
||||
block_size: int,
|
||||
max_dec_seq_len: int,
|
||||
max_enc_seq_len: int,
|
||||
):
|
||||
'''
|
||||
End-to-end encoder-only attention test:
|
||||
|
||||
* Construct fake test vectors for (1) encoder attention
|
||||
* Construct (1) attention metadata structure with prefill-phase
|
||||
encoder attention, and (2) an analogous attention metadata
|
||||
structure but for decode-phase
|
||||
* Test & validate encoder attention against ideal output
|
||||
|
||||
No KV cache is required for encoder-only attention.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
|
||||
This test globally forces an override of the usual backend
|
||||
auto-selection process, forcing the specific backend-under-test
|
||||
to be utilized.
|
||||
|
||||
Arguments:
|
||||
|
||||
* num_heads
|
||||
* head_size,
|
||||
* attn_backend: The attention backend to employ for testing
|
||||
* batch_size
|
||||
* block_size: KV cache block size
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
# Shared prefill metadata structure
|
||||
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
None,
|
||||
decoder_test_params=None,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=None,
|
||||
device=CUDA_DEVICE)
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
None,
|
||||
decoder_test_params=None,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=None,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
|
||||
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||
@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
||||
def test_e2e_enc_dec_attn(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
backend_name: str,
|
||||
attn_backend: _Backend,
|
||||
batch_size: int,
|
||||
block_size: int,
|
||||
max_dec_seq_len: int,
|
||||
max_enc_seq_len: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-end encoder/decoder test:
|
||||
@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn(
|
||||
cross-attention K/Vs are allowed to differ in seq len, as is often the case
|
||||
for cross-attention.
|
||||
|
||||
This test utilizes PyTest monkey patching to force the attention backend
|
||||
via an environment variable.
|
||||
This test globally forces an override of the usual backend
|
||||
auto-selection process, forcing the specific backend-under-test
|
||||
to be utilized.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
@ -830,124 +866,136 @@ def test_e2e_enc_dec_attn(
|
||||
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
||||
and a single one shared by all decode-phase attention operations
|
||||
(decoder & enc/dec cross.) This is intended to reflect the behavior
|
||||
of ModelRunner, which constructs a single attention metadata structure for
|
||||
each prefill or decode run. A realistic scenario would rely on the
|
||||
attention backend to utilize the appropriate attention metadata fields
|
||||
according to the value of attn_metadata.attention_type. Thus, this test is
|
||||
organized so as to confirm that the backend-under-test can handle a
|
||||
shared prefill attention metadata structure & a shared decode attention
|
||||
metadata structure.
|
||||
of EncoderDecoderModelRunner, which constructs a single attention metadata
|
||||
structure for each prefill or decode run. A realistic scenario would rely
|
||||
on the attention backend to utilize the appropriate attention metadata
|
||||
fields according to the value of attn_metadata.attention_type. Thus,
|
||||
this test is organized so as to confirm that the backend-under-test can
|
||||
handle a shared prefill attention metadata structure & a shared decode\
|
||||
attention metadata structure.
|
||||
|
||||
Arguments:
|
||||
|
||||
* num_heads
|
||||
* head_size,
|
||||
* attn_backend: The attention backend to employ for testing
|
||||
* batch_size
|
||||
* block_size: KV cache block size
|
||||
* max_dec_seq_len: max length of decoder input sequences
|
||||
* max_enc_seq_len: max length of encoder input sequences
|
||||
'''
|
||||
|
||||
# Force Attention wrapper backend
|
||||
override_backend_env_variable(monkeypatch, backend_name)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||
# to be more than necessary, since exceeding the kv cache size
|
||||
# is not part of this test
|
||||
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||
batch_size, block_size, max_dec_seq_len,
|
||||
max_enc_seq_len, 4096)
|
||||
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
# Attention scale factor, attention backend instance, attention wrapper
|
||||
# instance, KV cache init
|
||||
test_rsrcs = _make_test_resources(test_pt)
|
||||
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
# Construct encoder attention test params (only used
|
||||
# during prefill)
|
||||
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||
# test params, including query/key/value tensors, decoder self-attention
|
||||
# memory-mapping. cross_block_base_addr is the uppermost address in the
|
||||
# decoder self-attention block-table, i.e. a base address which the
|
||||
# encoder/decoder cross-attention block-table may build downward toward.
|
||||
# Construct Decoder self-attention prefill-phase & decode-phase
|
||||
# test params, including query/key/value tensors, decoder self-attention
|
||||
# memory-mapping. cross_block_base_addr is the uppermost address in the
|
||||
# decoder self-attention block-table, i.e. a base address which the
|
||||
# encoder/decoder cross-attention block-table may build downward toward.
|
||||
|
||||
(
|
||||
dec_qkv,
|
||||
prephase_dec_test_params,
|
||||
decphase_dec_test_params,
|
||||
cross_block_base_addr,
|
||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||
(
|
||||
dec_qkv,
|
||||
prephase_dec_test_params,
|
||||
decphase_dec_test_params,
|
||||
cross_block_base_addr,
|
||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||
|
||||
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
|
||||
# test params, including key/value tensors, cross-attention memory-mapping
|
||||
# Construct encoder/decoder cross-attention prefill-phase
|
||||
# & decode-phase test params, including key/value tensors,
|
||||
# cross-attention memory-mapping
|
||||
|
||||
(
|
||||
prephase_cross_test_params,
|
||||
decphase_cross_test_params,
|
||||
) = _enc_dec_cross_attn_setup_reuses_query(
|
||||
dec_qkv,
|
||||
enc_test_params,
|
||||
prephase_dec_test_params,
|
||||
test_pt,
|
||||
test_rsrcs,
|
||||
block_base_addr=cross_block_base_addr)
|
||||
(
|
||||
prephase_cross_test_params,
|
||||
decphase_cross_test_params,
|
||||
) = _enc_dec_cross_attn_setup_reuses_query(
|
||||
dec_qkv,
|
||||
enc_test_params,
|
||||
prephase_dec_test_params,
|
||||
test_pt,
|
||||
test_rsrcs,
|
||||
block_base_addr=cross_block_base_addr)
|
||||
|
||||
# Shared prefill metadata structure
|
||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||
decoder_test_params=prephase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=prephase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
# Shared prefill metadata structure
|
||||
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
|
||||
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
True,
|
||||
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
|
||||
decoder_test_params=prephase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=prephase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# PREFILL: encoder attention
|
||||
# PREFILL: encoder attention
|
||||
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
|
||||
enc_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
# - Is encoder attention result correct?
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
# PREFILL: decoder self-attention test
|
||||
# PREFILL: decoder self-attention test
|
||||
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
|
||||
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
# - Is prefill decoder self-attention correct?
|
||||
assert_actual_matches_ideal(prephase_dec_test_params,
|
||||
prephase_dec_pckd_act_out)
|
||||
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
# PREFILL: encoder/decoder cross-attention test
|
||||
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
|
||||
prephase_attn_metadata)
|
||||
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
# - Is prefill encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(prephase_cross_test_params,
|
||||
prephase_cross_pckd_act_out)
|
||||
|
||||
# DECODE: build decode-phase attention metadata
|
||||
# DECODE: build decode-phase attention metadata
|
||||
|
||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
False,
|
||||
dec_qkv.q_seq_lens,
|
||||
decoder_test_params=decphase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=decphase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
|
||||
test_rsrcs.attn_backend,
|
||||
False,
|
||||
dec_qkv.q_seq_lens,
|
||||
decoder_test_params=decphase_dec_test_params,
|
||||
encoder_test_params=enc_test_params,
|
||||
cross_test_params=decphase_cross_test_params,
|
||||
device=CUDA_DEVICE)
|
||||
|
||||
# DECODE: decoder self-attention test
|
||||
# DECODE: decoder self-attention test
|
||||
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
# - Is decode-phase decoder self-attention correct?
|
||||
assert_actual_matches_ideal(decphase_dec_test_params,
|
||||
decphase_dec_pckd_act_out)
|
||||
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
# DECODE: encoder/decoder cross-attention test
|
||||
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
|
||||
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
|
||||
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
# - Is decode-phase encoder/decoder cross-attention correct?
|
||||
assert_actual_matches_ideal(decphase_cross_test_params,
|
||||
decphase_cross_pckd_act_out)
|
||||
|
@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
@ -8,24 +8,10 @@ from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.xformers import XFormersBackend
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
# String name of register which may be set in order to
|
||||
# force auto-selection of attention backend by Attention
|
||||
# wrapper
|
||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad)
|
||||
|
||||
|
||||
class QKVInputs(NamedTuple):
|
||||
|
153
tests/models/test_bart.py
Normal file
153
tests/models/test_bart.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""Compare the outputs of HF and vLLM for BART models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_bart.py`.
|
||||
"""
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
if not is_cpu():
|
||||
# CPU backend is not currently supported with encoder/decoder models
|
||||
# skip test definitions entirely to avoid importing GPU kernel libs
|
||||
# (xFormers, etc.)
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import DecoderPromptType
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
|
||||
|
||||
DECODER_PROMPT_TYPES = ([
|
||||
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
|
||||
DecoderPromptType.NONE
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_encoder_decoder_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
) -> None:
|
||||
'''
|
||||
Test the vLLM BART model for a variety of encoder/decoder input prompts,
|
||||
by validating it against HuggingFace (HF) BART.
|
||||
|
||||
Arguments:
|
||||
|
||||
* hf_runner: HuggingFace (HF) test model runner
|
||||
* vllm_runner: vLLM test model runner
|
||||
* example_encoder_decoder_prompts: test fixture which provides a
|
||||
dictionary of dummy prompts
|
||||
* model: the HF ID of the specific BART variant under test
|
||||
* dtype: the tensor datatype to employ
|
||||
* max_tokens
|
||||
* num_logprobs
|
||||
* decoder_prompt_type: key into the example_encoder_decoder_prompts
|
||||
dictionary; selects specific encoder/decoder
|
||||
prompt scenarios to test
|
||||
|
||||
A note on using HF BART as a baseline for validating vLLM BART,
|
||||
specifically when the decoder prompt is None.
|
||||
|
||||
The HF GenerationMixin's default behavior is to force the first
|
||||
decoded token to be <BOS> if the prompt does not already contain
|
||||
<BOS> (this is accomplished using a logit
|
||||
processor setting.)
|
||||
|
||||
So when we use HF BART as our baseline for comparison, note that
|
||||
when the user provides a request with a None decoder prompt
|
||||
(i.e. a singleton encoder prompt, or else an explicit encoder/
|
||||
decoder prompt with the decoder sub-prompt set to None), HF and
|
||||
vLLM handle this in different ways:
|
||||
|
||||
* HF will (1) tokenize the None prompt as an empty token-list,
|
||||
(2) append <decoder-start-token> to the beginning, yielding
|
||||
[<decoder-start-token>], (3) pass this token list to the model, and
|
||||
then (4) after computing logits during prefill, override the model
|
||||
logits & force <BOS> to be the first generated token.
|
||||
|
||||
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
|
||||
start-token to the beginning, yielding [<decoder-start-token><BOS>],
|
||||
(3) pass these tokens to the model & proceed with generation.
|
||||
|
||||
The net effect is that compared to vLLM, the list of HF *decoded* tokens
|
||||
will contain one more initial <BOS> than the vLLM generated tokens,
|
||||
because vLLM's <BOS> token is injected into the prompt rather than into
|
||||
the generated output. This is in spite of the fact that overall, the
|
||||
complete sequences (prompt + decoded tokens) produced by vLLM will match
|
||||
HF.
|
||||
|
||||
So when we use HF decoded token output to validate vLLM's decoded token
|
||||
output, the testing process must account for the difference in decoded
|
||||
token sequences between vLLM and HF specifically in the
|
||||
decoder-prompt-is-None case.
|
||||
|
||||
One option is to disable the logit processor feature that forces the
|
||||
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
|
||||
the problem entirely. However this is not "normal" BART usage.
|
||||
|
||||
The other option is - only in the decoder-prompt-is-None case - to
|
||||
discard the first decoded token from the HF output before comparing it
|
||||
to vLLM.
|
||||
|
||||
To that end, when testing the scenario where the decoder prompt is None
|
||||
(and only in that one scenario), this test skips the first HF decoded
|
||||
token during the process of validating the vLLM decoded output.
|
||||
'''
|
||||
|
||||
test_case_prompts = example_encoder_decoder_prompts[
|
||||
decoder_prompt_type]
|
||||
|
||||
# Configuration settings for HF baseline
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
is_encoder_decoder_model=True) as hf_model:
|
||||
hf_outputs = (
|
||||
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_case_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
|
||||
# Note: currently encoder/decoder models are only compatible with
|
||||
# enforce_eager=True. Normally this is not a problem because
|
||||
# for encoder/decoder models vLLM will
|
||||
# default to enforce_eager=True if enforce_eager
|
||||
# is left unspecified. However, the
|
||||
# VllmRunner test fixture (which wraps around the LLM class) defaults to
|
||||
# enforce_eager=False (a behavior which a number of already-exisitng
|
||||
# decoder-only unit tests expect), so when testing an encoder/decoder
|
||||
# model we must explicitly specify enforce_eager=True in the VllmRunner
|
||||
# constructor.
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
test_case_prompts, max_tokens, num_logprobs)
|
||||
|
||||
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
|
||||
else 0)
|
||||
|
||||
check_logprobs_close(outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens)
|
@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -45,11 +46,27 @@ def check_logprobs_close(
|
||||
outputs_1_lst: Sequence[TokensTextLogprobs],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
num_outputs_0_skip_tokens: int = 0,
|
||||
warn_on_mismatch: bool = True,
|
||||
):
|
||||
"""
|
||||
Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
|
||||
Arguments:
|
||||
|
||||
* outputs_0_lst: First sequence to compare
|
||||
* outputs_0_lst: Second sequence to compare
|
||||
* name_0: sequence #0 name
|
||||
* name_1: sequence #1 name
|
||||
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
|
||||
sequence #0 tokens & logprobs to discard
|
||||
before comparison, i.e. all
|
||||
of sequence #1 will be compared to
|
||||
sequence #0 beginning at index
|
||||
num_outputs_0_skip_tokens
|
||||
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
|
||||
mismatch between the two sequences
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
@ -65,6 +82,15 @@ def check_logprobs_close(
|
||||
if logprobs_1 is None:
|
||||
logprobs_1 = [None] * len(output_ids_1)
|
||||
|
||||
# Skip specified number of initial sequence #0 tokens
|
||||
# & logprobs, leaving output text as-is for simplicity
|
||||
# (text mismatches may generate warnings but do not
|
||||
# cause the test to fail.)
|
||||
if num_outputs_0_skip_tokens < 0:
|
||||
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
|
||||
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
|
||||
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
@ -110,3 +136,13 @@ def check_logprobs_close(
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
|
||||
class DecoderPromptType(Enum):
|
||||
'''
|
||||
For encoder/decoder models only -
|
||||
|
||||
'''
|
||||
CUSTOM = 1
|
||||
NONE = 2
|
||||
EMPTY_STR = 3
|
||||
|
480
tests/worker/test_encoder_decoder_model_runner.py
Normal file
480
tests/worker/test_encoder_decoder_model_runner.py
Normal file
@ -0,0 +1,480 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
|
||||
# CUDA graph scenarios to test
|
||||
#
|
||||
# Currently CUDA graph is not supported
|
||||
ENFORCE_EAGER = [True]
|
||||
|
||||
BATCH_SIZES = [1, 4, 16, 64, 256]
|
||||
|
||||
|
||||
def _create_model_runner(model: str, *args,
|
||||
**kwargs) -> EncoderDecoderModelRunner:
|
||||
engine_args = EngineArgs(model, *args, **kwargs)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
model_runner = EncoderDecoderModelRunner(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
lora_config=engine_config.lora_config,
|
||||
prompt_adapter_config=engine_config.prompt_adapter_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
return model_runner
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_empty_seq_group(enforce_eager, ):
|
||||
"""Verify prepare prompt and decode returns empty output
|
||||
for empty seq group list"""
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/bart-base",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
attn_metadata,
|
||||
return_seq_lens,
|
||||
) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.encoder_input_tokens,
|
||||
model_input.encoder_input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert encoder_input_tokens is None
|
||||
assert encoder_input_positions is None
|
||||
assert attn_metadata is None
|
||||
assert return_seq_lens is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_prepare_prompt(
|
||||
batch_size,
|
||||
enforce_eager,
|
||||
):
|
||||
'''
|
||||
Test the ability of the encoder/decoder model runner subclass to
|
||||
produce prefill-phase model inputs & attention metadata.
|
||||
|
||||
Test behavior:
|
||||
|
||||
* Instantiate BART base model & enc/dec model runner
|
||||
* Construct sequence-group metadata for dummy prompts
|
||||
* Test that encoder attention, decoder self-attention,
|
||||
and encoder/decoder cross-attention inputs are correct
|
||||
|
||||
Arguments:
|
||||
|
||||
* batch_size
|
||||
* backend_name: The attention backend under test
|
||||
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
|
||||
'''
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/bart-base",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
encoder_seq_lens: List[int] = []
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
cross_block_table = [2]
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_lens.append(encoder_seq_len)
|
||||
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=cross_block_table,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# Build
|
||||
# * Decoder model inputs
|
||||
# * Decoder self-attention KV caching data structures
|
||||
# * Encoder model inputs
|
||||
# * Encoder/decoder cross-attention KV caching data structures
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
encoder_input_tokens = model_input.encoder_input_tokens
|
||||
encoder_input_positions = model_input.encoder_input_positions
|
||||
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
# - Decoder attention metadata
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills > 0
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
assert torch.equal(attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
|
||||
assert attn_metadata.max_decode_seq_len == 0
|
||||
# - Encoder attention metadata
|
||||
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
|
||||
assert torch.equal(
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
|
||||
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
|
||||
|
||||
# Test decoder subquery start locs.
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.equal(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
||||
)
|
||||
|
||||
# Test decoder seq start locs & context lengths
|
||||
|
||||
assert torch.equal(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
||||
)
|
||||
assert torch.equal(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
||||
dtype=torch.int,
|
||||
device=device),
|
||||
)
|
||||
|
||||
# Verify block tables are correct for prompts
|
||||
# - Decoder self-attention
|
||||
expected = torch.tensor(
|
||||
[[] for _ in range(len(seq_group_metadata_list))],
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
assert torch.equal(
|
||||
attn_metadata.block_tables,
|
||||
expected,
|
||||
)
|
||||
# - Encoder/decoder cross-attention
|
||||
assert torch.equal(
|
||||
attn_metadata.cross_block_tables,
|
||||
expected,
|
||||
)
|
||||
|
||||
# Cuda graph should not be used for prefill.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
# Verify the lengths of input tokens & positions
|
||||
# - Decoder
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
# -- An indirect check that model_input.input_tokens
|
||||
# and model_input.input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
)
|
||||
# - Encoder
|
||||
assert len(encoder_input_tokens) == sum(encoder_seq_lens)
|
||||
# -- An indirect check that model_input.encoder_input_tokens
|
||||
# and model_input.encoder_input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
)
|
||||
|
||||
# Test that vLLM sampling infrastructure chooses the correct
|
||||
# sequence positions at which to sample (i.e. the end of
|
||||
# each sequence) in the prefill phase
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
# Compute the index offset of the final token in each
|
||||
# prompt (recall that the prompts are concatenated)
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(
|
||||
expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype,
|
||||
)
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_prepare_decode(
|
||||
batch_size,
|
||||
enforce_eager,
|
||||
):
|
||||
'''
|
||||
Test the ability of the encoder/decoder model runner subclass to
|
||||
produce decode-phase model inputs & attention metadata.
|
||||
|
||||
Test behavior:
|
||||
|
||||
* Instantiate BART base model & enc/dec model runner
|
||||
* Construct sequence-group metadata for dummy prompts
|
||||
* Test that encoder attention, decoder self-attention,
|
||||
and encoder/decoder cross-attention inputs are correct
|
||||
|
||||
Arguments:
|
||||
|
||||
* batch_size
|
||||
* backend_name: The attention backend under test
|
||||
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
|
||||
'''
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/bart-base",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
encoder_seq_lens: List[int] = []
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
cross_block_table = [2]
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_lens.append(encoder_seq_len)
|
||||
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=cross_block_table,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# Build
|
||||
# * Decoder model inputs
|
||||
# * Decoder self-attention KV caching data structures
|
||||
# * Encoder model inputs
|
||||
# * Encoder/decoder cross-attention KV caching data structures
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
encoder_input_tokens = model_input.encoder_input_tokens
|
||||
encoder_input_positions = model_input.encoder_input_positions
|
||||
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
||||
|
||||
# Verify input metadata is correct for decode phase.
|
||||
# - Decoder attention metadata
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_decode_tokens > 0
|
||||
assert torch.equal(attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.max_prefill_seq_len == 0
|
||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||
# - Encoder attention metadata
|
||||
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
|
||||
assert torch.equal(
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
|
||||
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
|
||||
|
||||
# Test decoder subquery start locs.
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += 1
|
||||
start_loc.append(start_idx)
|
||||
assert torch.equal(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
||||
)
|
||||
|
||||
# Test decoder seq start locs. Note that for normal prefill it is
|
||||
# equivalent to query_start_loc.
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
|
||||
# Test seq_start_loc and context lengths
|
||||
|
||||
assert torch.equal(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
|
||||
)
|
||||
assert torch.equal(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.tensor([seq_len - 1 for seq_len in seq_lens],
|
||||
dtype=torch.int,
|
||||
device=device))
|
||||
|
||||
# Verify block tables are correct for prompts
|
||||
# - Decoder self-attention
|
||||
expected = torch.tensor(
|
||||
[block_tables[0] for _ in range(len(seq_group_metadata_list))],
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device)
|
||||
assert torch.equal(
|
||||
attn_metadata.block_tables,
|
||||
expected,
|
||||
)
|
||||
# - Encoder/decoder cross-attention
|
||||
expected = torch.tensor(
|
||||
[cross_block_table for _ in range(len(seq_group_metadata_list))],
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device)
|
||||
assert torch.equal(
|
||||
attn_metadata.cross_block_tables,
|
||||
expected,
|
||||
)
|
||||
|
||||
# Cuda graph should is currently not supported for encoder/decoer.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
# Verify the lengths of input tokens & positions
|
||||
# - Decoder
|
||||
assert len(input_tokens) == len(seq_lens)
|
||||
assert len(input_positions) == len(seq_lens)
|
||||
# -- An indirect check that model_input.input_tokens
|
||||
# and model_input.input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
)
|
||||
# - Encoder
|
||||
assert len(encoder_input_tokens) == 0
|
||||
assert len(encoder_input_tokens) == 0
|
||||
# -- An indirect check that model_input.encoder_input_tokens
|
||||
# and model_input.encoder_input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
)
|
||||
|
||||
# Test that vLLM sampling infrastructure chooses the correct
|
||||
# sequence positions at which to sample (i.e. the end of
|
||||
# each sequence) in the decode phase
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
# Compute the index offset of the final token in each
|
||||
# sequence's decoded outputs; since a single token is
|
||||
# decoded per iteration per sequence, then the length
|
||||
# of the decoded tokens for a given sequence is 1 and
|
||||
# the final index offset into a given sequence's
|
||||
# generated tokens is 0 (i.e. the expected sampling index
|
||||
# for a given sequence is just `selected_token_start_idx`)
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
selected_token_start_idx += 1
|
||||
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(
|
||||
expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype,
|
||||
)
|
||||
assert torch.equal(actual, expected)
|
@ -1,6 +1,7 @@
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder)
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
@ -8,6 +9,7 @@ __all__ = [
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"AttentionType",
|
||||
"AttentionMetadataBuilder",
|
||||
"Attention",
|
||||
"get_attn_backend",
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
@ -1,6 +1,8 @@
|
||||
import enum
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
from typing import Generator, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -8,7 +10,8 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
|
||||
is_tpu, is_xpu)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,6 +27,66 @@ class _Backend(enum.Enum):
|
||||
IPEX = enum.auto()
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
assert backend_name is not None
|
||||
|
||||
backend_members = _Backend.__members__
|
||||
if backend_name not in backend_members:
|
||||
raise ValueError(f"Invalid attention backend '{backend_name}'. "
|
||||
f"Available backends: {', '.join(backend_members)} "
|
||||
"(case-sensitive).")
|
||||
|
||||
return _Backend[backend_name]
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* None otherwise
|
||||
'''
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return (None
|
||||
if backend_name is None else backend_name_to_enum(backend_name))
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
# to be forced, overriding the logic which auto-selects
|
||||
# a backend based on system & workload configuration
|
||||
# (default behavior if this variable is None)
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: Optional[_Backend] = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
'''
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
Passing `None` for the argument re-enables automatic
|
||||
backend selection.,
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: backend selection (None to revert to auto)
|
||||
'''
|
||||
global forced_attn_backend
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
'''
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(
|
||||
num_heads: int,
|
||||
@ -101,16 +164,20 @@ def which_attn_to_use(
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
backend_members = _Backend.__members__
|
||||
if backend_by_env_var not in backend_members:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend '{backend_by_env_var}'. "
|
||||
f"Available backends: {', '.join(backend_members)} "
|
||||
"(case-sensitive).")
|
||||
selected_backend = _Backend[backend_by_env_var]
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
if is_cpu():
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
@ -193,3 +260,35 @@ def which_attn_to_use(
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
return selected_backend
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||
'''
|
||||
Globally force a vLLM attention backend override within a
|
||||
context manager, reverting the global attention backend
|
||||
override to its prior state upon exiting the context
|
||||
manager.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: attention backend to force
|
||||
|
||||
Returns:
|
||||
|
||||
* Generator
|
||||
'''
|
||||
|
||||
# Save the current state of the global backend override (if any)
|
||||
original_value = get_global_forced_attn_backend()
|
||||
|
||||
# Globally force the new backend override
|
||||
global_force_attn_backend(attn_backend)
|
||||
|
||||
# Yield control back to the enclosed code block
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Revert the original global backend override, if any
|
||||
global_force_attn_backend(original_value)
|
||||
|
@ -12,7 +12,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.tracing import is_otel_installed
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
|
||||
cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
|
||||
print_warning_once)
|
||||
|
||||
@ -87,6 +88,9 @@ class ModelConfig:
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
If None, the user did not specify, so default to False -
|
||||
except for encoder/decoder models, which currently require
|
||||
eager mode.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
|
||||
@ -121,7 +125,7 @@ class ModelConfig:
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: bool = False,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
@ -160,6 +164,34 @@ class ModelConfig:
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Choose a default enforce_eager value if the user did not specify
|
||||
# a value (enforce_eager is None)
|
||||
if getattr(self.hf_config, 'is_encoder_decoder', False):
|
||||
if self.enforce_eager is None:
|
||||
# *Only for encoder/decoder models* and
|
||||
# *only if enforce_eager is unset*, override
|
||||
# to enforce_eager=True
|
||||
#
|
||||
# Add a logger message since it is *somewhat* non-intuitive that
|
||||
# enforce_eager is True when the user has not specified its
|
||||
# value.
|
||||
logger.info("Forcing enforce_eager == True because "
|
||||
"enforce_eager setting was unspecified and "
|
||||
"CUDAGraph is not supported with encoder/ "
|
||||
"decoder models.")
|
||||
self.enforce_eager = True
|
||||
|
||||
if not self.enforce_eager:
|
||||
# Eager mode explicitly disabled by user for an encoder/
|
||||
# decoder model; however CUDAGRAPH + encoder/decoder is
|
||||
# not currently supported
|
||||
raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
|
||||
elif self.enforce_eager is None:
|
||||
# *Only for decoder-only models*, enforce_eager
|
||||
# defaults to False if unset. This is intuitive
|
||||
# so no logging message needed.
|
||||
self.enforce_eager = False
|
||||
|
||||
if (not self.disable_sliding_window
|
||||
and self.hf_text_config.model_type == "gemma2"
|
||||
and self.hf_text_config.sliding_window is not None):
|
||||
|
@ -1,15 +1,7 @@
|
||||
"""Block manager utils."""
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
# Exception strings for non-implemented block manager enc/dec scenarios
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||
"Sliding window attention for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
||||
"Prefix caching for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
|
||||
|
||||
def _get_block_mgr_sliding_window_attr(block_mgr):
|
||||
|
@ -392,6 +392,19 @@ class Scheduler:
|
||||
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||
self.free_seq(seq)
|
||||
|
||||
self._free_seq_group_cross_attn_blocks(aborted_group)
|
||||
|
||||
def _free_seq_group_cross_attn_blocks(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Free a sequence group from a cross-attention block table.
|
||||
Has no effect on decoder-only models.
|
||||
"""
|
||||
if seq_group.is_encoder_decoder():
|
||||
self.block_manager.free_cross(seq_group)
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
||||
self.swapped) != 0
|
||||
@ -963,6 +976,17 @@ class Scheduler:
|
||||
# seq_id -> physical block numbers
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
# Encoder associated with SequenceGroup
|
||||
encoder_seq_data = seq_group.get_encoder_seq().data
|
||||
# Block table for cross-attention
|
||||
# Also managed at SequenceGroup level
|
||||
cross_block_table = self.block_manager.get_cross_block_table(
|
||||
seq_group)
|
||||
else:
|
||||
encoder_seq_data = None
|
||||
cross_block_table = None
|
||||
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_id = seq.seq_id
|
||||
seq_data[seq_id] = seq.data
|
||||
@ -1001,6 +1025,8 @@ class Scheduler:
|
||||
token_chunk_size=token_chunk_size,
|
||||
lora_request=seq_group.lora_request,
|
||||
computed_block_nums=common_computed_block_nums,
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=cross_block_table,
|
||||
# `multi_modal_data` will only be present for the 1st comm
|
||||
# between engine and worker.
|
||||
# the subsequent comms can still use delta, but
|
||||
@ -1032,6 +1058,8 @@ class Scheduler:
|
||||
remaining: Deque[SequenceGroup] = deque()
|
||||
for seq_group in self.running:
|
||||
if seq_group.is_finished():
|
||||
# Free cross-attention block table, if it exists
|
||||
self._free_seq_group_cross_attn_blocks(seq_group)
|
||||
# Add the finished requests to the finished requests list.
|
||||
# This list will be used to update the Mamba cache in the
|
||||
# next step.
|
||||
|
@ -69,7 +69,7 @@ class EngineArgs:
|
||||
rope_theta: Optional[float] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
enforce_eager: Optional[bool] = None
|
||||
max_context_len_to_capture: Optional[int] = None
|
||||
max_seq_len_to_capture: int = 8192
|
||||
disable_custom_all_reduce: bool = False
|
||||
|
@ -3,7 +3,7 @@ from contextlib import contextmanager
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||
Mapping, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, TypeVar, Union
|
||||
from typing import Set, Tuple, Type, TypeVar, Union
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
@ -22,7 +22,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
|
||||
get_prompt_type)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
@ -42,7 +43,8 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter
|
||||
from vllm.utils import (Counter, is_embedding_model_config,
|
||||
is_encoder_decoder_model_config)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -502,8 +504,19 @@ class LLMEngine:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def _get_eos_token_id(
|
||||
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
|
||||
def _get_bos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for BOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
|
||||
|
||||
def _get_eos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for EOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
@ -511,6 +524,32 @@ class LLMEngine:
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||
|
||||
def _get_decoder_start_token_id(self, ) -> Optional[int]:
|
||||
'''
|
||||
Obtain the decoder start token id employed by an encoder/decoder
|
||||
model. Returns None for non-encoder/decoder models or if the
|
||||
model config is unavailable.
|
||||
'''
|
||||
|
||||
if not self.is_encoder_decoder_model():
|
||||
logger.warning("Using None for decoder start token id because "
|
||||
"this is not an encoder/decoder model.")
|
||||
return None
|
||||
|
||||
if (self.model_config is None or self.model_config.hf_config is None):
|
||||
logger.warning("Using None for decoder start token id because "
|
||||
"model config is not available.")
|
||||
return None
|
||||
|
||||
dec_start_token_id = getattr(self.model_config.hf_config,
|
||||
'decoder_start_token_id', None)
|
||||
if dec_start_token_id is None:
|
||||
logger.warning("Falling back on <BOS> for decoder start token id "
|
||||
"because decoder start token id is not available.")
|
||||
dec_start_token_id = self._get_bos_token_id()
|
||||
|
||||
return dec_start_token_id
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -529,6 +568,16 @@ class LLMEngine:
|
||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
encoder_seq = None
|
||||
if 'encoder_prompt_token_ids' in processed_inputs:
|
||||
encoder_seq = Sequence(seq_id,
|
||||
processed_inputs,
|
||||
block_size,
|
||||
eos_token_id,
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
from_decoder_prompt=False)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
@ -538,7 +587,8 @@ class LLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
@ -546,7 +596,8 @@ class LLMEngine:
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
@ -562,6 +613,336 @@ class LLMEngine:
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
_LLMInputComponentsType = Tuple[str, List[int], ]
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
decoder_input_ids: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
|
||||
Based on
|
||||
|
||||
https://github.com/huggingface/transformers/blob/
|
||||
4037a2b5b1278736e566aec12e169100275545ea/
|
||||
src/transformers/generation/utils.py
|
||||
|
||||
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
||||
|
||||
Arguments:
|
||||
|
||||
* decoder_input_ids: input token ids to preprocess
|
||||
|
||||
Returns:
|
||||
|
||||
* Processed token list
|
||||
"""
|
||||
|
||||
decoder_start_token_id: Optional[int] = (
|
||||
self._get_decoder_start_token_id())
|
||||
assert decoder_start_token_id is not None
|
||||
|
||||
if decoder_input_ids is None:
|
||||
# no decoder prompt input ->
|
||||
# use decoder_start_token_id as decoder_input_ids
|
||||
(decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
|
||||
|
||||
if (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||
|
||||
return decoder_input_ids
|
||||
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[str] = None,
|
||||
) -> List[int]:
|
||||
'''
|
||||
Wrapper around application of the model's
|
||||
tokenizer.
|
||||
|
||||
Arguments:
|
||||
|
||||
* prompt
|
||||
* request_id
|
||||
* lora_request
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt token ids
|
||||
'''
|
||||
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _extract_single_prompt_for_enc_dec_input(
|
||||
self,
|
||||
inputs: Optional[PromptInputs],
|
||||
request_id: Optional[str] = None,
|
||||
ptype: Optional[str] = None,
|
||||
is_encoder_prompt: bool = False,
|
||||
) -> Tuple[Optional[str], List[int]]:
|
||||
'''
|
||||
Only for encoder/decoder models:
|
||||
Extract prompt & prompt_token_ids from any single
|
||||
encoder or decoder input prompt. For encoder input prompts
|
||||
in particular, also extract multi-modal data.
|
||||
|
||||
This function handles the following scenarios:
|
||||
1. The user supplied a singleton encoder prompt
|
||||
& the prompt/prompt-token-ids must be extracted.
|
||||
2. The user supplied an explicit encoder/decoder
|
||||
prompt & the prompt/prompt-token-ids must be
|
||||
extracted from either the encoder and decoder prompts.
|
||||
|
||||
For decoder prompts in particular (scenario 2), special
|
||||
processing is applied to the returned decoder token ids.
|
||||
|
||||
Arguments:
|
||||
|
||||
* request_id
|
||||
* ptype: str representation of the input prompt type.
|
||||
If `ptype` is `None`, assume that the prompt
|
||||
type is unknown and must be inferred. This is the
|
||||
case for ExplicitEncoderDecoder sub-prompts.
|
||||
* inputs: single encoder or decoder input prompt
|
||||
* is_encoder_prompt: True if encoder input prompt.
|
||||
If False, decoder prompt tokens
|
||||
are preprocessed.
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt
|
||||
* prompt_token_ids
|
||||
'''
|
||||
prompt_token_ids = None
|
||||
ptype = (get_prompt_type(inputs) if ptype is None else ptype)
|
||||
|
||||
if inputs is None:
|
||||
prompt = None
|
||||
elif ptype == 'str':
|
||||
prompt = inputs
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
elif ptype == 'TokensPrompt':
|
||||
prompt = None
|
||||
prompt_token_ids = inputs['prompt_token_ids']
|
||||
else:
|
||||
prompt = inputs['prompt']
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if not is_encoder_prompt:
|
||||
# Apply special pre-processing to
|
||||
# decoder prompts
|
||||
prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
|
||||
prompt_token_ids, ))
|
||||
|
||||
assert prompt_token_ids is not None
|
||||
|
||||
return (
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
)
|
||||
|
||||
def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
|
||||
'''
|
||||
Specifically for encoder/decoder models:
|
||||
generate a default decoder prompt for when
|
||||
the user specifies only the encoder prompt.
|
||||
|
||||
Encoder/decoder models utilize the decoder
|
||||
prompt in different ways; as new models are
|
||||
added, it is intended that this function
|
||||
will be extended to produce differing
|
||||
default decoder prompts, depending on the
|
||||
model variety.
|
||||
|
||||
Absent a special case, the default behavior
|
||||
of this method is to mirror the behavior of
|
||||
the HuggingFace (HF) GenerationMixin for a None
|
||||
decoder prompt, which is to employ a logit processor
|
||||
setting to force the first decoded token to be <BOS>.
|
||||
Here, this behavior is approximated by having the
|
||||
"default" decoder prompt be <BOS>.
|
||||
|
||||
However, it is possible that in the future
|
||||
other models may have different or more
|
||||
complex logic for the default decoder prompt.
|
||||
This motivates having a special helper method
|
||||
for default decoder prompts.
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt_token_ids
|
||||
'''
|
||||
|
||||
bos_token_id = self._get_bos_token_id()
|
||||
assert bos_token_id is not None
|
||||
prompt_token_ids: List[int] = [bos_token_id]
|
||||
return prompt_token_ids
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: Optional[str] = None,
|
||||
) -> LLMInputs:
|
||||
'''
|
||||
For encoder/decoder models only:
|
||||
Process an input prompt
|
||||
into an `LLMInputs` instance.
|
||||
|
||||
There are two types of input prompts:
|
||||
singleton prompts which carry only the
|
||||
encoder prompt, and explicit encoder/decoder
|
||||
prompts which carry both the encoder and the
|
||||
decoder prompts as member variables.
|
||||
|
||||
This function handles the following scenarios:
|
||||
* Singleton encoder prompt: extract encoder prompt
|
||||
token ids & infer default decoder prompt token ids
|
||||
* Explicit encoder/decoder prompt: extract encoder
|
||||
and decoder prompt token ids
|
||||
|
||||
Note that for Explicit encoder/decoder prompts,
|
||||
each sub-prompt (encoder or decoder prompt) can
|
||||
have any possible singleton type; thus this
|
||||
method relies on helper functions to obtain
|
||||
token ids for the sub-prompts.
|
||||
|
||||
Arguments:
|
||||
|
||||
* inputs: an input prompt
|
||||
* request_id
|
||||
|
||||
Returns:
|
||||
|
||||
* `LLMInputs` instance
|
||||
'''
|
||||
|
||||
ptype = get_prompt_type(inputs)
|
||||
|
||||
# Obtain encoder and decoder prompt tokens. Note
|
||||
# that, no matter what, the decoder
|
||||
# prompt type is unknown.
|
||||
if ptype == "ExplicitEncoderDecoder":
|
||||
# If input is explicit encoder/decoder prompt,
|
||||
# then it remains to be determined what type
|
||||
# of encoder prompt we have
|
||||
extracted_encoder_prompt = inputs.get('encoder_prompt')
|
||||
encoder_ptype = None
|
||||
# Extract decoder prompt from explicit
|
||||
# encoder/decoder prompt
|
||||
extracted_decoder_prompt = inputs.get('decoder_prompt')
|
||||
else:
|
||||
# If input is singleton encoder prompt, then
|
||||
# we know the encoder prompt type
|
||||
extracted_encoder_prompt = inputs
|
||||
encoder_ptype = ptype
|
||||
# Decoder prompt is always unknown if
|
||||
# encoder/decoder prompt is not explicit
|
||||
extracted_decoder_prompt = None
|
||||
|
||||
# Invoke helper function to obtain encoder
|
||||
# prompt and prompt token ids, either from
|
||||
# singleton encoder prompt or from the
|
||||
# encoder sub-prompt of an explicit
|
||||
# encoder/decode scenario 2), special
|
||||
# processing is applied to the returned decoder token ids
|
||||
(
|
||||
encoder_prompt,
|
||||
encoder_prompt_token_ids,
|
||||
) = self._extract_single_prompt_for_enc_dec_input(
|
||||
extracted_encoder_prompt,
|
||||
request_id=request_id,
|
||||
ptype=encoder_ptype,
|
||||
is_encoder_prompt=True,
|
||||
)
|
||||
|
||||
# Invoke helper method to obtain
|
||||
# decoder prompt and prompt token ids.
|
||||
#
|
||||
# The helper method will detect the decoder
|
||||
# prompt type.
|
||||
#
|
||||
# Helper method will also apply special
|
||||
# preprocessing unique to decoder prompts.
|
||||
(
|
||||
decoder_prompt,
|
||||
decoder_prompt_token_ids,
|
||||
) = self._extract_single_prompt_for_enc_dec_input(
|
||||
extracted_decoder_prompt,
|
||||
request_id=request_id,
|
||||
ptype=None,
|
||||
is_encoder_prompt=False,
|
||||
)
|
||||
|
||||
return LLMInputs(
|
||||
prompt_token_ids=decoder_prompt_token_ids,
|
||||
prompt=decoder_prompt,
|
||||
encoder_prompt_token_ids=encoder_prompt_token_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
)
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
request_id: Optional[str] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
'''
|
||||
For decoder-only models:
|
||||
Process an input prompt
|
||||
into an `LLMInputs` instance.
|
||||
|
||||
Arguments:
|
||||
|
||||
* inputs: input prompt
|
||||
* lora_request
|
||||
* request_id
|
||||
* prompt_adapter_request
|
||||
|
||||
Returns:
|
||||
|
||||
* `LLMInputs` instance
|
||||
'''
|
||||
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
prompt = inputs.get("prompt")
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = (
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
||||
+ prompt_token_ids)
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -569,29 +950,25 @@ class LLMEngine:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
|
||||
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
||||
prompt=inputs["prompt"],
|
||||
lora_request=lora_request)
|
||||
model_inputs = self._process_encoder_decoder_prompt(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
# Decoder-only operation
|
||||
model_inputs = self._process_decoder_only_prompt(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = \
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
||||
+ prompt_token_ids
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
return self.input_processor(llm_inputs)
|
||||
return self.input_processor(model_inputs)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@ -676,6 +1053,7 @@ class LLMEngine:
|
||||
lora_request: Optional[LoRARequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
@ -701,7 +1079,8 @@ class LLMEngine:
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq)
|
||||
|
||||
return seq_group
|
||||
|
||||
@ -713,6 +1092,7 @@ class LLMEngine:
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with PoolingParams."""
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
@ -724,7 +1104,8 @@ class LLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq)
|
||||
return seq_group
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
@ -1214,3 +1595,9 @@ class LLMEngine:
|
||||
seq_span.set_attribute(
|
||||
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
|
||||
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
|
||||
|
||||
def is_encoder_decoder_model(self):
|
||||
return is_encoder_decoder_model_config(self.model_config)
|
||||
|
||||
def is_embedding_model(self):
|
||||
return is_embedding_model_config(self.model_config)
|
||||
|
@ -121,12 +121,21 @@ class LLM:
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: bool = False,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
LLM constructor.
|
||||
|
||||
Note: if enforce_eager is unset (enforce_eager is None)
|
||||
it defaults to False for decoder-only models and True
|
||||
for encoder/decoder models, since encoder/decoder models
|
||||
do not currently support CUDAGraph.
|
||||
'''
|
||||
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
removed_vision_keys = ("image_token_id", "image_feature_size",
|
||||
@ -297,8 +306,8 @@ class LLM:
|
||||
"""
|
||||
if self.llm_engine.model_config.embedding_mode:
|
||||
raise ValueError(
|
||||
"LLM.generate() is only supported for generation models "
|
||||
"(XForCausalLM).")
|
||||
"LLM.generate() is only supported for (conditional) generation "
|
||||
"models (XForCausalLM, XForConditionalGeneration).")
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
inputs = self._convert_v1_inputs(
|
||||
@ -631,3 +640,9 @@ class LLM:
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return self.llm_engine.is_encoder_decoder_model()
|
||||
|
||||
def _is_embedding_model(self):
|
||||
return self.llm_engine.is_embedding_model()
|
||||
|
@ -1,5 +1,7 @@
|
||||
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
|
||||
TextPrompt, TokensPrompt, parse_and_batch_prompt)
|
||||
from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
|
||||
ParsedTokens, PromptInputs, SingletonPromptInputs,
|
||||
TextPrompt, TokensPrompt, get_prompt_type,
|
||||
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
|
||||
from .registry import InputContext, InputRegistry
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
@ -12,7 +14,18 @@ See also:
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
|
||||
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
|
||||
"InputContext", "InputRegistry"
|
||||
"ParsedText",
|
||||
"ParsedTokens",
|
||||
"parse_and_batch_prompt",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"PromptInputs",
|
||||
"LLMInputs",
|
||||
"INPUT_REGISTRY",
|
||||
"InputContext",
|
||||
"InputRegistry",
|
||||
"get_prompt_type",
|
||||
"is_valid_encoder_decoder_llm_inputs",
|
||||
"ExplicitEncoderDecoderPrompt",
|
||||
"SingletonPromptInputs",
|
||||
]
|
||||
|
@ -92,15 +92,114 @@ class TokensPrompt(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
PromptInputs = Union[str, TextPrompt, TokensPrompt]
|
||||
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
|
||||
"""
|
||||
The inputs to the LLM, which can take one of the following forms:
|
||||
Set of possible schemas for a single LLM input:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
|
||||
Note that "singleton" is as opposed to a data structure
|
||||
which encapsulates multiple prompts, i.e. of the sort
|
||||
which may be utilized for encoder/decoder models when
|
||||
the user desires to express both the encoder & decoder
|
||||
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
|
||||
|
||||
A prompt of type SingletonPromptInputs may be employed
|
||||
as (1) input to a decoder-only model, (2) input to
|
||||
the encoder of an encoder/decoder model, in the scenario
|
||||
where the decoder-prompt is not specified explicitly, or
|
||||
(3) as a member of a larger data structure encapsulating
|
||||
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
|
||||
"""
|
||||
|
||||
|
||||
class ExplicitEncoderDecoderPrompt(TypedDict):
|
||||
"""Represents an encoder/decoder model input prompt,
|
||||
comprising an explicit encoder prompt and a
|
||||
decoder prompt.
|
||||
|
||||
The encoder and decoder prompts, respectively,
|
||||
may formatted according to any of the
|
||||
SingletonPromptInputs schemas, and are not
|
||||
required to have the same schema.
|
||||
|
||||
Only the encoder prompt may have multi-modal data.
|
||||
|
||||
Note that an ExplicitEncoderDecoderPrompt may not
|
||||
be used as an input to a decoder-only model,
|
||||
and that the `encoder_prompt` and `decoder_prompt`
|
||||
fields of this data structure may not themselves
|
||||
must be SingletonPromptInputs instances.
|
||||
"""
|
||||
|
||||
encoder_prompt: SingletonPromptInputs
|
||||
|
||||
decoder_prompt: SingletonPromptInputs
|
||||
|
||||
|
||||
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
|
||||
"""
|
||||
Set of possible schemas for an LLM input, including
|
||||
both decoder-only and encoder/decoder input types:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
- A single data structure containing both an encoder and a decoder prompt
|
||||
(:class:`ExplicitEncoderDecoderPrompt`)
|
||||
"""
|
||||
|
||||
|
||||
def _has_required_keys(
|
||||
d: dict,
|
||||
required_keys: set,
|
||||
) -> bool:
|
||||
return required_keys.issubset(d.keys())
|
||||
|
||||
|
||||
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
|
||||
"""
|
||||
Get the type-name of the prompt argument instance, given that
|
||||
isinstance() cannot apply to TypedDict subclasses directly.
|
||||
If the prompt is None, return 'None' as the type name.
|
||||
|
||||
Arguments:
|
||||
|
||||
* prompt: LLM input prompt or None
|
||||
|
||||
Returns:
|
||||
|
||||
* String representation of prompt type
|
||||
"""
|
||||
|
||||
if prompt is None:
|
||||
return 'None'
|
||||
|
||||
required_keys_dict = {
|
||||
'TextPrompt': {'prompt'},
|
||||
'TokensPrompt': {'prompt_token_ids'},
|
||||
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
|
||||
}
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
for (ptype, required_keys) in required_keys_dict.items():
|
||||
# Ignore type checking in the conditional below because type
|
||||
# checker does not understand that is_dict(prompt) narrows
|
||||
# down the possible types
|
||||
if _has_required_keys(
|
||||
prompt, # type: ignore
|
||||
required_keys):
|
||||
return ptype
|
||||
|
||||
raise ValueError(f"Invalid prompt {prompt}, valid types are "
|
||||
"required_keys_dict={required_keys_dict}")
|
||||
|
||||
if isinstance(prompt, str):
|
||||
return "str"
|
||||
|
||||
raise ValueError(f"Invalid prompt {prompt}")
|
||||
|
||||
|
||||
class LLMInputs(TypedDict):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
@ -114,8 +213,29 @@ class LLMInputs(TypedDict):
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
"""
|
||||
|
||||
encoder_prompt_token_ids: NotRequired[List[int]]
|
||||
"""The token IDs of the encoder prompt."""
|
||||
|
||||
encoder_prompt: NotRequired[Optional[str]]
|
||||
"""
|
||||
The original encoder prompt text corresponding to the token IDs, if
|
||||
available.
|
||||
"""
|
||||
|
||||
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
|
||||
"""
|
||||
Return True if the LLMInputs instance has the correct configuration
|
||||
for encoder/decoder.
|
||||
"""
|
||||
|
||||
# True if encoder prompt token ids field exists &
|
||||
# is not None
|
||||
return ('encoder_prompt_token_ids' in inputs
|
||||
and inputs['encoder_prompt_token_ids'] is not None)
|
||||
|
@ -83,7 +83,16 @@ _EMBEDDING_MODELS = {
|
||||
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||
}
|
||||
|
||||
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
|
||||
_CONDITIONAL_GENERATION_MODELS = {
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||
}
|
||||
|
||||
_MODELS = {
|
||||
**_GENERATION_MODELS,
|
||||
**_EMBEDDING_MODELS,
|
||||
**_CONDITIONAL_GENERATION_MODELS
|
||||
}
|
||||
|
||||
# Architecture -> type.
|
||||
# out of tree models
|
||||
|
996
vllm/model_executor/models/bart.py
Normal file
996
vllm/model_executor/models/bart.py
Normal file
@ -0,0 +1,996 @@
|
||||
# Derived from BART implementation posted on HuggingFace; license below:
|
||||
#
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BART model."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BartConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_bsz_seq_len(input_ids):
|
||||
shp = input_ids.shape
|
||||
ndim = len(shp)
|
||||
if ndim == 1:
|
||||
return 1, input_ids.numel()
|
||||
else:
|
||||
return shp[:2]
|
||||
|
||||
|
||||
class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# Bart is set up so that if padding_idx is
|
||||
# specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately.
|
||||
# Other models don't have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
attn_type: AttentionType,
|
||||
) -> torch.Tensor:
|
||||
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
||||
|
||||
assert attn_type != AttentionType.ENCODER_DECODER
|
||||
|
||||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
class BartScaledWordEmbedding(VocabParallelEmbedding):
|
||||
"""
|
||||
This module overrides VocabParallelEmbedding's
|
||||
forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
embed_scale: float = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class BartParallelLMHead(ParallelLMHead):
|
||||
"""
|
||||
This module overrides ParallelLMHead's
|
||||
forward by dividing by embeddings scale,
|
||||
yielding effectively the inverse of
|
||||
BartScaledWordEmbedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
embed_scale: float = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(input_ids) / self.embed_scale
|
||||
|
||||
|
||||
class BartEncoderAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
config: Optional[BartConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.embed_dim = embed_dim
|
||||
self.total_num_heads = num_heads
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.config = config
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads}).")
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model // self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
if self.total_num_kv_heads >= tp_world_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_world_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_world_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.ENCODER)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BartDecoderSelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
config: Optional[BartConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.embed_dim = embed_dim
|
||||
self.total_num_heads = num_heads
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.config = config
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads}).")
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model // self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
if self.total_num_kv_heads >= tp_world_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_world_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_world_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.DECODER)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BartCrossAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
config: Optional[BartConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.embed_dim = embed_dim
|
||||
self.total_num_heads = num_heads
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.config = config
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads}).")
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model // self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
if self.total_num_kv_heads >= tp_world_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_world_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_world_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# (afeldman-nm 2024/07/22) TODO:
|
||||
# Need a more efficient solution for q/k/v
|
||||
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
|
||||
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
if encoder_hidden_states is None:
|
||||
k = None
|
||||
v = None
|
||||
else:
|
||||
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
|
||||
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
attn_output = self.attn(q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
attn_type=AttentionType.ENCODER_DECODER)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BartEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = BartEncoderAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config)
|
||||
|
||||
ffn_hidden_size = self.embed_dim
|
||||
ffn_intermediate_size = config.encoder_ffn_dim
|
||||
ffn_has_bias = True
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
ffn_hidden_size,
|
||||
ffn_intermediate_size,
|
||||
bias=ffn_has_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
|
||||
self.fc2 = RowParallelLinear(
|
||||
ffn_intermediate_size,
|
||||
ffn_hidden_size,
|
||||
bias=ffn_has_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
hidden_states
|
||||
torch.Tensor of *encoder* input embeddings.
|
||||
kv_cache:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Encoder layer output torch.Tensor
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
residual = hidden_states
|
||||
fc1_out, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(fc1_out)
|
||||
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any()
|
||||
or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states,
|
||||
min=-clamp_value,
|
||||
max=clamp_value)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BartDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = BartDecoderSelfAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
'''
|
||||
afeldman-nm: personally I would call this "cross-attention",
|
||||
however I left the name as "encoder_attn" to maintain consistency
|
||||
with the name of the pretrained weights.
|
||||
'''
|
||||
self.encoder_attn = BartCrossAttention(
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
config=config,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
ffn_hidden_size = self.embed_dim
|
||||
ffn_intermediate_size = config.encoder_ffn_dim
|
||||
ffn_has_bias = True
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
ffn_hidden_size,
|
||||
ffn_intermediate_size,
|
||||
bias=ffn_has_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
ffn_intermediate_size,
|
||||
ffn_hidden_size,
|
||||
bias=ffn_has_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
decoder_hidden_states
|
||||
torch.Tensor of *decoder* input embeddings.
|
||||
kv_cache:
|
||||
KV cache tensor
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
encoder_hidden_states
|
||||
torch.Tensor of *encoder* input embeddings.
|
||||
Returns:
|
||||
Decoder layer output torch.Tensor
|
||||
"""
|
||||
residual = decoder_hidden_states
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Cross-Attention Block
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.encoder_attn(
|
||||
decoder_hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
fc1_out, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(fc1_out)
|
||||
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BartEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of *config.encoder_layers*
|
||||
self attention layers. Each layer is a [`BartEncoderLayer`].
|
||||
Args:
|
||||
config: BartConfig
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__()
|
||||
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.lora_config = lora_config
|
||||
embed_dim = config.d_model
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||
embed_dim,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[BartEncoderLayer(config,cache_config,quant_config) \
|
||||
for _ in range(config.encoder_layers)])
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||
Padding will be ignored by default should you
|
||||
provide it.
|
||||
positions
|
||||
Positions of *encoder* input sequence tokens.
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
# retrieve input_ids and inputs_embeds
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(
|
||||
positions,
|
||||
AttentionType.ENCODER,
|
||||
)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BartDecoder(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.decoder_layers* layers.
|
||||
Each layer is a [`BartDecoderLayer`]
|
||||
Args:
|
||||
config: BartConfig
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
embed_tokens: Optional[nn.Embedding] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.lora_config = lora_config
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[BartDecoderLayer(config,cache_config,quant_config) \
|
||||
for _ in range(config.decoder_layers)])
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
def forward(self, decoder_input_ids: torch.Tensor,
|
||||
decoder_positions: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
decoder_input_ids
|
||||
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||
Padding will be ignored by default should you
|
||||
provide it.
|
||||
decoder_positions
|
||||
Positions of *decoder* input sequence tokens.
|
||||
encoder_hidden_states:
|
||||
Tensor of encoder output embeddings
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
|
||||
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
||||
|
||||
# embed positions
|
||||
embed_pos = self.embed_positions(
|
||||
decoder_positions,
|
||||
AttentionType.DECODER,
|
||||
)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
||||
# decoder layers
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
decoder_hidden_states=hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BartModel(nn.Module):
|
||||
_tied_weights_keys = [
|
||||
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.encoder = BartEncoder(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
self.decoder = BartDecoder(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||
Padding will be ignored by default should you
|
||||
provide it.
|
||||
positions
|
||||
Positions of *decoder* input sequence tokens.
|
||||
encoder_input_ids
|
||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||
encoder_positions:
|
||||
Positions of *encoder* input sequence tokens.
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Model output torch.Tensor
|
||||
"""
|
||||
|
||||
encoder_hidden_states = None
|
||||
|
||||
if encoder_input_ids.numel() > 0:
|
||||
# Run encoder attention if a non-zero number of encoder tokens
|
||||
# are provided as input
|
||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||
positions=encoder_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids=input_ids,
|
||||
decoder_positions=positions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
class BartForConditionalGeneration(nn.Module):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = BartModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.lm_head = BartParallelLMHead(config.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
torch.Tensor of *decoder* input token ids.
|
||||
positions
|
||||
torch.Tensor of *decoder* position indices.
|
||||
encoder_input_ids
|
||||
torch.Tensor of *encoder* input token ids.
|
||||
encoder_positions
|
||||
torch.Tensor of *encoder* position indices
|
||||
kv_caches:
|
||||
Layer-wise list of KV cache tensors
|
||||
attn_metadata:
|
||||
vLLM Attention metadata structure
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions, kv_caches, attn_metadata)
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
stacked_params_mapping = {
|
||||
"q_proj": {
|
||||
"param_name": "qkv_proj",
|
||||
"shard_id": "q",
|
||||
},
|
||||
"k_proj": {
|
||||
"param_name": "qkv_proj",
|
||||
"shard_id": "k",
|
||||
},
|
||||
"v_proj": {
|
||||
"param_name": "qkv_proj",
|
||||
"shard_id": "v",
|
||||
},
|
||||
}
|
||||
|
||||
params_mapping = {
|
||||
"beta": "bias",
|
||||
"gamma": "weight",
|
||||
"LayerNorm": "layernorm",
|
||||
}
|
||||
|
||||
def _rename_key(self, key: str):
|
||||
prefix = f"{self.base_model_prefix}."
|
||||
key = key[len(prefix):] if key.startswith(prefix) else key
|
||||
|
||||
for src, dst in self.params_mapping.items():
|
||||
key = key.replace(src, dst)
|
||||
|
||||
return key
|
||||
|
||||
def _rename_stacked_param(
|
||||
self,
|
||||
name: str,
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
for key, mapping in self.stacked_params_mapping.items():
|
||||
if key in name:
|
||||
name = name.replace(key, mapping["param_name"])
|
||||
return name, mapping["shard_id"]
|
||||
return name, None
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
model_params_dict = dict(self.model.named_parameters())
|
||||
top_params_dict = dict(self.named_parameters())
|
||||
|
||||
weights_tuple_list = list(weights)
|
||||
|
||||
shared_embedding_weight = None
|
||||
shared_embedding_shard_id = None
|
||||
|
||||
for name, loaded_weight in weights_tuple_list:
|
||||
|
||||
name = self._rename_key(name)
|
||||
name, shard_id = self._rename_stacked_param(name)
|
||||
|
||||
if ('shared.weight' in name
|
||||
or 'encoder.embed_tokens.weight' in name
|
||||
or 'decoder.embed_tokens.weight' in name
|
||||
or 'lm_head.weight' in name):
|
||||
assert shared_embedding_weight is None, (
|
||||
"Conflicting embedding weights.")
|
||||
shared_embedding_weight = loaded_weight
|
||||
shared_embedding_shard_id = shard_id
|
||||
else:
|
||||
# Skip the specific downstream task weight.
|
||||
if name.startswith('cls.'):
|
||||
continue
|
||||
# use Pooler instead.
|
||||
if name.startswith('pooler.'):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in model_params_dict:
|
||||
continue
|
||||
|
||||
param = model_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if shard_id:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# Assign shared weight values
|
||||
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
|
||||
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
|
||||
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
lm_head_in_param = top_params_dict['lm_head.weight']
|
||||
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
assert shared_embedding_weight is not None
|
||||
|
||||
if shared_embedding_shard_id:
|
||||
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
|
||||
shared_embedding_shard_id)
|
||||
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
|
||||
shared_embedding_shard_id)
|
||||
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
|
||||
shared_embedding_shard_id)
|
||||
else:
|
||||
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
|
||||
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
|
||||
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
|
@ -70,12 +70,20 @@ class RequestOutput:
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string of the request.
|
||||
For encoder/decoder models, this is the
|
||||
decoder input prompt.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
For encoder/decoder models, this is the
|
||||
decoder input prompt token ids.
|
||||
prompt_logprobs: The log probabilities to return per prompt token.
|
||||
outputs: The output sequences of the request.
|
||||
finished: Whether the whole request is finished.
|
||||
metrics: Metrics associated with the request.
|
||||
lora_request: The LoRA request that was used to generate the output.
|
||||
encoder_prompt: The encoder prompt string of the request;
|
||||
None if decoder-only
|
||||
encoder_prompt_token_ids: The token IDs of the encoder prompt;
|
||||
None if decoder-only
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -88,6 +96,8 @@ class RequestOutput:
|
||||
finished: bool,
|
||||
metrics: Optional[RequestMetrics] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
encoder_prompt: Optional[str] = None,
|
||||
encoder_prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@ -97,6 +107,8 @@ class RequestOutput:
|
||||
self.finished = finished
|
||||
self.metrics = metrics
|
||||
self.lora_request = lora_request
|
||||
self.encoder_prompt = encoder_prompt
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
@ -137,6 +149,8 @@ class RequestOutput:
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = seq_group.prompt
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
encoder_prompt = seq_group.encoder_prompt
|
||||
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
|
||||
prompt_logprobs = seq_group.prompt_logprobs
|
||||
finished = seq_group.is_finished()
|
||||
finished_time = time.time() if finished else None
|
||||
@ -148,12 +162,16 @@ class RequestOutput:
|
||||
outputs,
|
||||
finished,
|
||||
seq_group.metrics,
|
||||
lora_request=seq_group.lora_request)
|
||||
lora_request=seq_group.lora_request,
|
||||
encoder_prompt=encoder_prompt,
|
||||
encoder_prompt_token_ids=encoder_prompt_token_ids)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"encoder_prompt={self.encoder_prompt!r}, "
|
||||
f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished}, "
|
||||
|
105
vllm/sequence.py
105
vllm/sequence.py
@ -7,10 +7,11 @@ from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||
Union)
|
||||
Union, cast)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -244,24 +245,38 @@ class SequenceData:
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
The sequence is constructed from the LLMInputs instance passed
|
||||
in through the `inputs` constructor argument.
|
||||
|
||||
For encoder/decoder models, LLMInputs encapsulates both a
|
||||
decoder and encoder prompt, creating an ambiguity about which
|
||||
prompt to construct the sequence from. The `from_decoder_prompt`
|
||||
constructor argument signals whether to construct the Sequence
|
||||
from the LLMInputs decoder prompt, or encoder prompt.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
inputs: The inputs of the sequence.
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
|
||||
lora_request: LoRA request.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
|
||||
(True) or encoder prompt (False.) Must be True
|
||||
for decoder-only model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "LLMInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "LLMInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
from_decoder_prompt: bool = True,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = inputs
|
||||
@ -269,6 +284,36 @@ class Sequence:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.from_decoder_prompt = from_decoder_prompt
|
||||
self._prompt: Optional[str] = None
|
||||
self._prompt_token_ids: Optional[List[int]] = None
|
||||
|
||||
# For decoder-only models, a Sequence is constructed
|
||||
# from an LLMInputs instance (the `inputs` arg.)
|
||||
#
|
||||
# For encoder/decoder models the same `inputs`
|
||||
# instance could be utilized to construct either an
|
||||
# encoder sequence or a decoder sequence, because
|
||||
# `LLMInputs` has both decoder- and encoder-oriented
|
||||
# member variables (i.e. it encapsulates both an encoder
|
||||
# and a decoder prompt.) The decision of which type of sequence
|
||||
# to generate is determined by the `from_decoder_prompt` argument.
|
||||
#
|
||||
# When constructing a encoder sequence
|
||||
# (`from_decoder_prompt` False) it matters that
|
||||
# the `LLMInputs` instance stored in `inputs` is valid
|
||||
# in the sense that its encoder-related member variables are
|
||||
# populated; below, an exception is raised if this is
|
||||
# not the case.
|
||||
#
|
||||
# When constructing a decoder sequence (`from_decoder_prompt` True)
|
||||
# it does not matter whether `inputs` has its encoder-related
|
||||
# member variables populated.
|
||||
if not (from_decoder_prompt
|
||||
or is_valid_encoder_decoder_llm_inputs(inputs)):
|
||||
raise ValueError("Cannot extract encoder input prompt from "
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
|
||||
self.data = SequenceData(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
@ -289,11 +334,35 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
return self.inputs.get("prompt")
|
||||
if self._prompt is not None:
|
||||
# Reuse precomputed prompt string
|
||||
return self._prompt
|
||||
|
||||
# Select decoder or encoder input prompt str,
|
||||
# as appropriate
|
||||
prompt_key: str = ("prompt"
|
||||
if self.from_decoder_prompt else "encoder_prompt")
|
||||
|
||||
# Cache prompt
|
||||
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
|
||||
return self._prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
return self.inputs["prompt_token_ids"]
|
||||
if self._prompt_token_ids is not None:
|
||||
# Reuse precomputed prompt token ids
|
||||
return self._prompt_token_ids
|
||||
|
||||
# Select decoder or encoder input prompt
|
||||
# token ids, as appropriate
|
||||
prompt_token_ids_key: str = ("prompt_token_ids"
|
||||
if self.from_decoder_prompt else
|
||||
"encoder_prompt_token_ids")
|
||||
|
||||
# Cache computed prompt token ids
|
||||
self._prompt_token_ids = cast(List[int],
|
||||
self.inputs.get(prompt_token_ids_key))
|
||||
return self._prompt_token_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
@ -472,6 +541,22 @@ class SequenceGroup:
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return self.seqs[0].prompt_token_ids
|
||||
|
||||
@property
|
||||
def encoder_prompt(self) -> Optional[str]:
|
||||
# There are either 0 or 1 encoder sequences
|
||||
# If one is present, its prompt is distinct
|
||||
# from the decoder's.
|
||||
return (self.encoder_seq.prompt
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
|
||||
# There are either 0 or 1 encoder sequences
|
||||
# If one is present, its prompt token ids are
|
||||
# distinct from the decoder's.
|
||||
return (self.encoder_seq.prompt_token_ids
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
# All sequences in the group should have the same multi-modal data.
|
||||
|
130
vllm/utils.py
130
vllm/utils.py
@ -27,10 +27,93 @@ from typing_extensions import ParamSpec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Exception strings for non-implemented encoder/decoder scenarios
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||
"Sliding window attention for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
||||
"Prefix caching for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
|
||||
"Chunked prefill for encoder/decoder models " + \
|
||||
"is not currently supported."
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
|
||||
"Models with logits_soft_cap "
|
||||
"require FlashInfer backend, which is "
|
||||
"currently not supported for encoder/decoder "
|
||||
"models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
|
||||
"supported with encoder/decoder "
|
||||
"models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
|
||||
"currently supported with "
|
||||
"encoder/decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
|
||||
"supported with encoder/decoder "
|
||||
"models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
# Efficiently import all enc/dec error strings
|
||||
# rather than having to import all of the above
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||
"STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
|
||||
"STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
"STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
|
||||
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
|
||||
"STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
|
||||
"STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
|
||||
"STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
|
||||
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
|
||||
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
|
||||
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
|
||||
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
|
||||
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
|
||||
}
|
||||
|
||||
# Constants related to forcing the attention backend selection
|
||||
|
||||
# String name of register which may be set in order to
|
||||
# force auto-selection of attention backend by Attention
|
||||
# wrapper
|
||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
def is_encoder_decoder_model_config(model_config) -> bool:
|
||||
'''
|
||||
Extract the HF encoder/decoder model flag from the ModelConfig instance.
|
||||
Return False if model_config is None.
|
||||
'''
|
||||
return model_config is not None and \
|
||||
getattr(model_config.hf_config,
|
||||
"is_encoder_decoder",
|
||||
False)
|
||||
|
||||
|
||||
def is_embedding_model_config(model_config) -> bool:
|
||||
'''
|
||||
Extract the embedding model flag from the ModelConfig instance.
|
||||
Return False if model_config is None.
|
||||
'''
|
||||
return model_config is not None and \
|
||||
model_config.embedding_mode
|
||||
|
||||
|
||||
def build_explicit_enc_dec_prompt(
|
||||
encoder_prompt: SingletonPromptInputs,
|
||||
decoder_prompt: SingletonPromptInputs,
|
||||
) -> ExplicitEncoderDecoderPrompt:
|
||||
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
|
||||
decoder_prompt=decoder_prompt)
|
||||
|
||||
|
||||
def zip_enc_dec_prompt_lists(
|
||||
enc_prompt_list: List[SingletonPromptInputs],
|
||||
dec_prompt_list: List[SingletonPromptInputs],
|
||||
) -> List[ExplicitEncoderDecoderPrompt]:
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
|
||||
for (encoder_prompt,
|
||||
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
|
||||
]
|
||||
|
||||
|
||||
def to_enc_dec_tuple_list(
|
||||
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
|
||||
) -> List[Tuple[PromptInputs, PromptInputs]]:
|
||||
return [(enc_dec_prompt['encoder_prompt'],
|
||||
enc_dec_prompt['decoder_prompt'])
|
||||
for enc_dec_prompt in enc_dec_prompts]
|
||||
|
472
vllm/worker/enc_dec_model_runner.py
Normal file
472
vllm/worker/enc_dec_model_runner.py
Normal file
@ -0,0 +1,472 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
get_global_forced_attn_backend,
|
||||
global_force_attn_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict)
|
||||
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||
"""
|
||||
Used by the EncoderDecoderModelRunner.
|
||||
"""
|
||||
encoder_input_tokens: Optional[torch.Tensor] = None
|
||||
encoder_input_positions: Optional[torch.Tensor] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"encoder_input_tokens": self.encoder_input_tokens,
|
||||
"encoder_input_positions": self.encoder_input_positions,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
"finished_requests_ids": self.finished_requests_ids,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "EncoderDecoderModelInput":
|
||||
return cast(
|
||||
EncoderDecoderModelInput,
|
||||
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
|
||||
|
||||
|
||||
class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
_model_input_cls: Type[EncoderDecoderModelInput] = (
|
||||
EncoderDecoderModelInput)
|
||||
_builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
):
|
||||
'''
|
||||
EncoderDecoderModelRunner constructor.
|
||||
|
||||
`lora_config`, `multimodal_config`, and prompt_adapter_config are
|
||||
unused (since these features are not yet supported for encoder/decoder
|
||||
models) but these arguments are present here for compatibility with
|
||||
the base-class constructor.
|
||||
'''
|
||||
|
||||
self._maybe_force_supported_attention_backend()
|
||||
|
||||
super().__init__(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
lora_config=None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
# Crash for unsupported encoder/scenarios
|
||||
assert_enc_dec_mr_supported_scenario(self)
|
||||
|
||||
def _maybe_force_supported_attention_backend(self):
|
||||
'''
|
||||
Force vLLM to use the XFormers attention backend,
|
||||
which is currently the only supported option.
|
||||
'''
|
||||
|
||||
def raise_backend_err():
|
||||
# The user has specified an attention backend override
|
||||
# which is invalid for encoder/decoder models
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
|
||||
|
||||
maybe_env_var_forced_backend = get_env_variable_attn_backend()
|
||||
maybe_global_forced_backend = get_global_forced_attn_backend()
|
||||
is_forced_by_global = maybe_global_forced_backend is not None
|
||||
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
||||
|
||||
if not (is_forced_by_global or is_forced_by_env_var):
|
||||
# The user has not already specified an attention backend
|
||||
# override
|
||||
logger.info("EncoderDecoderModelRunner requires "
|
||||
"XFormers backend; overriding backend "
|
||||
"auto-selection and forcing XFormers.")
|
||||
global_force_attn_backend(_Backend.XFORMERS)
|
||||
elif is_forced_by_global:
|
||||
# Backend override enforced by global variable takes
|
||||
# precedence over vLLM backend environment variable.
|
||||
if maybe_global_forced_backend != _Backend.XFORMERS:
|
||||
raise_backend_err()
|
||||
elif is_forced_by_env_var:
|
||||
# Backend override enforced by vLLM backend
|
||||
# environment variable
|
||||
if maybe_env_var_forced_backend != _Backend.XFORMERS:
|
||||
raise_backend_err()
|
||||
|
||||
def _list_to_int32_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.int32, device=self.device)
|
||||
|
||||
def _list_to_long_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def _empty_int32_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_int32_tensor([])
|
||||
|
||||
def _empty_long_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_long_tensor([])
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: EncoderDecoderModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[PoolerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in "
|
||||
"EncoderDecoderModelRunner")
|
||||
|
||||
model_executable = self.model
|
||||
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_seqlen_agnostic else {}
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
encoder_input_ids=model_input.encoder_input_tokens,
|
||||
encoder_positions=model_input.encoder_input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
|
||||
return [output]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
|
||||
return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> EncoderDecoderModelInput:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
Since chunked prefill is not supported for encoder/decoder models,
|
||||
`input_tokens` is assumed to be either entirely prefill tokens or
|
||||
entirely decode tokens.
|
||||
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
|
||||
(
|
||||
attn_metadata,
|
||||
encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor,
|
||||
) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||
model_input))
|
||||
|
||||
# Inject attn_metadata encoder/cross-attention fields &
|
||||
# encoder input tokens/positions into model_input.
|
||||
# Frozen dataclass fields cannot be modified, so use
|
||||
# dataclasses.replace to construct a new model input
|
||||
# instance.
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_input_tokens=encoder_input_tokens_tensor,
|
||||
encoder_input_positions=encoder_input_positions_tensor,
|
||||
)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
self.pin_memory)
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prompt=is_prompt,
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
seq_data, _ = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but got: {len(seq_data.prompt_token_ids)}")
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
encoder_seq_data=seq_data,
|
||||
cross_block_table=None,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
intermediate_tensors = None
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
def _prepare_encoder_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
model_input: EncoderDecoderModelInput,
|
||||
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""Helper method to prepare the encoder- and cross-attn-related
|
||||
model inputs based on a given sequence group. These additional inputs
|
||||
are used to augment an already-computed `EncoderDecoderModelInput`
|
||||
data structure which already has decoder-related model inputs
|
||||
populated.
|
||||
|
||||
Sets the following attn_metadata fields:
|
||||
* `num_encoder_tokens`
|
||||
* `encoder_seq_lens`
|
||||
* `encoder_seq_lens_tensor`
|
||||
* `max_encoder_seq_len`
|
||||
* `cross_slot_mapping`
|
||||
* `cross_block_tables`
|
||||
|
||||
Constructs a new model inputs data structure, based on
|
||||
(1) the existing fields in the `model_inputs` argument,
|
||||
and (2) the following additional fields which are
|
||||
computed (or in the case of `attn_metadata`, updated)
|
||||
by this function:
|
||||
* attn_metadata
|
||||
* encoder_input_tokens
|
||||
* encoder_input_positions
|
||||
|
||||
Arguments:
|
||||
|
||||
* seq_group_metadata_list: list of sequence groups for which to
|
||||
compute inputs
|
||||
* model_inputs: model inputs data structure with decoder-oriented
|
||||
fields already computed.
|
||||
|
||||
Return:
|
||||
|
||||
* Updated model inputs data structure
|
||||
"""
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return (model_input.attn_metadata, None, None)
|
||||
|
||||
# Since we are not supporting chunked prefill either the entire
|
||||
# batch is prefill or it is decode
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
|
||||
# Build encoder inputs
|
||||
encoder_seq_lens: List[int] = []
|
||||
if is_prompt:
|
||||
# Prefill phase.
|
||||
cross_block_tables = self._empty_int32_tensor().view(
|
||||
len(seq_group_metadata_list), -1)
|
||||
|
||||
# Extract input tokens/positions, cross-attention slot-mapping,
|
||||
# & seq len from each sequence group metadata
|
||||
(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
cross_slot_mapping,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Build seq lens
|
||||
seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
|
||||
encoder_seq_lens.append(seq_len)
|
||||
|
||||
# Build slot mapping
|
||||
is_profile_run = (seq_group_metadata.block_tables is None)
|
||||
if is_profile_run:
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
else:
|
||||
for i in range(0, seq_len):
|
||||
block_number = seq_group_metadata.cross_block_table[
|
||||
i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
cross_slot_mapping.append(slot)
|
||||
|
||||
# Build encoder input tokens
|
||||
encoder_input_tokens.extend(token_ids)
|
||||
encoder_input_positions.extend(list(range(0, seq_len)))
|
||||
|
||||
# Convert tokens/positions & cross-attention
|
||||
# slot-mapping to encoder input tensors
|
||||
encoder_input_tokens_tensor = self._list_to_long_tensor(
|
||||
encoder_input_tokens)
|
||||
encoder_input_positions_tensor = self._list_to_long_tensor(
|
||||
encoder_input_positions)
|
||||
cross_slot_mapping_tensor = self._list_to_long_tensor(
|
||||
cross_slot_mapping)
|
||||
|
||||
else:
|
||||
# Decode phase.
|
||||
encoder_input_tokens_tensor = self._empty_long_tensor()
|
||||
encoder_input_positions_tensor = self._empty_long_tensor()
|
||||
cross_slot_mapping_tensor = self._empty_long_tensor()
|
||||
|
||||
# Extract cross-attention block tables &
|
||||
# seq len from each sequence group metadata.
|
||||
# Cross-attention block tables are empty
|
||||
# during vLLM memory profiling.
|
||||
cross_block_tables = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
encoder_seq_lens.append(
|
||||
seq_group_metadata.encoder_seq_data.get_len())
|
||||
cross_block_table = seq_group_metadata.cross_block_table
|
||||
cross_block_tables.append([] if (
|
||||
cross_block_table is None) else cross_block_table)
|
||||
|
||||
# Convert cross-attention block tables to encoder input tensor
|
||||
cross_block_tables = make_tensor_with_pad(
|
||||
cross_block_tables,
|
||||
max_len=max(
|
||||
len(block_table) for block_table in cross_block_tables),
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Compute encoder sequence lengths & encoder
|
||||
# sequence starting offset tensors
|
||||
max_encoder_seq_len = max(encoder_seq_lens, default=0)
|
||||
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
|
||||
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
torch.cumsum(encoder_seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=encoder_seq_start_loc.dtype,
|
||||
out=encoder_seq_start_loc[1:])
|
||||
|
||||
# Update attention metadata with encoder-oriented attributes
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
(
|
||||
attn_metadata.num_encoder_tokens,
|
||||
attn_metadata.encoder_seq_lens,
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_slot_mapping,
|
||||
attn_metadata.cross_block_tables,
|
||||
) = (
|
||||
sum(encoder_seq_lens),
|
||||
encoder_seq_lens,
|
||||
encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len,
|
||||
cross_slot_mapping_tensor,
|
||||
cross_block_tables,
|
||||
)
|
||||
|
||||
return (attn_metadata, encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor)
|
56
vllm/worker/utils.py
Normal file
56
vllm/worker/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
'''
|
||||
Worker-related helper functions.
|
||||
'''
|
||||
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase
|
||||
|
||||
|
||||
def assert_enc_dec_mr_supported_scenario(
|
||||
enc_dec_mr: GPUModelRunnerBase) -> None:
|
||||
'''
|
||||
Asserted that the provided encoder/decoder model runner instance reflects
|
||||
a supported scenario.
|
||||
'''
|
||||
|
||||
if enc_dec_mr.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
|
||||
|
||||
if enc_dec_mr.sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
|
||||
|
||||
if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||
'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
|
||||
|
||||
if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
|
||||
None) is not None:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
|
||||
)
|
||||
|
||||
if enc_dec_mr.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
|
||||
|
||||
if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
|
||||
|
||||
if enc_dec_mr.multimodal_config is not None:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
|
||||
|
||||
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
|
||||
|
||||
if not enc_dec_mr.model_config.enforce_eager:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'])
|
||||
|
||||
if enc_dec_mr.prompt_adapter_config is not None:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])
|
@ -19,8 +19,11 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (is_embedding_model_config,
|
||||
is_encoder_decoder_model_config)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
||||
|
||||
@ -85,8 +88,10 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if model_runner_cls is not None:
|
||||
ModelRunnerClass = model_runner_cls
|
||||
elif self.model_config.embedding_mode:
|
||||
elif self._is_embedding_model():
|
||||
ModelRunnerClass = EmbeddingModelRunner
|
||||
elif self._is_encoder_decoder_model():
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
model_config,
|
||||
parallel_config,
|
||||
@ -107,6 +112,12 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return is_encoder_decoder_model_config(self.model_config)
|
||||
|
||||
def _is_embedding_model(self):
|
||||
return is_embedding_model_config(self.model_config)
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
|
Loading…
x
Reference in New Issue
Block a user