[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
660470e5a3
commit
fd95e026e0
@ -148,8 +148,9 @@ steps:
|
|||||||
- python3 cpu_offload.py
|
- python3 cpu_offload.py
|
||||||
- python3 offline_inference_with_prefix.py
|
- python3 offline_inference_with_prefix.py
|
||||||
- python3 llm_engine_example.py
|
- python3 llm_engine_example.py
|
||||||
- python3 llava_example.py
|
- python3 offline_inference_vision_language.py
|
||||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
|
- python3 offline_inference_encoder_decoder.py
|
||||||
|
|
||||||
- label: Models Test # 1hr10min
|
- label: Models Test # 1hr10min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -289,6 +290,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||||
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
|
||||||
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- pytest -v -s distributed/test_multimodal_broadcast.py
|
- pytest -v -s distributed/test_multimodal_broadcast.py
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
|
99
examples/offline_inference_encoder_decoder.py
Normal file
99
examples/offline_inference_encoder_decoder.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
'''
|
||||||
|
Demonstrate prompting of text-to-text
|
||||||
|
encoder/decoder models, specifically BART
|
||||||
|
'''
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||||
|
from vllm.utils import zip_enc_dec_prompt_lists
|
||||||
|
|
||||||
|
dtype = "float"
|
||||||
|
|
||||||
|
# Create a BART encoder/decoder model instance
|
||||||
|
llm = LLM(
|
||||||
|
model="facebook/bart-large-cnn",
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get BART tokenizer
|
||||||
|
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||||
|
|
||||||
|
# Test prompts
|
||||||
|
#
|
||||||
|
# This section shows all of the valid ways to prompt an
|
||||||
|
# encoder/decoder model.
|
||||||
|
#
|
||||||
|
# - Helpers for building prompts
|
||||||
|
text_prompt_raw = "Hello, my name is"
|
||||||
|
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||||
|
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
|
||||||
|
prompt="The capital of France is"))
|
||||||
|
# - Pass a single prompt to encoder/decoder model
|
||||||
|
# (implicitly encoder input prompt);
|
||||||
|
# decoder input prompt is assumed to be None
|
||||||
|
|
||||||
|
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||||
|
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||||
|
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||||
|
|
||||||
|
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||||
|
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||||
|
# no requirement that they be the same prompt type. Some example prompt-type
|
||||||
|
# combinations are shown below, note that these are not exhaustive.
|
||||||
|
|
||||||
|
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||||
|
# Pass encoder prompt string directly, &
|
||||||
|
# pass decoder prompt tokens
|
||||||
|
encoder_prompt=single_text_prompt_raw,
|
||||||
|
decoder_prompt=single_tokens_prompt,
|
||||||
|
)
|
||||||
|
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||||
|
# Pass TextPrompt to encoder, and
|
||||||
|
# pass decoder prompt string directly
|
||||||
|
encoder_prompt=single_text_prompt,
|
||||||
|
decoder_prompt=single_text_prompt_raw,
|
||||||
|
)
|
||||||
|
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||||
|
# Pass encoder prompt tokens directly, and
|
||||||
|
# pass TextPrompt to decoder
|
||||||
|
encoder_prompt=single_tokens_prompt,
|
||||||
|
decoder_prompt=single_text_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# - Finally, here's a useful helper function for zipping encoder and
|
||||||
|
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
|
||||||
|
# instances
|
||||||
|
zipped_prompt_list = zip_enc_dec_prompt_lists(
|
||||||
|
['An encoder prompt', 'Another encoder prompt'],
|
||||||
|
['A decoder prompt', 'Another decoder prompt'])
|
||||||
|
|
||||||
|
# - Let's put all of the above example prompts together into one list
|
||||||
|
# which we will pass to the encoder/decoder LLM.
|
||||||
|
prompts = [
|
||||||
|
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||||
|
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||||
|
] + zipped_prompt_list
|
||||||
|
|
||||||
|
print(prompts)
|
||||||
|
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
top_p=1.0,
|
||||||
|
min_tokens=0,
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate output tokens from the prompts. The output is a list of
|
||||||
|
# RequestOutput objects that contain the prompt, generated
|
||||||
|
# text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
encoder_prompt = output.encoder_prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Encoder prompt: {encoder_prompt!r}, "
|
||||||
|
f"Decoder prompt: {prompt!r}, "
|
||||||
|
f"Generated text: {generated_text!r}")
|
@ -10,9 +10,11 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer, BatchEncoding, BatchFeature)
|
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
|
||||||
|
BatchFeature)
|
||||||
|
|
||||||
|
from tests.models.utils import DecoderPromptType
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
@ -21,9 +23,11 @@ from vllm.distributed import (destroy_distributed_environment,
|
|||||||
destroy_model_parallel)
|
destroy_model_parallel)
|
||||||
from vllm.inputs import TextPrompt
|
from vllm.inputs import TextPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||||
is_cpu)
|
is_cpu, to_enc_dec_tuple_list,
|
||||||
|
zip_enc_dec_prompt_lists)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -120,6 +124,40 @@ def example_prompts() -> List[str]:
|
|||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_encoder_decoder_prompts() \
|
||||||
|
-> Dict[DecoderPromptType,
|
||||||
|
Tuple[List[str], List[Optional[str]]]]:
|
||||||
|
'''
|
||||||
|
Returns an encoder prompt list and a decoder prompt list, wherein each pair
|
||||||
|
of same-index entries in both lists corresponds to an (encoder prompt,
|
||||||
|
decoder prompt) tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Encoder prompt list
|
||||||
|
* Decoder prompt list (reverse of encoder prompt list)
|
||||||
|
'''
|
||||||
|
|
||||||
|
encoder_prompts = []
|
||||||
|
for filename in _TEST_PROMPTS:
|
||||||
|
encoder_prompts += _read_prompts(filename)
|
||||||
|
|
||||||
|
custom_decoder_prompts = encoder_prompts[::-1]
|
||||||
|
empty_str_decoder_prompts = [""] * len(encoder_prompts)
|
||||||
|
none_decoder_prompts = [None] * len(encoder_prompts)
|
||||||
|
|
||||||
|
# NONE decoder prompt type
|
||||||
|
return {
|
||||||
|
DecoderPromptType.NONE:
|
||||||
|
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
|
||||||
|
DecoderPromptType.EMPTY_STR:
|
||||||
|
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
|
||||||
|
DecoderPromptType.CUSTOM:
|
||||||
|
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def example_long_prompts() -> List[str]:
|
def example_long_prompts() -> List[str]:
|
||||||
prompts = []
|
prompts = []
|
||||||
@ -152,6 +190,7 @@ class HfRunner:
|
|||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
is_embedding_model: bool = False,
|
is_embedding_model: bool = False,
|
||||||
is_vision_model: bool = False,
|
is_vision_model: bool = False,
|
||||||
|
is_encoder_decoder_model: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
|
||||||
@ -168,6 +207,8 @@ class HfRunner:
|
|||||||
else:
|
else:
|
||||||
if is_vision_model:
|
if is_vision_model:
|
||||||
auto_cls = AutoModelForVision2Seq
|
auto_cls = AutoModelForVision2Seq
|
||||||
|
elif is_encoder_decoder_model:
|
||||||
|
auto_cls = AutoModelForSeq2SeqLM
|
||||||
else:
|
else:
|
||||||
auto_cls = AutoModelForCausalLM
|
auto_cls = AutoModelForCausalLM
|
||||||
|
|
||||||
@ -314,6 +355,44 @@ class HfRunner:
|
|||||||
all_logprobs.append(seq_logprobs)
|
all_logprobs.append(seq_logprobs)
|
||||||
return all_logprobs
|
return all_logprobs
|
||||||
|
|
||||||
|
def _hidden_states_to_logprobs(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
num_logprobs,
|
||||||
|
) -> Tuple[List[Dict[int, float]], int]:
|
||||||
|
seq_logprobs: List[torch.Tensor] = []
|
||||||
|
output_len = len(hidden_states)
|
||||||
|
for _, hidden_state in enumerate(hidden_states):
|
||||||
|
last_hidden_states = hidden_state[-1][0]
|
||||||
|
logits = torch.matmul(
|
||||||
|
last_hidden_states,
|
||||||
|
self.model.get_output_embeddings().weight.t(),
|
||||||
|
)
|
||||||
|
if getattr(self.model.get_output_embeddings(), "bias",
|
||||||
|
None) is not None:
|
||||||
|
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
|
||||||
|
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
|
seq_logprobs.append(logprobs)
|
||||||
|
|
||||||
|
# convert to dict
|
||||||
|
seq_logprobs_lst: List[Dict[int, float]] = []
|
||||||
|
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||||
|
# drop prompt logprobs
|
||||||
|
if tok_idx == 0:
|
||||||
|
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
|
||||||
|
topk = tok_logprobs.topk(num_logprobs)
|
||||||
|
|
||||||
|
tok_logprobs_dct = {}
|
||||||
|
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
|
||||||
|
tok_logprobs_dct[token_id.item()] = logprob.item()
|
||||||
|
|
||||||
|
seq_logprobs_lst.append(tok_logprobs_dct)
|
||||||
|
|
||||||
|
return (
|
||||||
|
seq_logprobs_lst,
|
||||||
|
output_len,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_greedy_logprobs_limit(
|
def generate_greedy_logprobs_limit(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -346,33 +425,11 @@ class HfRunner:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_logprobs: List[torch.Tensor] = []
|
(
|
||||||
for _, hidden_states in enumerate(output.hidden_states):
|
seq_logprobs_lst,
|
||||||
last_hidden_states = hidden_states[-1][0]
|
output_len,
|
||||||
logits = torch.matmul(
|
) = self._hidden_states_to_logprobs(output.hidden_states,
|
||||||
last_hidden_states,
|
num_logprobs)
|
||||||
self.model.get_output_embeddings().weight.t(),
|
|
||||||
)
|
|
||||||
if getattr(self.model.get_output_embeddings(), "bias",
|
|
||||||
None) is not None:
|
|
||||||
logits += self.model.get_output_embeddings(
|
|
||||||
).bias.unsqueeze(0)
|
|
||||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
|
||||||
seq_logprobs.append(logprobs)
|
|
||||||
|
|
||||||
# convert to dict
|
|
||||||
seq_logprobs_lst: List[Dict[int, float]] = []
|
|
||||||
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
|
||||||
# drop prompt logprobs
|
|
||||||
if tok_idx == 0:
|
|
||||||
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
|
|
||||||
topk = tok_logprobs.topk(num_logprobs)
|
|
||||||
|
|
||||||
tok_logprobs_dct = {}
|
|
||||||
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
|
|
||||||
tok_logprobs_dct[token_id.item()] = logprob.item()
|
|
||||||
|
|
||||||
seq_logprobs_lst.append(tok_logprobs_dct)
|
|
||||||
|
|
||||||
all_logprobs.append(seq_logprobs_lst)
|
all_logprobs.append(seq_logprobs_lst)
|
||||||
seq_ids = output.sequences[0]
|
seq_ids = output.sequences[0]
|
||||||
@ -385,6 +442,57 @@ class HfRunner:
|
|||||||
return [(output_ids, output_str, output_logprobs)
|
return [(output_ids, output_str, output_logprobs)
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
|
def generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
|
self,
|
||||||
|
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
||||||
|
'''
|
||||||
|
Greedy logprobs generation for vLLM encoder/decoder models
|
||||||
|
'''
|
||||||
|
|
||||||
|
all_logprobs: List[List[Dict[int, float]]] = []
|
||||||
|
all_output_ids: List[List[int]] = []
|
||||||
|
all_output_strs: List[str] = []
|
||||||
|
|
||||||
|
for (encoder_prompt,
|
||||||
|
decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
|
||||||
|
encoder_input_ids = self.wrap_device(
|
||||||
|
self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
|
||||||
|
decoder_input_ids = (
|
||||||
|
None if decoder_prompt is None else self.wrap_device(
|
||||||
|
self.tokenizer(decoder_prompt,
|
||||||
|
return_tensors="pt").input_ids))
|
||||||
|
|
||||||
|
output = self.model.generate(
|
||||||
|
encoder_input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
use_cache=True,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
seq_logprobs_lst,
|
||||||
|
output_len,
|
||||||
|
) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
|
||||||
|
num_logprobs)
|
||||||
|
|
||||||
|
all_logprobs.append(seq_logprobs_lst)
|
||||||
|
seq_ids = output.sequences[0]
|
||||||
|
output_ids = seq_ids[-output_len:]
|
||||||
|
all_output_ids.append(output_ids.tolist())
|
||||||
|
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||||
|
|
||||||
|
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||||
|
return [(output_ids, output_str, output_logprobs)
|
||||||
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
||||||
return self.model.encode(prompts)
|
return self.model.encode(prompts)
|
||||||
|
|
||||||
@ -416,7 +524,7 @@ class VllmRunner:
|
|||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
enable_chunked_prefill: bool = False,
|
enable_chunked_prefill: bool = False,
|
||||||
swap_space: int = 4,
|
swap_space: int = 4,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
@ -465,6 +573,19 @@ class VllmRunner:
|
|||||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def _final_steps_generate_w_logprobs(
|
||||||
|
self,
|
||||||
|
req_outputs: List[RequestOutput],
|
||||||
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
|
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||||
|
for req_output in req_outputs:
|
||||||
|
for sample in req_output.outputs:
|
||||||
|
output_str = sample.text
|
||||||
|
output_ids = sample.token_ids
|
||||||
|
output_logprobs = sample.logprobs
|
||||||
|
outputs.append((output_ids, output_str, output_logprobs))
|
||||||
|
return outputs
|
||||||
|
|
||||||
def generate_w_logprobs(
|
def generate_w_logprobs(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -483,14 +604,21 @@ class VllmRunner:
|
|||||||
|
|
||||||
req_outputs = self.model.generate(inputs,
|
req_outputs = self.model.generate(inputs,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
return self._final_steps_generate_w_logprobs(req_outputs)
|
||||||
for req_output in req_outputs:
|
|
||||||
for sample in req_output.outputs:
|
def generate_encoder_decoder_w_logprobs(
|
||||||
output_str = sample.text
|
self,
|
||||||
output_ids = sample.token_ids
|
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||||
output_logprobs = sample.logprobs
|
sampling_params: SamplingParams,
|
||||||
outputs.append((output_ids, output_str, output_logprobs))
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
return outputs
|
'''
|
||||||
|
Logprobs generation for vLLM encoder/decoder models
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert sampling_params.logprobs is not None
|
||||||
|
req_outputs = self.model.generate(encoder_decoder_prompts,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
return self._final_steps_generate_w_logprobs(req_outputs)
|
||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
@ -523,6 +651,26 @@ class VllmRunner:
|
|||||||
return [(output_ids, output_str, output_logprobs)
|
return [(output_ids, output_str, output_logprobs)
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
|
def generate_encoder_decoder_greedy_logprobs(
|
||||||
|
self,
|
||||||
|
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
|
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||||
|
use_beam_search=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
logprobs=num_logprobs)
|
||||||
|
'''
|
||||||
|
Greedy logprobs generation for vLLM encoder/decoder models
|
||||||
|
'''
|
||||||
|
|
||||||
|
outputs = self.generate_encoder_decoder_w_logprobs(
|
||||||
|
encoder_decoder_prompts, greedy_logprobs_params)
|
||||||
|
|
||||||
|
return [(output_ids, output_str, output_logprobs)
|
||||||
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
|
@ -9,33 +9,11 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
|||||||
from vllm.core.interfaces import AllocStatus
|
from vllm.core.interfaces import AllocStatus
|
||||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
|
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||||
|
|
||||||
from .utils import create_dummy_prompt
|
from .utils import (append_new_token, append_new_token_seq_group,
|
||||||
|
create_dummy_prompt, get_sequence_groups,
|
||||||
|
schedule_and_update_computed_tokens)
|
||||||
def get_sequence_groups(scheduler_output):
|
|
||||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
|
||||||
|
|
||||||
|
|
||||||
def append_new_token(out, token_id: int):
|
|
||||||
seq_groups = get_sequence_groups(out)
|
|
||||||
for seq_group in seq_groups:
|
|
||||||
for seq in seq_group.get_seqs():
|
|
||||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
|
||||||
|
|
||||||
|
|
||||||
def schedule_and_update_computed_tokens(scheduler):
|
|
||||||
metas, out = scheduler.schedule()
|
|
||||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
|
||||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
|
||||||
return metas, out
|
|
||||||
|
|
||||||
|
|
||||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
|
||||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
|
||||||
for seq in seq_group.get_seqs():
|
|
||||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_add_seq_group():
|
def test_scheduler_add_seq_group():
|
||||||
|
99
tests/core/test_scheduler_encoder_decoder.py
Normal file
99
tests/core/test_scheduler_encoder_decoder.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest # noqa
|
||||||
|
|
||||||
|
from vllm.config import CacheConfig, SchedulerConfig
|
||||||
|
from vllm.core.scheduler import Scheduler
|
||||||
|
from vllm.sequence import SequenceGroup
|
||||||
|
|
||||||
|
from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
|
||||||
|
get_sequence_groups, schedule_and_update_computed_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_schedule_simple_encoder_decoder():
|
||||||
|
'''
|
||||||
|
Test basic scheduler functionality in the context
|
||||||
|
of an encoder/decoder model. Focus on testing
|
||||||
|
enc/dec-specific functionality sense tests already
|
||||||
|
exist for decoder-only functionality
|
||||||
|
|
||||||
|
Test behavior:
|
||||||
|
* Construct Scheduler
|
||||||
|
* Construct dummy encoder/decoder sequence groups
|
||||||
|
* Add dummy seq groups to scheduler backlog
|
||||||
|
* Schedule the next seq group & validate:
|
||||||
|
* Cross-attn block tables
|
||||||
|
* Updated states of seq groups
|
||||||
|
* Number of batched tokens
|
||||||
|
* Number of blocks to copy/swap-in/swap-out
|
||||||
|
* Number of scheduled seq groups
|
||||||
|
* Repeat for both prefill- and decode-phase
|
||||||
|
* Abort scheduled seq groups
|
||||||
|
* Assert that aborted seq groups no longer appear in
|
||||||
|
cross-attention block table
|
||||||
|
'''
|
||||||
|
|
||||||
|
block_size = 4
|
||||||
|
num_seq_group = 4
|
||||||
|
max_model_len = 16
|
||||||
|
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
|
||||||
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
|
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
|
||||||
|
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
|
||||||
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
|
running: List[SequenceGroup] = []
|
||||||
|
|
||||||
|
# Add seq groups to scheduler.
|
||||||
|
req_id_list = []
|
||||||
|
for i in range(num_seq_group):
|
||||||
|
req_id = str(i)
|
||||||
|
req_id_list.append(req_id)
|
||||||
|
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||||
|
req_id, block_size, block_size, block_size)
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
running.append(seq_group)
|
||||||
|
|
||||||
|
# Schedule seq groups prefill.
|
||||||
|
num_tokens = block_size * num_seq_group
|
||||||
|
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
# - Verify that sequence group cross-attention block tables are
|
||||||
|
# registered with the block manager
|
||||||
|
assert all([(req_id in scheduler.block_manager.cross_block_tables)
|
||||||
|
for req_id in req_id_list])
|
||||||
|
# - Validate sequence-group status
|
||||||
|
assert set(get_sequence_groups(out)) == set(running)
|
||||||
|
# - Validate number of batched tokens
|
||||||
|
assert out.num_batched_tokens == num_tokens
|
||||||
|
# - Validate there are no remaining blocks to swap
|
||||||
|
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||||
|
and not out.blocks_to_swap_out)
|
||||||
|
# - Validate all seq groups were scheduled
|
||||||
|
assert len(seq_group_meta_list) == num_seq_group
|
||||||
|
append_new_token(out, 1)
|
||||||
|
|
||||||
|
# Schedule seq groups decode.
|
||||||
|
seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
# - Verify that sequence group metadata includes encoder attention
|
||||||
|
# and cross-attention metadata
|
||||||
|
assert all([
|
||||||
|
not ((seq_group_meta.encoder_seq_data is None) or
|
||||||
|
(seq_group_meta.cross_block_table is None))
|
||||||
|
for seq_group_meta in seq_group_meta_list
|
||||||
|
])
|
||||||
|
# - Validate sequence-group status
|
||||||
|
assert set(get_sequence_groups(out)) == set(running)
|
||||||
|
# - Validate there is one batched token per seq group
|
||||||
|
assert out.num_batched_tokens == num_seq_group
|
||||||
|
# - Validate there are no remaining blocks to swap
|
||||||
|
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||||
|
and not out.blocks_to_swap_out)
|
||||||
|
# - Validate that all seq groups were scheduled
|
||||||
|
assert len(seq_group_meta_list) == num_seq_group
|
||||||
|
append_new_token(out, 1)
|
||||||
|
|
||||||
|
# Abort sequences
|
||||||
|
for req_id in req_id_list:
|
||||||
|
scheduler.abort_seq_group(req_id)
|
||||||
|
# - Verify that sequence group cross-attention block tables are
|
||||||
|
# NO LONGER registered with the block manager
|
||||||
|
assert req_id not in scheduler.block_manager.cross_block_tables
|
@ -53,27 +53,30 @@ def create_dummy_prompt_encoder_decoder(
|
|||||||
block_size = decoder_prompt_length
|
block_size = decoder_prompt_length
|
||||||
|
|
||||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||||
# and prompt "0 ... block_size".
|
# and prompt "0 ... block_size". Note that the prompt string
|
||||||
|
# doesn't actually match the tokens
|
||||||
decoder_prompt_tokens = list(range(decoder_prompt_length))
|
decoder_prompt_tokens = list(range(decoder_prompt_length))
|
||||||
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
|
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
|
||||||
|
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
|
||||||
|
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
|
||||||
|
|
||||||
decoder_prompt = Sequence(int(request_id),
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt": decoder_prompt_str,
|
"prompt": decoder_prompt_str,
|
||||||
"prompt_token_ids": decoder_prompt_tokens,
|
"prompt_token_ids": decoder_prompt_tokens,
|
||||||
|
"encoder_prompt": encoder_prompt_str,
|
||||||
|
"encoder_prompt_token_ids": encoder_prompt_tokens,
|
||||||
"multi_modal_data": None,
|
"multi_modal_data": None,
|
||||||
},
|
}
|
||||||
block_size=block_size)
|
|
||||||
|
decoder_prompt = Sequence(int(request_id),
|
||||||
|
inputs=inputs,
|
||||||
|
block_size=block_size,
|
||||||
|
from_decoder_prompt=True)
|
||||||
|
|
||||||
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
|
|
||||||
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
|
|
||||||
encoder_prompt = Sequence(int(request_id),
|
encoder_prompt = Sequence(int(request_id),
|
||||||
inputs={
|
inputs=inputs,
|
||||||
"prompt": encoder_prompt_str,
|
block_size=block_size,
|
||||||
"prompt_token_ids": encoder_prompt_tokens,
|
from_decoder_prompt=False)
|
||||||
"multi_modal_data": None,
|
|
||||||
},
|
|
||||||
block_size=block_size)
|
|
||||||
seq_group = SequenceGroup(request_id=request_id,
|
seq_group = SequenceGroup(request_id=request_id,
|
||||||
seqs=[decoder_prompt],
|
seqs=[decoder_prompt],
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
@ -139,17 +142,21 @@ def create_seq_group_encoder_decoder(
|
|||||||
|
|
||||||
prompt_token_ids = [0] * seq_prompt_len
|
prompt_token_ids = [0] * seq_prompt_len
|
||||||
|
|
||||||
seqs = []
|
|
||||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
|
||||||
seq = Sequence(
|
|
||||||
seq_id=seq_id_start + seq_id_offset,
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt": "",
|
"prompt": "",
|
||||||
"prompt_token_ids": prompt_token_ids,
|
"prompt_token_ids": prompt_token_ids,
|
||||||
|
"encoder_prompt": "",
|
||||||
|
"encoder_prompt_token_ids": prompt_token_ids,
|
||||||
"multi_modal_data": None,
|
"multi_modal_data": None,
|
||||||
},
|
}
|
||||||
|
|
||||||
|
seqs = []
|
||||||
|
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||||
|
# Construct decoder input sequences
|
||||||
|
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
|
||||||
|
inputs=inputs,
|
||||||
block_size=16,
|
block_size=16,
|
||||||
)
|
from_decoder_prompt=True)
|
||||||
|
|
||||||
for i in range(output_len):
|
for i in range(output_len):
|
||||||
seq.append_token_id(
|
seq.append_token_id(
|
||||||
@ -158,16 +165,11 @@ def create_seq_group_encoder_decoder(
|
|||||||
)
|
)
|
||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
|
|
||||||
# Encoder sequence
|
# Encoder input sequence
|
||||||
encoder_seq = Sequence(
|
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
|
||||||
seq_id=seq_id_start + len(seq_output_lens),
|
inputs=inputs,
|
||||||
inputs={
|
|
||||||
"prompt": "",
|
|
||||||
"prompt_token_ids": prompt_token_ids,
|
|
||||||
"multi_modal_data": None,
|
|
||||||
},
|
|
||||||
block_size=16,
|
block_size=16,
|
||||||
)
|
from_decoder_prompt=False)
|
||||||
|
|
||||||
return SequenceGroup(request_id=request_id,
|
return SequenceGroup(request_id=request_id,
|
||||||
seqs=seqs,
|
seqs=seqs,
|
||||||
@ -178,3 +180,30 @@ def create_seq_group_encoder_decoder(
|
|||||||
|
|
||||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||||
return (seq_len + block_size - 1) // block_size
|
return (seq_len + block_size - 1) // block_size
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions for scheduler tests
|
||||||
|
|
||||||
|
|
||||||
|
def get_sequence_groups(scheduler_output):
|
||||||
|
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||||
|
|
||||||
|
|
||||||
|
def append_new_token(out, token_id: int):
|
||||||
|
seq_groups = get_sequence_groups(out)
|
||||||
|
for seq_group in seq_groups:
|
||||||
|
for seq in seq_group.get_seqs():
|
||||||
|
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_and_update_computed_tokens(scheduler):
|
||||||
|
metas, out = scheduler.schedule()
|
||||||
|
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||||
|
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||||
|
return metas, out
|
||||||
|
|
||||||
|
|
||||||
|
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||||
|
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||||
|
for seq in seq_group.get_seqs():
|
||||||
|
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||||
|
101
tests/distributed/test_basic_distributed_correctness_enc_dec.py
Normal file
101
tests/distributed/test_basic_distributed_correctness_enc_dec.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
"""For encoder/decoder models only:
|
||||||
|
Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
```sh
|
||||||
|
cd $VLLM_PATH/tests
|
||||||
|
|
||||||
|
pytest distributed/test_basic_distributed_correctness_enc_dec.py
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.models.utils import DecoderPromptType
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
|
from ..models.utils import check_logprobs_close
|
||||||
|
from ..utils import fork_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize("model, distributed_executor_backend", [
|
||||||
|
("facebook/bart-large-cnn", "ray"),
|
||||||
|
("facebook/bart-large-cnn", "mp"),
|
||||||
|
])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_models(
|
||||||
|
model: str,
|
||||||
|
distributed_executor_backend: str,
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_encoder_decoder_prompts,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Test vLLM BART inference on more than one GPU, comparing
|
||||||
|
outputs against HF as a baseline.
|
||||||
|
|
||||||
|
Fork a new process for each test, to prevent CUDA from
|
||||||
|
being re-initialized by successive tests within the same
|
||||||
|
process.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* model: the HF ID of the specific BART variant under test
|
||||||
|
* distributed_executor_backend
|
||||||
|
* hf_runner: HuggingFace (HF) test model runner
|
||||||
|
* vllm_runner: vLLM test model runner
|
||||||
|
* example_encoder_decoder_prompts: test fixture which provides a
|
||||||
|
dictionary of dummy prompts
|
||||||
|
'''
|
||||||
|
|
||||||
|
dtype = "float"
|
||||||
|
max_tokens = 64
|
||||||
|
num_logprobs = 5
|
||||||
|
|
||||||
|
# Example inputs with non-trivial (i.e. not None/empty) encoder &
|
||||||
|
# decoder prompts.
|
||||||
|
test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM]
|
||||||
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||||
|
test_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
# Configuration settings for HF baseline
|
||||||
|
hf_kwargs = {
|
||||||
|
"top_k": None,
|
||||||
|
"num_beams": 1,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"early_stopping": False,
|
||||||
|
"no_repeat_ngram_size": None,
|
||||||
|
"min_length": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
is_encoder_decoder_model=True) as hf_model:
|
||||||
|
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
|
test_prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs,
|
||||||
|
**hf_kwargs,
|
||||||
|
))
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
@ -3,9 +3,9 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
override_backend_env_variable)
|
|
||||||
from vllm.attention.selector import which_attn_to_use
|
from vllm.attention.selector import which_attn_to_use
|
||||||
|
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -4,8 +4,6 @@ Tests:
|
|||||||
* E2E test of Encoder attention + Decoder self-attention +
|
* E2E test of Encoder attention + Decoder self-attention +
|
||||||
Encoder/decoder cross-attention (collectively
|
Encoder/decoder cross-attention (collectively
|
||||||
"encoder/decoder attention")
|
"encoder/decoder attention")
|
||||||
* Confirm enc/dec models will fail for chunked prefill
|
|
||||||
* Confirm enc/dec models will fail for prefix caching
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -15,19 +13,22 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils import *
|
from tests.kernels.utils import *
|
||||||
from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
|
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
AttentionType)
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
|
||||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||||
|
from vllm.attention.selector import (_Backend,
|
||||||
|
global_force_attn_backend_context_manager)
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
# List of support backends for encoder/decoder models
|
||||||
|
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
|
||||||
|
|
||||||
HEAD_SIZES = [64, 256]
|
HEAD_SIZES = [64, 256]
|
||||||
|
|
||||||
NUM_HEADS = [1, 16]
|
NUM_HEADS = [1, 16]
|
||||||
|
|
||||||
BATCH_SIZES = [1, 16]
|
BATCH_SIZES = [1, 16]
|
||||||
BLOCK_SIZES = [16]
|
BLOCK_SIZES = [16]
|
||||||
BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
|
|
||||||
CUDA_DEVICE = "cuda:0"
|
CUDA_DEVICE = "cuda:0"
|
||||||
|
|
||||||
MAX_DEC_SEQ_LENS = [128]
|
MAX_DEC_SEQ_LENS = [128]
|
||||||
@ -724,23 +725,58 @@ def _run_encoder_decoder_cross_attention_test(
|
|||||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
|
||||||
def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
def test_encoder_only(
|
||||||
batch_size: int, block_size: int, max_dec_seq_len: int,
|
num_heads: int,
|
||||||
max_enc_seq_len: int, monkeypatch):
|
head_size: int,
|
||||||
|
attn_backend: _Backend,
|
||||||
|
batch_size: int,
|
||||||
|
block_size: int,
|
||||||
|
max_dec_seq_len: int,
|
||||||
|
max_enc_seq_len: int,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
End-to-end encoder-only attention test:
|
||||||
|
|
||||||
|
* Construct fake test vectors for (1) encoder attention
|
||||||
|
* Construct (1) attention metadata structure with prefill-phase
|
||||||
|
encoder attention, and (2) an analogous attention metadata
|
||||||
|
structure but for decode-phase
|
||||||
|
* Test & validate encoder attention against ideal output
|
||||||
|
|
||||||
|
No KV cache is required for encoder-only attention.
|
||||||
|
|
||||||
|
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||||
|
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||||
|
|
||||||
|
This test globally forces an override of the usual backend
|
||||||
|
auto-selection process, forcing the specific backend-under-test
|
||||||
|
to be utilized.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* num_heads
|
||||||
|
* head_size,
|
||||||
|
* attn_backend: The attention backend to employ for testing
|
||||||
|
* batch_size
|
||||||
|
* block_size: KV cache block size
|
||||||
|
* max_dec_seq_len: max length of decoder input sequences
|
||||||
|
* max_enc_seq_len: max length of encoder input sequences
|
||||||
|
'''
|
||||||
|
|
||||||
# Force Attention wrapper backend
|
# Force Attention wrapper backend
|
||||||
override_backend_env_variable(monkeypatch, backend_name)
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
|
|
||||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||||
# to be more than necessary, since exceeding the kv cache size
|
# to be more than necessary, since exceeding the kv cache size
|
||||||
# is not part of this test
|
# is not part of this test
|
||||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
batch_size, block_size, max_dec_seq_len,
|
||||||
|
max_enc_seq_len, 4096)
|
||||||
|
|
||||||
# Attention scale factor, attention backend instance, attention wrapper
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
# instance, KV cache init
|
# instance, KV cache init
|
||||||
@ -774,7 +810,7 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
|||||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
|
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
|
||||||
@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
|
|||||||
def test_e2e_enc_dec_attn(
|
def test_e2e_enc_dec_attn(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
backend_name: str,
|
attn_backend: _Backend,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
max_dec_seq_len: int,
|
max_dec_seq_len: int,
|
||||||
max_enc_seq_len: int,
|
max_enc_seq_len: int,
|
||||||
monkeypatch,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
End-to-end encoder/decoder test:
|
End-to-end encoder/decoder test:
|
||||||
@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn(
|
|||||||
cross-attention K/Vs are allowed to differ in seq len, as is often the case
|
cross-attention K/Vs are allowed to differ in seq len, as is often the case
|
||||||
for cross-attention.
|
for cross-attention.
|
||||||
|
|
||||||
This test utilizes PyTest monkey patching to force the attention backend
|
This test globally forces an override of the usual backend
|
||||||
via an environment variable.
|
auto-selection process, forcing the specific backend-under-test
|
||||||
|
to be utilized.
|
||||||
|
|
||||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||||
@ -830,23 +866,34 @@ def test_e2e_enc_dec_attn(
|
|||||||
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
||||||
and a single one shared by all decode-phase attention operations
|
and a single one shared by all decode-phase attention operations
|
||||||
(decoder & enc/dec cross.) This is intended to reflect the behavior
|
(decoder & enc/dec cross.) This is intended to reflect the behavior
|
||||||
of ModelRunner, which constructs a single attention metadata structure for
|
of EncoderDecoderModelRunner, which constructs a single attention metadata
|
||||||
each prefill or decode run. A realistic scenario would rely on the
|
structure for each prefill or decode run. A realistic scenario would rely
|
||||||
attention backend to utilize the appropriate attention metadata fields
|
on the attention backend to utilize the appropriate attention metadata
|
||||||
according to the value of attn_metadata.attention_type. Thus, this test is
|
fields according to the value of attn_metadata.attention_type. Thus,
|
||||||
organized so as to confirm that the backend-under-test can handle a
|
this test is organized so as to confirm that the backend-under-test can
|
||||||
shared prefill attention metadata structure & a shared decode attention
|
handle a shared prefill attention metadata structure & a shared decode\
|
||||||
metadata structure.
|
attention metadata structure.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* num_heads
|
||||||
|
* head_size,
|
||||||
|
* attn_backend: The attention backend to employ for testing
|
||||||
|
* batch_size
|
||||||
|
* block_size: KV cache block size
|
||||||
|
* max_dec_seq_len: max length of decoder input sequences
|
||||||
|
* max_enc_seq_len: max length of encoder input sequences
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Force Attention wrapper backend
|
# Force Attention wrapper backend
|
||||||
override_backend_env_variable(monkeypatch, backend_name)
|
with global_force_attn_backend_context_manager(attn_backend):
|
||||||
|
|
||||||
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
|
||||||
# to be more than necessary, since exceeding the kv cache size
|
# to be more than necessary, since exceeding the kv cache size
|
||||||
# is not part of this test
|
# is not part of this test
|
||||||
test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
|
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
|
||||||
block_size, max_dec_seq_len, max_enc_seq_len, 4096)
|
batch_size, block_size, max_dec_seq_len,
|
||||||
|
max_enc_seq_len, 4096)
|
||||||
|
|
||||||
# Attention scale factor, attention backend instance, attention wrapper
|
# Attention scale factor, attention backend instance, attention wrapper
|
||||||
# instance, KV cache init
|
# instance, KV cache init
|
||||||
@ -870,8 +917,9 @@ def test_e2e_enc_dec_attn(
|
|||||||
cross_block_base_addr,
|
cross_block_base_addr,
|
||||||
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
) = _decoder_attn_setup(test_pt, test_rsrcs)
|
||||||
|
|
||||||
# Construct encoder/decoder cross-attention prefill-phase & decode-phase
|
# Construct encoder/decoder cross-attention prefill-phase
|
||||||
# test params, including key/value tensors, cross-attention memory-mapping
|
# & decode-phase test params, including key/value tensors,
|
||||||
|
# cross-attention memory-mapping
|
||||||
|
|
||||||
(
|
(
|
||||||
prephase_cross_test_params,
|
prephase_cross_test_params,
|
||||||
|
@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
|
|||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
soft_cap=soft_cap,
|
soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
@ -8,24 +8,10 @@ from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||||
AttentionMetadata, AttentionType)
|
|
||||||
from vllm.attention.backends.xformers import XFormersBackend
|
from vllm.attention.backends.xformers import XFormersBackend
|
||||||
from vllm.utils import make_tensor_with_pad
|
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||||
|
make_tensor_with_pad)
|
||||||
# String name of register which may be set in order to
|
|
||||||
# force auto-selection of attention backend by Attention
|
|
||||||
# wrapper
|
|
||||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
|
||||||
|
|
||||||
# Possible string values of STR_BACKEND_ENV_VAR
|
|
||||||
# register, corresponding to possible backends
|
|
||||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
|
||||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
|
||||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
|
||||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
|
||||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
|
||||||
STR_INVALID_VAL: str = "INVALID"
|
|
||||||
|
|
||||||
|
|
||||||
class QKVInputs(NamedTuple):
|
class QKVInputs(NamedTuple):
|
||||||
|
153
tests/models/test_bart.py
Normal file
153
tests/models/test_bart.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM for BART models using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_bart.py`.
|
||||||
|
"""
|
||||||
|
from vllm.utils import is_cpu
|
||||||
|
|
||||||
|
if not is_cpu():
|
||||||
|
# CPU backend is not currently supported with encoder/decoder models
|
||||||
|
# skip test definitions entirely to avoid importing GPU kernel libs
|
||||||
|
# (xFormers, etc.)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.models.utils import DecoderPromptType
|
||||||
|
|
||||||
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
|
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
|
||||||
|
|
||||||
|
DECODER_PROMPT_TYPES = ([
|
||||||
|
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
|
||||||
|
DecoderPromptType.NONE
|
||||||
|
])
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_encoder_decoder_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
decoder_prompt_type: DecoderPromptType,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Test the vLLM BART model for a variety of encoder/decoder input prompts,
|
||||||
|
by validating it against HuggingFace (HF) BART.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* hf_runner: HuggingFace (HF) test model runner
|
||||||
|
* vllm_runner: vLLM test model runner
|
||||||
|
* example_encoder_decoder_prompts: test fixture which provides a
|
||||||
|
dictionary of dummy prompts
|
||||||
|
* model: the HF ID of the specific BART variant under test
|
||||||
|
* dtype: the tensor datatype to employ
|
||||||
|
* max_tokens
|
||||||
|
* num_logprobs
|
||||||
|
* decoder_prompt_type: key into the example_encoder_decoder_prompts
|
||||||
|
dictionary; selects specific encoder/decoder
|
||||||
|
prompt scenarios to test
|
||||||
|
|
||||||
|
A note on using HF BART as a baseline for validating vLLM BART,
|
||||||
|
specifically when the decoder prompt is None.
|
||||||
|
|
||||||
|
The HF GenerationMixin's default behavior is to force the first
|
||||||
|
decoded token to be <BOS> if the prompt does not already contain
|
||||||
|
<BOS> (this is accomplished using a logit
|
||||||
|
processor setting.)
|
||||||
|
|
||||||
|
So when we use HF BART as our baseline for comparison, note that
|
||||||
|
when the user provides a request with a None decoder prompt
|
||||||
|
(i.e. a singleton encoder prompt, or else an explicit encoder/
|
||||||
|
decoder prompt with the decoder sub-prompt set to None), HF and
|
||||||
|
vLLM handle this in different ways:
|
||||||
|
|
||||||
|
* HF will (1) tokenize the None prompt as an empty token-list,
|
||||||
|
(2) append <decoder-start-token> to the beginning, yielding
|
||||||
|
[<decoder-start-token>], (3) pass this token list to the model, and
|
||||||
|
then (4) after computing logits during prefill, override the model
|
||||||
|
logits & force <BOS> to be the first generated token.
|
||||||
|
|
||||||
|
* vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
|
||||||
|
start-token to the beginning, yielding [<decoder-start-token><BOS>],
|
||||||
|
(3) pass these tokens to the model & proceed with generation.
|
||||||
|
|
||||||
|
The net effect is that compared to vLLM, the list of HF *decoded* tokens
|
||||||
|
will contain one more initial <BOS> than the vLLM generated tokens,
|
||||||
|
because vLLM's <BOS> token is injected into the prompt rather than into
|
||||||
|
the generated output. This is in spite of the fact that overall, the
|
||||||
|
complete sequences (prompt + decoded tokens) produced by vLLM will match
|
||||||
|
HF.
|
||||||
|
|
||||||
|
So when we use HF decoded token output to validate vLLM's decoded token
|
||||||
|
output, the testing process must account for the difference in decoded
|
||||||
|
token sequences between vLLM and HF specifically in the
|
||||||
|
decoder-prompt-is-None case.
|
||||||
|
|
||||||
|
One option is to disable the logit processor feature that forces the
|
||||||
|
<BOS> token to be decoded (forced_bos_token_id = None), eliminating
|
||||||
|
the problem entirely. However this is not "normal" BART usage.
|
||||||
|
|
||||||
|
The other option is - only in the decoder-prompt-is-None case - to
|
||||||
|
discard the first decoded token from the HF output before comparing it
|
||||||
|
to vLLM.
|
||||||
|
|
||||||
|
To that end, when testing the scenario where the decoder prompt is None
|
||||||
|
(and only in that one scenario), this test skips the first HF decoded
|
||||||
|
token during the process of validating the vLLM decoded output.
|
||||||
|
'''
|
||||||
|
|
||||||
|
test_case_prompts = example_encoder_decoder_prompts[
|
||||||
|
decoder_prompt_type]
|
||||||
|
|
||||||
|
# Configuration settings for HF baseline
|
||||||
|
hf_kwargs = {
|
||||||
|
"top_k": None,
|
||||||
|
"num_beams": 1,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"early_stopping": False,
|
||||||
|
"no_repeat_ngram_size": None,
|
||||||
|
"min_length": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
is_encoder_decoder_model=True) as hf_model:
|
||||||
|
hf_outputs = (
|
||||||
|
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
|
test_case_prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs,
|
||||||
|
**hf_kwargs,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Note: currently encoder/decoder models are only compatible with
|
||||||
|
# enforce_eager=True. Normally this is not a problem because
|
||||||
|
# for encoder/decoder models vLLM will
|
||||||
|
# default to enforce_eager=True if enforce_eager
|
||||||
|
# is left unspecified. However, the
|
||||||
|
# VllmRunner test fixture (which wraps around the LLM class) defaults to
|
||||||
|
# enforce_eager=False (a behavior which a number of already-exisitng
|
||||||
|
# decoder-only unit tests expect), so when testing an encoder/decoder
|
||||||
|
# model we must explicitly specify enforce_eager=True in the VllmRunner
|
||||||
|
# constructor.
|
||||||
|
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||||
|
test_case_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
|
||||||
|
else 0)
|
||||||
|
|
||||||
|
check_logprobs_close(outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
num_outputs_0_skip_tokens=hf_skip_tokens)
|
@ -1,4 +1,5 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
@ -45,11 +46,27 @@ def check_logprobs_close(
|
|||||||
outputs_1_lst: Sequence[TokensTextLogprobs],
|
outputs_1_lst: Sequence[TokensTextLogprobs],
|
||||||
name_0: str,
|
name_0: str,
|
||||||
name_1: str,
|
name_1: str,
|
||||||
|
num_outputs_0_skip_tokens: int = 0,
|
||||||
warn_on_mismatch: bool = True,
|
warn_on_mismatch: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compare the logprobs of two sequences generated by different models,
|
Compare the logprobs of two sequences generated by different models,
|
||||||
which should be similar but not necessarily equal.
|
which should be similar but not necessarily equal.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* outputs_0_lst: First sequence to compare
|
||||||
|
* outputs_0_lst: Second sequence to compare
|
||||||
|
* name_0: sequence #0 name
|
||||||
|
* name_1: sequence #1 name
|
||||||
|
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
|
||||||
|
sequence #0 tokens & logprobs to discard
|
||||||
|
before comparison, i.e. all
|
||||||
|
of sequence #1 will be compared to
|
||||||
|
sequence #0 beginning at index
|
||||||
|
num_outputs_0_skip_tokens
|
||||||
|
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
|
||||||
|
mismatch between the two sequences
|
||||||
"""
|
"""
|
||||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||||
|
|
||||||
@ -65,6 +82,15 @@ def check_logprobs_close(
|
|||||||
if logprobs_1 is None:
|
if logprobs_1 is None:
|
||||||
logprobs_1 = [None] * len(output_ids_1)
|
logprobs_1 = [None] * len(output_ids_1)
|
||||||
|
|
||||||
|
# Skip specified number of initial sequence #0 tokens
|
||||||
|
# & logprobs, leaving output text as-is for simplicity
|
||||||
|
# (text mismatches may generate warnings but do not
|
||||||
|
# cause the test to fail.)
|
||||||
|
if num_outputs_0_skip_tokens < 0:
|
||||||
|
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
|
||||||
|
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
|
||||||
|
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
|
||||||
|
|
||||||
# Loop through generated tokens.
|
# Loop through generated tokens.
|
||||||
for idx, (output_id_0,
|
for idx, (output_id_0,
|
||||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||||
@ -110,3 +136,13 @@ def check_logprobs_close(
|
|||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
warnings.warn(fail_msg, stacklevel=2)
|
warnings.warn(fail_msg, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderPromptType(Enum):
|
||||||
|
'''
|
||||||
|
For encoder/decoder models only -
|
||||||
|
|
||||||
|
'''
|
||||||
|
CUSTOM = 1
|
||||||
|
NONE = 2
|
||||||
|
EMPTY_STR = 3
|
||||||
|
480
tests/worker/test_encoder_decoder_model_runner.py
Normal file
480
tests/worker/test_encoder_decoder_model_runner.py
Normal file
@ -0,0 +1,480 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.utils import is_cpu
|
||||||
|
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||||
|
|
||||||
|
# CUDA graph scenarios to test
|
||||||
|
#
|
||||||
|
# Currently CUDA graph is not supported
|
||||||
|
ENFORCE_EAGER = [True]
|
||||||
|
|
||||||
|
BATCH_SIZES = [1, 4, 16, 64, 256]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_model_runner(model: str, *args,
|
||||||
|
**kwargs) -> EncoderDecoderModelRunner:
|
||||||
|
engine_args = EngineArgs(model, *args, **kwargs)
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
model_runner = EncoderDecoderModelRunner(
|
||||||
|
model_config=engine_config.model_config,
|
||||||
|
parallel_config=engine_config.parallel_config,
|
||||||
|
scheduler_config=engine_config.scheduler_config,
|
||||||
|
device_config=engine_config.device_config,
|
||||||
|
cache_config=engine_config.cache_config,
|
||||||
|
load_config=engine_config.load_config,
|
||||||
|
lora_config=engine_config.lora_config,
|
||||||
|
prompt_adapter_config=engine_config.prompt_adapter_config,
|
||||||
|
is_driver_worker=True,
|
||||||
|
)
|
||||||
|
return model_runner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(condition=is_cpu(),
|
||||||
|
reason="CPU backend is currently "
|
||||||
|
"unsupported for encoder/ "
|
||||||
|
"decoder models")
|
||||||
|
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||||
|
def test_empty_seq_group(enforce_eager, ):
|
||||||
|
"""Verify prepare prompt and decode returns empty output
|
||||||
|
for empty seq group list"""
|
||||||
|
|
||||||
|
model_runner = _create_model_runner(
|
||||||
|
"facebook/bart-base",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
max_num_batched_tokens=100000,
|
||||||
|
max_num_seqs=100000,
|
||||||
|
enable_chunked_prefill=False,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
)
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list)
|
||||||
|
(
|
||||||
|
input_tokens,
|
||||||
|
input_positions,
|
||||||
|
encoder_input_tokens,
|
||||||
|
encoder_input_positions,
|
||||||
|
attn_metadata,
|
||||||
|
return_seq_lens,
|
||||||
|
) = (
|
||||||
|
model_input.input_tokens,
|
||||||
|
model_input.input_positions,
|
||||||
|
model_input.encoder_input_tokens,
|
||||||
|
model_input.encoder_input_positions,
|
||||||
|
model_input.attn_metadata,
|
||||||
|
model_input.seq_lens,
|
||||||
|
)
|
||||||
|
assert input_tokens is None
|
||||||
|
assert input_positions is None
|
||||||
|
assert encoder_input_tokens is None
|
||||||
|
assert encoder_input_positions is None
|
||||||
|
assert attn_metadata is None
|
||||||
|
assert return_seq_lens is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(condition=is_cpu(),
|
||||||
|
reason="CPU backend is currently "
|
||||||
|
"unsupported for encoder/ "
|
||||||
|
"decoder models")
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||||
|
def test_prepare_prompt(
|
||||||
|
batch_size,
|
||||||
|
enforce_eager,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Test the ability of the encoder/decoder model runner subclass to
|
||||||
|
produce prefill-phase model inputs & attention metadata.
|
||||||
|
|
||||||
|
Test behavior:
|
||||||
|
|
||||||
|
* Instantiate BART base model & enc/dec model runner
|
||||||
|
* Construct sequence-group metadata for dummy prompts
|
||||||
|
* Test that encoder attention, decoder self-attention,
|
||||||
|
and encoder/decoder cross-attention inputs are correct
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* batch_size
|
||||||
|
* backend_name: The attention backend under test
|
||||||
|
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
|
||||||
|
'''
|
||||||
|
|
||||||
|
model_runner = _create_model_runner(
|
||||||
|
"facebook/bart-base",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
max_num_batched_tokens=100000,
|
||||||
|
max_num_seqs=100000,
|
||||||
|
enable_chunked_prefill=False,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
encoder_seq_lens: List[int] = []
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
block_tables = {0: [1]}
|
||||||
|
cross_block_table = [2]
|
||||||
|
for i in range(batch_size):
|
||||||
|
# make sure all tokens fit into one block
|
||||||
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
seq_data = SequenceData(list(range(seq_len)))
|
||||||
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||||
|
encoder_seq_lens.append(encoder_seq_len)
|
||||||
|
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
|
||||||
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: seq_data},
|
||||||
|
sampling_params=SamplingParams(temperature=0),
|
||||||
|
block_tables=block_tables,
|
||||||
|
encoder_seq_data=encoder_seq_data,
|
||||||
|
cross_block_table=cross_block_table,
|
||||||
|
)
|
||||||
|
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
||||||
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
|
# Build
|
||||||
|
# * Decoder model inputs
|
||||||
|
# * Decoder self-attention KV caching data structures
|
||||||
|
# * Encoder model inputs
|
||||||
|
# * Encoder/decoder cross-attention KV caching data structures
|
||||||
|
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||||
|
|
||||||
|
input_tokens = model_input.input_tokens
|
||||||
|
input_positions = model_input.input_positions
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
return_seq_lens = model_input.seq_lens
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
encoder_input_tokens = model_input.encoder_input_tokens
|
||||||
|
encoder_input_positions = model_input.encoder_input_positions
|
||||||
|
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
||||||
|
assert return_seq_lens == seq_lens
|
||||||
|
assert len(slot_mapping) == len(input_tokens)
|
||||||
|
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
||||||
|
|
||||||
|
# Verify input metadata is correct for prompts.
|
||||||
|
# - Decoder attention metadata
|
||||||
|
device = model_runner.device
|
||||||
|
assert attn_metadata.num_prefills > 0
|
||||||
|
assert attn_metadata.num_decode_tokens == 0
|
||||||
|
assert torch.equal(attn_metadata.seq_lens_tensor,
|
||||||
|
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||||
|
assert attn_metadata.seq_lens == seq_lens
|
||||||
|
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
|
||||||
|
assert attn_metadata.max_decode_seq_len == 0
|
||||||
|
# - Encoder attention metadata
|
||||||
|
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
|
||||||
|
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
|
||||||
|
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
|
||||||
|
|
||||||
|
# Test decoder subquery start locs.
|
||||||
|
start_idx = 0
|
||||||
|
start_loc = [start_idx]
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
start_idx += seq_len
|
||||||
|
start_loc.append(start_idx)
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.query_start_loc,
|
||||||
|
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test decoder seq start locs & context lengths
|
||||||
|
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.seq_start_loc,
|
||||||
|
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.context_lens_tensor,
|
||||||
|
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify block tables are correct for prompts
|
||||||
|
# - Decoder self-attention
|
||||||
|
expected = torch.tensor(
|
||||||
|
[[] for _ in range(len(seq_group_metadata_list))],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
expected,
|
||||||
|
)
|
||||||
|
# - Encoder/decoder cross-attention
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.cross_block_tables,
|
||||||
|
expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cuda graph should not be used for prefill.
|
||||||
|
assert attn_metadata.use_cuda_graph is False
|
||||||
|
|
||||||
|
# Verify the lengths of input tokens & positions
|
||||||
|
# - Decoder
|
||||||
|
assert len(input_tokens) == sum(seq_lens)
|
||||||
|
assert len(input_positions) == sum(seq_lens)
|
||||||
|
# -- An indirect check that model_input.input_tokens
|
||||||
|
# and model_input.input_positions are correct -
|
||||||
|
# by design of the test, the input tokens are
|
||||||
|
# equal to the input position values, so if
|
||||||
|
# the model_input data structure has the correct
|
||||||
|
# values then these two should be equal
|
||||||
|
assert torch.equal(
|
||||||
|
input_tokens,
|
||||||
|
input_positions,
|
||||||
|
)
|
||||||
|
# - Encoder
|
||||||
|
assert len(encoder_input_tokens) == sum(encoder_seq_lens)
|
||||||
|
# -- An indirect check that model_input.encoder_input_tokens
|
||||||
|
# and model_input.encoder_input_positions are correct -
|
||||||
|
# by design of the test, the input tokens are
|
||||||
|
# equal to the input position values, so if
|
||||||
|
# the model_input data structure has the correct
|
||||||
|
# values then these two should be equal
|
||||||
|
assert torch.equal(
|
||||||
|
encoder_input_tokens,
|
||||||
|
encoder_input_positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that vLLM sampling infrastructure chooses the correct
|
||||||
|
# sequence positions at which to sample (i.e. the end of
|
||||||
|
# each sequence) in the prefill phase
|
||||||
|
|
||||||
|
expected_selected_token_indices = []
|
||||||
|
selected_token_start_idx = 0
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
# Compute the index offset of the final token in each
|
||||||
|
# prompt (recall that the prompts are concatenated)
|
||||||
|
expected_selected_token_indices.append(selected_token_start_idx +
|
||||||
|
seq_len - 1)
|
||||||
|
selected_token_start_idx += seq_len
|
||||||
|
|
||||||
|
sampling_metadata = model_input.sampling_metadata
|
||||||
|
actual = sampling_metadata.selected_token_indices
|
||||||
|
expected = torch.tensor(
|
||||||
|
expected_selected_token_indices,
|
||||||
|
device=actual.device,
|
||||||
|
dtype=actual.dtype,
|
||||||
|
)
|
||||||
|
assert torch.equal(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(condition=is_cpu(),
|
||||||
|
reason="CPU backend is currently "
|
||||||
|
"unsupported for encoder/ "
|
||||||
|
"decoder models")
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||||
|
def test_prepare_decode(
|
||||||
|
batch_size,
|
||||||
|
enforce_eager,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Test the ability of the encoder/decoder model runner subclass to
|
||||||
|
produce decode-phase model inputs & attention metadata.
|
||||||
|
|
||||||
|
Test behavior:
|
||||||
|
|
||||||
|
* Instantiate BART base model & enc/dec model runner
|
||||||
|
* Construct sequence-group metadata for dummy prompts
|
||||||
|
* Test that encoder attention, decoder self-attention,
|
||||||
|
and encoder/decoder cross-attention inputs are correct
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* batch_size
|
||||||
|
* backend_name: The attention backend under test
|
||||||
|
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
|
||||||
|
'''
|
||||||
|
|
||||||
|
model_runner = _create_model_runner(
|
||||||
|
"facebook/bart-base",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
max_num_batched_tokens=100000,
|
||||||
|
max_num_seqs=100000,
|
||||||
|
enable_chunked_prefill=False,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
encoder_seq_lens: List[int] = []
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
block_tables = {0: [1]}
|
||||||
|
cross_block_table = [2]
|
||||||
|
for i in range(batch_size):
|
||||||
|
# make sure all tokens fit into one block
|
||||||
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
seq_data = SequenceData(list(range(seq_len)))
|
||||||
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||||
|
encoder_seq_lens.append(encoder_seq_len)
|
||||||
|
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
|
||||||
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=False,
|
||||||
|
seq_data={0: seq_data},
|
||||||
|
sampling_params=SamplingParams(temperature=0),
|
||||||
|
block_tables=block_tables,
|
||||||
|
encoder_seq_data=encoder_seq_data,
|
||||||
|
cross_block_table=cross_block_table,
|
||||||
|
)
|
||||||
|
assert seq_group_metadata.token_chunk_size == 1
|
||||||
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
|
# Build
|
||||||
|
# * Decoder model inputs
|
||||||
|
# * Decoder self-attention KV caching data structures
|
||||||
|
# * Encoder model inputs
|
||||||
|
# * Encoder/decoder cross-attention KV caching data structures
|
||||||
|
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||||
|
input_tokens = model_input.input_tokens
|
||||||
|
input_positions = model_input.input_positions
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
return_seq_lens = model_input.seq_lens
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
encoder_input_tokens = model_input.encoder_input_tokens
|
||||||
|
encoder_input_positions = model_input.encoder_input_positions
|
||||||
|
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
||||||
|
assert return_seq_lens == seq_lens
|
||||||
|
assert len(slot_mapping) == len(input_tokens)
|
||||||
|
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
||||||
|
|
||||||
|
# Verify input metadata is correct for decode phase.
|
||||||
|
# - Decoder attention metadata
|
||||||
|
device = model_runner.device
|
||||||
|
assert attn_metadata.num_prefills == 0
|
||||||
|
assert attn_metadata.num_decode_tokens > 0
|
||||||
|
assert torch.equal(attn_metadata.seq_lens_tensor,
|
||||||
|
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||||
|
assert attn_metadata.seq_lens == seq_lens
|
||||||
|
assert attn_metadata.max_prefill_seq_len == 0
|
||||||
|
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||||
|
# - Encoder attention metadata
|
||||||
|
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
|
||||||
|
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
|
||||||
|
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
|
||||||
|
|
||||||
|
# Test decoder subquery start locs.
|
||||||
|
start_idx = 0
|
||||||
|
start_loc = [start_idx]
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
start_idx += 1
|
||||||
|
start_loc.append(start_idx)
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.query_start_loc,
|
||||||
|
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test decoder seq start locs. Note that for normal prefill it is
|
||||||
|
# equivalent to query_start_loc.
|
||||||
|
start_idx = 0
|
||||||
|
seq_start_loc = [start_idx]
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
start_idx += seq_len
|
||||||
|
seq_start_loc.append(start_idx)
|
||||||
|
|
||||||
|
# Test seq_start_loc and context lengths
|
||||||
|
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.seq_start_loc,
|
||||||
|
torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.context_lens_tensor,
|
||||||
|
torch.tensor([seq_len - 1 for seq_len in seq_lens],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device))
|
||||||
|
|
||||||
|
# Verify block tables are correct for prompts
|
||||||
|
# - Decoder self-attention
|
||||||
|
expected = torch.tensor(
|
||||||
|
[block_tables[0] for _ in range(len(seq_group_metadata_list))],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=model_runner.device)
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
expected,
|
||||||
|
)
|
||||||
|
# - Encoder/decoder cross-attention
|
||||||
|
expected = torch.tensor(
|
||||||
|
[cross_block_table for _ in range(len(seq_group_metadata_list))],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=model_runner.device)
|
||||||
|
assert torch.equal(
|
||||||
|
attn_metadata.cross_block_tables,
|
||||||
|
expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cuda graph should is currently not supported for encoder/decoer.
|
||||||
|
assert attn_metadata.use_cuda_graph is False
|
||||||
|
|
||||||
|
# Verify the lengths of input tokens & positions
|
||||||
|
# - Decoder
|
||||||
|
assert len(input_tokens) == len(seq_lens)
|
||||||
|
assert len(input_positions) == len(seq_lens)
|
||||||
|
# -- An indirect check that model_input.input_tokens
|
||||||
|
# and model_input.input_positions are correct -
|
||||||
|
# by design of the test, the input tokens are
|
||||||
|
# equal to the input position values, so if
|
||||||
|
# the model_input data structure has the correct
|
||||||
|
# values then these two should be equal
|
||||||
|
assert torch.equal(
|
||||||
|
input_tokens,
|
||||||
|
input_positions,
|
||||||
|
)
|
||||||
|
# - Encoder
|
||||||
|
assert len(encoder_input_tokens) == 0
|
||||||
|
assert len(encoder_input_tokens) == 0
|
||||||
|
# -- An indirect check that model_input.encoder_input_tokens
|
||||||
|
# and model_input.encoder_input_positions are correct -
|
||||||
|
# by design of the test, the input tokens are
|
||||||
|
# equal to the input position values, so if
|
||||||
|
# the model_input data structure has the correct
|
||||||
|
# values then these two should be equal
|
||||||
|
assert torch.equal(
|
||||||
|
encoder_input_tokens,
|
||||||
|
encoder_input_positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that vLLM sampling infrastructure chooses the correct
|
||||||
|
# sequence positions at which to sample (i.e. the end of
|
||||||
|
# each sequence) in the decode phase
|
||||||
|
|
||||||
|
expected_selected_token_indices = []
|
||||||
|
selected_token_start_idx = 0
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
# Compute the index offset of the final token in each
|
||||||
|
# sequence's decoded outputs; since a single token is
|
||||||
|
# decoded per iteration per sequence, then the length
|
||||||
|
# of the decoded tokens for a given sequence is 1 and
|
||||||
|
# the final index offset into a given sequence's
|
||||||
|
# generated tokens is 0 (i.e. the expected sampling index
|
||||||
|
# for a given sequence is just `selected_token_start_idx`)
|
||||||
|
expected_selected_token_indices.append(selected_token_start_idx)
|
||||||
|
selected_token_start_idx += 1
|
||||||
|
|
||||||
|
sampling_metadata = model_input.sampling_metadata
|
||||||
|
actual = sampling_metadata.selected_token_indices
|
||||||
|
expected = torch.tensor(
|
||||||
|
expected_selected_token_indices,
|
||||||
|
device=actual.device,
|
||||||
|
dtype=actual.dtype,
|
||||||
|
)
|
||||||
|
assert torch.equal(actual, expected)
|
@ -1,6 +1,7 @@
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder)
|
AttentionMetadataBuilder,
|
||||||
|
AttentionType)
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
@ -8,6 +9,7 @@ __all__ = [
|
|||||||
"Attention",
|
"Attention",
|
||||||
"AttentionBackend",
|
"AttentionBackend",
|
||||||
"AttentionMetadata",
|
"AttentionMetadata",
|
||||||
|
"AttentionType",
|
||||||
"AttentionMetadataBuilder",
|
"AttentionMetadataBuilder",
|
||||||
"Attention",
|
"Attention",
|
||||||
"get_attn_backend",
|
"get_attn_backend",
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
from vllm.attention import AttentionMetadata, AttentionType
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Type
|
from typing import Generator, Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -8,7 +10,8 @@ import vllm.envs as envs
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
|
from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
|
||||||
|
is_tpu, is_xpu)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,6 +27,66 @@ class _Backend(enum.Enum):
|
|||||||
IPEX = enum.auto()
|
IPEX = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||||
|
assert backend_name is not None
|
||||||
|
|
||||||
|
backend_members = _Backend.__members__
|
||||||
|
if backend_name not in backend_members:
|
||||||
|
raise ValueError(f"Invalid attention backend '{backend_name}'. "
|
||||||
|
f"Available backends: {', '.join(backend_members)} "
|
||||||
|
"(case-sensitive).")
|
||||||
|
|
||||||
|
return _Backend[backend_name]
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||||
|
'''
|
||||||
|
Get the backend override specified by the vLLM attention
|
||||||
|
backend environment variable, if one is specified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* _Backend enum value if an override is specified
|
||||||
|
* None otherwise
|
||||||
|
'''
|
||||||
|
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||||
|
return (None
|
||||||
|
if backend_name is None else backend_name_to_enum(backend_name))
|
||||||
|
|
||||||
|
|
||||||
|
# Global state allows a particular choice of backend
|
||||||
|
# to be forced, overriding the logic which auto-selects
|
||||||
|
# a backend based on system & workload configuration
|
||||||
|
# (default behavior if this variable is None)
|
||||||
|
#
|
||||||
|
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||||
|
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
|
||||||
|
forced_attn_backend: Optional[_Backend] = None
|
||||||
|
|
||||||
|
|
||||||
|
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||||
|
'''
|
||||||
|
Force all attention operations to use a specified backend.
|
||||||
|
|
||||||
|
Passing `None` for the argument re-enables automatic
|
||||||
|
backend selection.,
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_backend: backend selection (None to revert to auto)
|
||||||
|
'''
|
||||||
|
global forced_attn_backend
|
||||||
|
forced_attn_backend = attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||||
|
'''
|
||||||
|
Get the currently-forced choice of attention backend,
|
||||||
|
or None if auto-selection is currently enabled.
|
||||||
|
'''
|
||||||
|
return forced_attn_backend
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_attn_backend(
|
def get_attn_backend(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -101,16 +164,20 @@ def which_attn_to_use(
|
|||||||
# Default case.
|
# Default case.
|
||||||
selected_backend = _Backend.FLASH_ATTN
|
selected_backend = _Backend.FLASH_ATTN
|
||||||
|
|
||||||
|
# Check whether a particular choice of backend was
|
||||||
|
# previously forced.
|
||||||
|
#
|
||||||
|
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||||
|
# ENVIRONMENT VARIABLE.
|
||||||
|
backend_by_global_setting: Optional[_Backend] = (
|
||||||
|
get_global_forced_attn_backend())
|
||||||
|
if backend_by_global_setting is not None:
|
||||||
|
selected_backend = backend_by_global_setting
|
||||||
|
else:
|
||||||
# Check the environment variable and override if specified
|
# Check the environment variable and override if specified
|
||||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
if backend_by_env_var is not None:
|
if backend_by_env_var is not None:
|
||||||
backend_members = _Backend.__members__
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
if backend_by_env_var not in backend_members:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid attention backend '{backend_by_env_var}'. "
|
|
||||||
f"Available backends: {', '.join(backend_members)} "
|
|
||||||
"(case-sensitive).")
|
|
||||||
selected_backend = _Backend[backend_by_env_var]
|
|
||||||
|
|
||||||
if is_cpu():
|
if is_cpu():
|
||||||
if selected_backend != _Backend.TORCH_SDPA:
|
if selected_backend != _Backend.TORCH_SDPA:
|
||||||
@ -193,3 +260,35 @@ def which_attn_to_use(
|
|||||||
selected_backend = _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
|
|
||||||
return selected_backend
|
return selected_backend
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def global_force_attn_backend_context_manager(
|
||||||
|
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||||
|
'''
|
||||||
|
Globally force a vLLM attention backend override within a
|
||||||
|
context manager, reverting the global attention backend
|
||||||
|
override to its prior state upon exiting the context
|
||||||
|
manager.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* attn_backend: attention backend to force
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Generator
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Save the current state of the global backend override (if any)
|
||||||
|
original_value = get_global_forced_attn_backend()
|
||||||
|
|
||||||
|
# Globally force the new backend override
|
||||||
|
global_force_attn_backend(attn_backend)
|
||||||
|
|
||||||
|
# Yield control back to the enclosed code block
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Revert the original global backend override, if any
|
||||||
|
global_force_attn_backend(original_value)
|
||||||
|
@ -12,7 +12,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.tracing import is_otel_installed
|
from vllm.tracing import is_otel_installed
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
|
||||||
|
cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||||
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
|
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
|
||||||
print_warning_once)
|
print_warning_once)
|
||||||
|
|
||||||
@ -87,6 +88,9 @@ class ModelConfig:
|
|||||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
disable CUDA graph and always execute the model in eager mode.
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
If False, we will use CUDA graph and eager execution in hybrid.
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
If None, the user did not specify, so default to False -
|
||||||
|
except for encoder/decoder models, which currently require
|
||||||
|
eager mode.
|
||||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
When a sequence has context length larger than this, we fall back
|
When a sequence has context length larger than this, we fall back
|
||||||
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
|
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
|
||||||
@ -121,7 +125,7 @@ class ModelConfig:
|
|||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
quantization_param_path: Optional[str] = None,
|
quantization_param_path: Optional[str] = None,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: Optional[bool] = None,
|
||||||
max_context_len_to_capture: Optional[int] = None,
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
max_seq_len_to_capture: Optional[int] = None,
|
max_seq_len_to_capture: Optional[int] = None,
|
||||||
max_logprobs: int = 20,
|
max_logprobs: int = 20,
|
||||||
@ -160,6 +164,34 @@ class ModelConfig:
|
|||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
|
||||||
|
# Choose a default enforce_eager value if the user did not specify
|
||||||
|
# a value (enforce_eager is None)
|
||||||
|
if getattr(self.hf_config, 'is_encoder_decoder', False):
|
||||||
|
if self.enforce_eager is None:
|
||||||
|
# *Only for encoder/decoder models* and
|
||||||
|
# *only if enforce_eager is unset*, override
|
||||||
|
# to enforce_eager=True
|
||||||
|
#
|
||||||
|
# Add a logger message since it is *somewhat* non-intuitive that
|
||||||
|
# enforce_eager is True when the user has not specified its
|
||||||
|
# value.
|
||||||
|
logger.info("Forcing enforce_eager == True because "
|
||||||
|
"enforce_eager setting was unspecified and "
|
||||||
|
"CUDAGraph is not supported with encoder/ "
|
||||||
|
"decoder models.")
|
||||||
|
self.enforce_eager = True
|
||||||
|
|
||||||
|
if not self.enforce_eager:
|
||||||
|
# Eager mode explicitly disabled by user for an encoder/
|
||||||
|
# decoder model; however CUDAGRAPH + encoder/decoder is
|
||||||
|
# not currently supported
|
||||||
|
raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
|
||||||
|
elif self.enforce_eager is None:
|
||||||
|
# *Only for decoder-only models*, enforce_eager
|
||||||
|
# defaults to False if unset. This is intuitive
|
||||||
|
# so no logging message needed.
|
||||||
|
self.enforce_eager = False
|
||||||
|
|
||||||
if (not self.disable_sliding_window
|
if (not self.disable_sliding_window
|
||||||
and self.hf_text_config.model_type == "gemma2"
|
and self.hf_text_config.model_type == "gemma2"
|
||||||
and self.hf_text_config.sliding_window is not None):
|
and self.hf_text_config.sliding_window is not None):
|
||||||
|
@ -1,15 +1,7 @@
|
|||||||
"""Block manager utils."""
|
"""Block manager utils."""
|
||||||
from vllm.sequence import SequenceGroup
|
from vllm.sequence import SequenceGroup
|
||||||
|
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||||
# Exception strings for non-implemented block manager enc/dec scenarios
|
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_SWA = \
|
|
||||||
"Sliding window attention for encoder/decoder models " + \
|
|
||||||
"is not currently supported."
|
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
|
||||||
"Prefix caching for encoder/decoder models " + \
|
|
||||||
"is not currently supported."
|
|
||||||
|
|
||||||
|
|
||||||
def _get_block_mgr_sliding_window_attr(block_mgr):
|
def _get_block_mgr_sliding_window_attr(block_mgr):
|
||||||
|
@ -392,6 +392,19 @@ class Scheduler:
|
|||||||
seq.status = SequenceStatus.FINISHED_ABORTED
|
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||||
self.free_seq(seq)
|
self.free_seq(seq)
|
||||||
|
|
||||||
|
self._free_seq_group_cross_attn_blocks(aborted_group)
|
||||||
|
|
||||||
|
def _free_seq_group_cross_attn_blocks(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Free a sequence group from a cross-attention block table.
|
||||||
|
Has no effect on decoder-only models.
|
||||||
|
"""
|
||||||
|
if seq_group.is_encoder_decoder():
|
||||||
|
self.block_manager.free_cross(seq_group)
|
||||||
|
|
||||||
def has_unfinished_seqs(self) -> bool:
|
def has_unfinished_seqs(self) -> bool:
|
||||||
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
||||||
self.swapped) != 0
|
self.swapped) != 0
|
||||||
@ -963,6 +976,17 @@ class Scheduler:
|
|||||||
# seq_id -> physical block numbers
|
# seq_id -> physical block numbers
|
||||||
block_tables: Dict[int, List[int]] = {}
|
block_tables: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
|
if seq_group.is_encoder_decoder():
|
||||||
|
# Encoder associated with SequenceGroup
|
||||||
|
encoder_seq_data = seq_group.get_encoder_seq().data
|
||||||
|
# Block table for cross-attention
|
||||||
|
# Also managed at SequenceGroup level
|
||||||
|
cross_block_table = self.block_manager.get_cross_block_table(
|
||||||
|
seq_group)
|
||||||
|
else:
|
||||||
|
encoder_seq_data = None
|
||||||
|
cross_block_table = None
|
||||||
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
seq_id = seq.seq_id
|
seq_id = seq.seq_id
|
||||||
seq_data[seq_id] = seq.data
|
seq_data[seq_id] = seq.data
|
||||||
@ -1001,6 +1025,8 @@ class Scheduler:
|
|||||||
token_chunk_size=token_chunk_size,
|
token_chunk_size=token_chunk_size,
|
||||||
lora_request=seq_group.lora_request,
|
lora_request=seq_group.lora_request,
|
||||||
computed_block_nums=common_computed_block_nums,
|
computed_block_nums=common_computed_block_nums,
|
||||||
|
encoder_seq_data=encoder_seq_data,
|
||||||
|
cross_block_table=cross_block_table,
|
||||||
# `multi_modal_data` will only be present for the 1st comm
|
# `multi_modal_data` will only be present for the 1st comm
|
||||||
# between engine and worker.
|
# between engine and worker.
|
||||||
# the subsequent comms can still use delta, but
|
# the subsequent comms can still use delta, but
|
||||||
@ -1032,6 +1058,8 @@ class Scheduler:
|
|||||||
remaining: Deque[SequenceGroup] = deque()
|
remaining: Deque[SequenceGroup] = deque()
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
if seq_group.is_finished():
|
if seq_group.is_finished():
|
||||||
|
# Free cross-attention block table, if it exists
|
||||||
|
self._free_seq_group_cross_attn_blocks(seq_group)
|
||||||
# Add the finished requests to the finished requests list.
|
# Add the finished requests to the finished requests list.
|
||||||
# This list will be used to update the Mamba cache in the
|
# This list will be used to update the Mamba cache in the
|
||||||
# next step.
|
# next step.
|
||||||
|
@ -69,7 +69,7 @@ class EngineArgs:
|
|||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
tokenizer_revision: Optional[str] = None
|
tokenizer_revision: Optional[str] = None
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
enforce_eager: bool = False
|
enforce_eager: Optional[bool] = None
|
||||||
max_context_len_to_capture: Optional[int] = None
|
max_context_len_to_capture: Optional[int] = None
|
||||||
max_seq_len_to_capture: int = 8192
|
max_seq_len_to_capture: int = 8192
|
||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
|
@ -3,7 +3,7 @@ from contextlib import contextmanager
|
|||||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||||
Mapping, Optional)
|
Mapping, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Type, TypeVar, Union
|
from typing import Set, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||||
@ -22,7 +22,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
|||||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
|
from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
|
||||||
|
get_prompt_type)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||||
@ -42,7 +43,8 @@ from vllm.transformers_utils.tokenizer_group import (
|
|||||||
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
|
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||||
usage_message)
|
usage_message)
|
||||||
from vllm.utils import Counter
|
from vllm.utils import (Counter, is_embedding_model_config,
|
||||||
|
is_encoder_decoder_model_config)
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -502,8 +504,19 @@ class LLMEngine:
|
|||||||
self.prompt_adapter_config.verify_with_model_config(
|
self.prompt_adapter_config.verify_with_model_config(
|
||||||
self.model_config)
|
self.model_config)
|
||||||
|
|
||||||
def _get_eos_token_id(
|
def _get_bos_token_id(self,
|
||||||
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> Optional[int]:
|
||||||
|
if self.tokenizer is None:
|
||||||
|
logger.warning("Using None for BOS token id because tokenizer "
|
||||||
|
"is not initialized")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
|
||||||
|
|
||||||
|
def _get_eos_token_id(self,
|
||||||
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> Optional[int]:
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
logger.warning("Using None for EOS token id because tokenizer "
|
logger.warning("Using None for EOS token id because tokenizer "
|
||||||
"is not initialized")
|
"is not initialized")
|
||||||
@ -511,6 +524,32 @@ class LLMEngine:
|
|||||||
|
|
||||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||||
|
|
||||||
|
def _get_decoder_start_token_id(self, ) -> Optional[int]:
|
||||||
|
'''
|
||||||
|
Obtain the decoder start token id employed by an encoder/decoder
|
||||||
|
model. Returns None for non-encoder/decoder models or if the
|
||||||
|
model config is unavailable.
|
||||||
|
'''
|
||||||
|
|
||||||
|
if not self.is_encoder_decoder_model():
|
||||||
|
logger.warning("Using None for decoder start token id because "
|
||||||
|
"this is not an encoder/decoder model.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (self.model_config is None or self.model_config.hf_config is None):
|
||||||
|
logger.warning("Using None for decoder start token id because "
|
||||||
|
"model config is not available.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
dec_start_token_id = getattr(self.model_config.hf_config,
|
||||||
|
'decoder_start_token_id', None)
|
||||||
|
if dec_start_token_id is None:
|
||||||
|
logger.warning("Falling back on <BOS> for decoder start token id "
|
||||||
|
"because decoder start token id is not available.")
|
||||||
|
dec_start_token_id = self._get_bos_token_id()
|
||||||
|
|
||||||
|
return dec_start_token_id
|
||||||
|
|
||||||
def _add_processed_request(
|
def _add_processed_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -529,6 +568,16 @@ class LLMEngine:
|
|||||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||||
lora_request, prompt_adapter_request)
|
lora_request, prompt_adapter_request)
|
||||||
|
|
||||||
|
encoder_seq = None
|
||||||
|
if 'encoder_prompt_token_ids' in processed_inputs:
|
||||||
|
encoder_seq = Sequence(seq_id,
|
||||||
|
processed_inputs,
|
||||||
|
block_size,
|
||||||
|
eos_token_id,
|
||||||
|
lora_request,
|
||||||
|
prompt_adapter_request,
|
||||||
|
from_decoder_prompt=False)
|
||||||
|
|
||||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||||
if isinstance(params, SamplingParams):
|
if isinstance(params, SamplingParams):
|
||||||
seq_group = self._create_sequence_group_with_sampling(
|
seq_group = self._create_sequence_group_with_sampling(
|
||||||
@ -538,7 +587,8 @@ class LLMEngine:
|
|||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
encoder_seq=encoder_seq)
|
||||||
elif isinstance(params, PoolingParams):
|
elif isinstance(params, PoolingParams):
|
||||||
seq_group = self._create_sequence_group_with_pooling(
|
seq_group = self._create_sequence_group_with_pooling(
|
||||||
request_id,
|
request_id,
|
||||||
@ -546,7 +596,8 @@ class LLMEngine:
|
|||||||
params,
|
params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
encoder_seq=encoder_seq)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Either SamplingParams or PoolingParams must be provided.")
|
"Either SamplingParams or PoolingParams must be provided.")
|
||||||
@ -562,6 +613,336 @@ class LLMEngine:
|
|||||||
def stop_remote_worker_execution_loop(self) -> None:
|
def stop_remote_worker_execution_loop(self) -> None:
|
||||||
self.model_executor.stop_remote_worker_execution_loop()
|
self.model_executor.stop_remote_worker_execution_loop()
|
||||||
|
|
||||||
|
_LLMInputComponentsType = Tuple[str, List[int], ]
|
||||||
|
|
||||||
|
def _prepare_decoder_input_ids_for_generation(
|
||||||
|
self,
|
||||||
|
decoder_input_ids: Optional[List[int]] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||||
|
|
||||||
|
Based on
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/blob/
|
||||||
|
4037a2b5b1278736e566aec12e169100275545ea/
|
||||||
|
src/transformers/generation/utils.py
|
||||||
|
|
||||||
|
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* decoder_input_ids: input token ids to preprocess
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* Processed token list
|
||||||
|
"""
|
||||||
|
|
||||||
|
decoder_start_token_id: Optional[int] = (
|
||||||
|
self._get_decoder_start_token_id())
|
||||||
|
assert decoder_start_token_id is not None
|
||||||
|
|
||||||
|
if decoder_input_ids is None:
|
||||||
|
# no decoder prompt input ->
|
||||||
|
# use decoder_start_token_id as decoder_input_ids
|
||||||
|
(decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
|
||||||
|
|
||||||
|
if (len(decoder_input_ids) == 0
|
||||||
|
or decoder_input_ids[0] != decoder_start_token_id):
|
||||||
|
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||||
|
|
||||||
|
return decoder_input_ids
|
||||||
|
|
||||||
|
def _tokenize_prompt(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[str] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
'''
|
||||||
|
Wrapper around application of the model's
|
||||||
|
tokenizer.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* prompt
|
||||||
|
* request_id
|
||||||
|
* lora_request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* prompt token ids
|
||||||
|
'''
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||||
|
"skip_tokenizer_init is True")
|
||||||
|
|
||||||
|
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
def _extract_single_prompt_for_enc_dec_input(
|
||||||
|
self,
|
||||||
|
inputs: Optional[PromptInputs],
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
ptype: Optional[str] = None,
|
||||||
|
is_encoder_prompt: bool = False,
|
||||||
|
) -> Tuple[Optional[str], List[int]]:
|
||||||
|
'''
|
||||||
|
Only for encoder/decoder models:
|
||||||
|
Extract prompt & prompt_token_ids from any single
|
||||||
|
encoder or decoder input prompt. For encoder input prompts
|
||||||
|
in particular, also extract multi-modal data.
|
||||||
|
|
||||||
|
This function handles the following scenarios:
|
||||||
|
1. The user supplied a singleton encoder prompt
|
||||||
|
& the prompt/prompt-token-ids must be extracted.
|
||||||
|
2. The user supplied an explicit encoder/decoder
|
||||||
|
prompt & the prompt/prompt-token-ids must be
|
||||||
|
extracted from either the encoder and decoder prompts.
|
||||||
|
|
||||||
|
For decoder prompts in particular (scenario 2), special
|
||||||
|
processing is applied to the returned decoder token ids.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* request_id
|
||||||
|
* ptype: str representation of the input prompt type.
|
||||||
|
If `ptype` is `None`, assume that the prompt
|
||||||
|
type is unknown and must be inferred. This is the
|
||||||
|
case for ExplicitEncoderDecoder sub-prompts.
|
||||||
|
* inputs: single encoder or decoder input prompt
|
||||||
|
* is_encoder_prompt: True if encoder input prompt.
|
||||||
|
If False, decoder prompt tokens
|
||||||
|
are preprocessed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* prompt
|
||||||
|
* prompt_token_ids
|
||||||
|
'''
|
||||||
|
prompt_token_ids = None
|
||||||
|
ptype = (get_prompt_type(inputs) if ptype is None else ptype)
|
||||||
|
|
||||||
|
if inputs is None:
|
||||||
|
prompt = None
|
||||||
|
elif ptype == 'str':
|
||||||
|
prompt = inputs
|
||||||
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
|
prompt,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
elif ptype == 'TokensPrompt':
|
||||||
|
prompt = None
|
||||||
|
prompt_token_ids = inputs['prompt_token_ids']
|
||||||
|
else:
|
||||||
|
prompt = inputs['prompt']
|
||||||
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
|
prompt,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_encoder_prompt:
|
||||||
|
# Apply special pre-processing to
|
||||||
|
# decoder prompts
|
||||||
|
prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
|
||||||
|
prompt_token_ids, ))
|
||||||
|
|
||||||
|
assert prompt_token_ids is not None
|
||||||
|
|
||||||
|
return (
|
||||||
|
prompt,
|
||||||
|
prompt_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
|
||||||
|
'''
|
||||||
|
Specifically for encoder/decoder models:
|
||||||
|
generate a default decoder prompt for when
|
||||||
|
the user specifies only the encoder prompt.
|
||||||
|
|
||||||
|
Encoder/decoder models utilize the decoder
|
||||||
|
prompt in different ways; as new models are
|
||||||
|
added, it is intended that this function
|
||||||
|
will be extended to produce differing
|
||||||
|
default decoder prompts, depending on the
|
||||||
|
model variety.
|
||||||
|
|
||||||
|
Absent a special case, the default behavior
|
||||||
|
of this method is to mirror the behavior of
|
||||||
|
the HuggingFace (HF) GenerationMixin for a None
|
||||||
|
decoder prompt, which is to employ a logit processor
|
||||||
|
setting to force the first decoded token to be <BOS>.
|
||||||
|
Here, this behavior is approximated by having the
|
||||||
|
"default" decoder prompt be <BOS>.
|
||||||
|
|
||||||
|
However, it is possible that in the future
|
||||||
|
other models may have different or more
|
||||||
|
complex logic for the default decoder prompt.
|
||||||
|
This motivates having a special helper method
|
||||||
|
for default decoder prompts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* prompt_token_ids
|
||||||
|
'''
|
||||||
|
|
||||||
|
bos_token_id = self._get_bos_token_id()
|
||||||
|
assert bos_token_id is not None
|
||||||
|
prompt_token_ids: List[int] = [bos_token_id]
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
def _process_encoder_decoder_prompt(
|
||||||
|
self,
|
||||||
|
inputs: PromptInputs,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
) -> LLMInputs:
|
||||||
|
'''
|
||||||
|
For encoder/decoder models only:
|
||||||
|
Process an input prompt
|
||||||
|
into an `LLMInputs` instance.
|
||||||
|
|
||||||
|
There are two types of input prompts:
|
||||||
|
singleton prompts which carry only the
|
||||||
|
encoder prompt, and explicit encoder/decoder
|
||||||
|
prompts which carry both the encoder and the
|
||||||
|
decoder prompts as member variables.
|
||||||
|
|
||||||
|
This function handles the following scenarios:
|
||||||
|
* Singleton encoder prompt: extract encoder prompt
|
||||||
|
token ids & infer default decoder prompt token ids
|
||||||
|
* Explicit encoder/decoder prompt: extract encoder
|
||||||
|
and decoder prompt token ids
|
||||||
|
|
||||||
|
Note that for Explicit encoder/decoder prompts,
|
||||||
|
each sub-prompt (encoder or decoder prompt) can
|
||||||
|
have any possible singleton type; thus this
|
||||||
|
method relies on helper functions to obtain
|
||||||
|
token ids for the sub-prompts.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* inputs: an input prompt
|
||||||
|
* request_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* `LLMInputs` instance
|
||||||
|
'''
|
||||||
|
|
||||||
|
ptype = get_prompt_type(inputs)
|
||||||
|
|
||||||
|
# Obtain encoder and decoder prompt tokens. Note
|
||||||
|
# that, no matter what, the decoder
|
||||||
|
# prompt type is unknown.
|
||||||
|
if ptype == "ExplicitEncoderDecoder":
|
||||||
|
# If input is explicit encoder/decoder prompt,
|
||||||
|
# then it remains to be determined what type
|
||||||
|
# of encoder prompt we have
|
||||||
|
extracted_encoder_prompt = inputs.get('encoder_prompt')
|
||||||
|
encoder_ptype = None
|
||||||
|
# Extract decoder prompt from explicit
|
||||||
|
# encoder/decoder prompt
|
||||||
|
extracted_decoder_prompt = inputs.get('decoder_prompt')
|
||||||
|
else:
|
||||||
|
# If input is singleton encoder prompt, then
|
||||||
|
# we know the encoder prompt type
|
||||||
|
extracted_encoder_prompt = inputs
|
||||||
|
encoder_ptype = ptype
|
||||||
|
# Decoder prompt is always unknown if
|
||||||
|
# encoder/decoder prompt is not explicit
|
||||||
|
extracted_decoder_prompt = None
|
||||||
|
|
||||||
|
# Invoke helper function to obtain encoder
|
||||||
|
# prompt and prompt token ids, either from
|
||||||
|
# singleton encoder prompt or from the
|
||||||
|
# encoder sub-prompt of an explicit
|
||||||
|
# encoder/decode scenario 2), special
|
||||||
|
# processing is applied to the returned decoder token ids
|
||||||
|
(
|
||||||
|
encoder_prompt,
|
||||||
|
encoder_prompt_token_ids,
|
||||||
|
) = self._extract_single_prompt_for_enc_dec_input(
|
||||||
|
extracted_encoder_prompt,
|
||||||
|
request_id=request_id,
|
||||||
|
ptype=encoder_ptype,
|
||||||
|
is_encoder_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke helper method to obtain
|
||||||
|
# decoder prompt and prompt token ids.
|
||||||
|
#
|
||||||
|
# The helper method will detect the decoder
|
||||||
|
# prompt type.
|
||||||
|
#
|
||||||
|
# Helper method will also apply special
|
||||||
|
# preprocessing unique to decoder prompts.
|
||||||
|
(
|
||||||
|
decoder_prompt,
|
||||||
|
decoder_prompt_token_ids,
|
||||||
|
) = self._extract_single_prompt_for_enc_dec_input(
|
||||||
|
extracted_decoder_prompt,
|
||||||
|
request_id=request_id,
|
||||||
|
ptype=None,
|
||||||
|
is_encoder_prompt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMInputs(
|
||||||
|
prompt_token_ids=decoder_prompt_token_ids,
|
||||||
|
prompt=decoder_prompt,
|
||||||
|
encoder_prompt_token_ids=encoder_prompt_token_ids,
|
||||||
|
encoder_prompt=encoder_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_decoder_only_prompt(
|
||||||
|
self,
|
||||||
|
inputs: PromptInputs,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> LLMInputs:
|
||||||
|
'''
|
||||||
|
For decoder-only models:
|
||||||
|
Process an input prompt
|
||||||
|
into an `LLMInputs` instance.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* inputs: input prompt
|
||||||
|
* lora_request
|
||||||
|
* request_id
|
||||||
|
* prompt_adapter_request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* `LLMInputs` instance
|
||||||
|
'''
|
||||||
|
|
||||||
|
if isinstance(inputs, str):
|
||||||
|
inputs = {"prompt": inputs}
|
||||||
|
prompt = inputs.get("prompt")
|
||||||
|
|
||||||
|
if "prompt_token_ids" not in inputs:
|
||||||
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
|
prompt,
|
||||||
|
request_id=request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt_token_ids = inputs["prompt_token_ids"]
|
||||||
|
|
||||||
|
if prompt_adapter_request:
|
||||||
|
prompt_token_ids = (
|
||||||
|
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
||||||
|
+ prompt_token_ids)
|
||||||
|
|
||||||
|
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||||
|
prompt=prompt,
|
||||||
|
multi_modal_data=inputs.get("multi_modal_data"))
|
||||||
|
|
||||||
def process_model_inputs(
|
def process_model_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -569,29 +950,25 @@ class LLMEngine:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> LLMInputs:
|
) -> LLMInputs:
|
||||||
if isinstance(inputs, str):
|
|
||||||
inputs = {"prompt": inputs}
|
|
||||||
|
|
||||||
if "prompt_token_ids" not in inputs:
|
if self.is_encoder_decoder_model():
|
||||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
# Encoder-decoder model requires special mapping of
|
||||||
"skip_tokenizer_init is True")
|
# input prompts to encoder & decoder
|
||||||
|
|
||||||
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
model_inputs = self._process_encoder_decoder_prompt(
|
||||||
prompt=inputs["prompt"],
|
inputs,
|
||||||
lora_request=lora_request)
|
request_id=request_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_token_ids = inputs["prompt_token_ids"]
|
# Decoder-only operation
|
||||||
|
model_inputs = self._process_decoder_only_prompt(
|
||||||
|
inputs,
|
||||||
|
request_id=request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
)
|
||||||
|
|
||||||
if prompt_adapter_request:
|
return self.input_processor(model_inputs)
|
||||||
prompt_token_ids = \
|
|
||||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
|
||||||
+ prompt_token_ids
|
|
||||||
|
|
||||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
|
||||||
prompt=inputs.get("prompt"),
|
|
||||||
multi_modal_data=inputs.get("multi_modal_data"))
|
|
||||||
|
|
||||||
return self.input_processor(llm_inputs)
|
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
@ -676,6 +1053,7 @@ class LLMEngine:
|
|||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
encoder_seq: Optional[Sequence] = None,
|
||||||
) -> SequenceGroup:
|
) -> SequenceGroup:
|
||||||
"""Creates a SequenceGroup with SamplingParams."""
|
"""Creates a SequenceGroup with SamplingParams."""
|
||||||
max_logprobs = self.get_model_config().max_logprobs
|
max_logprobs = self.get_model_config().max_logprobs
|
||||||
@ -701,7 +1079,8 @@ class LLMEngine:
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
encoder_seq=encoder_seq)
|
||||||
|
|
||||||
return seq_group
|
return seq_group
|
||||||
|
|
||||||
@ -713,6 +1092,7 @@ class LLMEngine:
|
|||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
|
encoder_seq: Optional[Sequence] = None,
|
||||||
) -> SequenceGroup:
|
) -> SequenceGroup:
|
||||||
"""Creates a SequenceGroup with PoolingParams."""
|
"""Creates a SequenceGroup with PoolingParams."""
|
||||||
# Defensive copy of PoolingParams, which are used by the pooler
|
# Defensive copy of PoolingParams, which are used by the pooler
|
||||||
@ -724,7 +1104,8 @@ class LLMEngine:
|
|||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
encoder_seq=encoder_seq)
|
||||||
return seq_group
|
return seq_group
|
||||||
|
|
||||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
@ -1214,3 +1595,9 @@ class LLMEngine:
|
|||||||
seq_span.set_attribute(
|
seq_span.set_attribute(
|
||||||
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
|
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
|
||||||
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
|
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
|
||||||
|
|
||||||
|
def is_encoder_decoder_model(self):
|
||||||
|
return is_encoder_decoder_model_config(self.model_config)
|
||||||
|
|
||||||
|
def is_embedding_model(self):
|
||||||
|
return is_embedding_model_config(self.model_config)
|
||||||
|
@ -121,12 +121,21 @@ class LLM:
|
|||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
swap_space: int = 4,
|
swap_space: int = 4,
|
||||||
cpu_offload_gb: float = 0,
|
cpu_offload_gb: float = 0,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: Optional[bool] = None,
|
||||||
max_context_len_to_capture: Optional[int] = None,
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
max_seq_len_to_capture: int = 8192,
|
max_seq_len_to_capture: int = 8192,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
'''
|
||||||
|
LLM constructor.
|
||||||
|
|
||||||
|
Note: if enforce_eager is unset (enforce_eager is None)
|
||||||
|
it defaults to False for decoder-only models and True
|
||||||
|
for encoder/decoder models, since encoder/decoder models
|
||||||
|
do not currently support CUDAGraph.
|
||||||
|
'''
|
||||||
|
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
removed_vision_keys = ("image_token_id", "image_feature_size",
|
removed_vision_keys = ("image_token_id", "image_feature_size",
|
||||||
@ -297,8 +306,8 @@ class LLM:
|
|||||||
"""
|
"""
|
||||||
if self.llm_engine.model_config.embedding_mode:
|
if self.llm_engine.model_config.embedding_mode:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"LLM.generate() is only supported for generation models "
|
"LLM.generate() is only supported for (conditional) generation "
|
||||||
"(XForCausalLM).")
|
"models (XForCausalLM, XForConditionalGeneration).")
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
inputs = self._convert_v1_inputs(
|
inputs = self._convert_v1_inputs(
|
||||||
@ -631,3 +640,9 @@ class LLM:
|
|||||||
# This is necessary because some requests may be finished earlier than
|
# This is necessary because some requests may be finished earlier than
|
||||||
# its previous requests.
|
# its previous requests.
|
||||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||||
|
|
||||||
|
def _is_encoder_decoder_model(self):
|
||||||
|
return self.llm_engine.is_encoder_decoder_model()
|
||||||
|
|
||||||
|
def _is_embedding_model(self):
|
||||||
|
return self.llm_engine.is_embedding_model()
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
|
from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
|
||||||
TextPrompt, TokensPrompt, parse_and_batch_prompt)
|
ParsedTokens, PromptInputs, SingletonPromptInputs,
|
||||||
|
TextPrompt, TokensPrompt, get_prompt_type,
|
||||||
|
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
|
||||||
from .registry import InputContext, InputRegistry
|
from .registry import InputContext, InputRegistry
|
||||||
|
|
||||||
INPUT_REGISTRY = InputRegistry()
|
INPUT_REGISTRY = InputRegistry()
|
||||||
@ -12,7 +14,18 @@ See also:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
|
"ParsedText",
|
||||||
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
|
"ParsedTokens",
|
||||||
"InputContext", "InputRegistry"
|
"parse_and_batch_prompt",
|
||||||
|
"TextPrompt",
|
||||||
|
"TokensPrompt",
|
||||||
|
"PromptInputs",
|
||||||
|
"LLMInputs",
|
||||||
|
"INPUT_REGISTRY",
|
||||||
|
"InputContext",
|
||||||
|
"InputRegistry",
|
||||||
|
"get_prompt_type",
|
||||||
|
"is_valid_encoder_decoder_llm_inputs",
|
||||||
|
"ExplicitEncoderDecoderPrompt",
|
||||||
|
"SingletonPromptInputs",
|
||||||
]
|
]
|
||||||
|
@ -92,15 +92,114 @@ class TokensPrompt(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
PromptInputs = Union[str, TextPrompt, TokensPrompt]
|
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
|
||||||
"""
|
"""
|
||||||
The inputs to the LLM, which can take one of the following forms:
|
Set of possible schemas for a single LLM input:
|
||||||
|
|
||||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||||
- A tokenized prompt (:class:`TokensPrompt`)
|
- A tokenized prompt (:class:`TokensPrompt`)
|
||||||
|
|
||||||
|
Note that "singleton" is as opposed to a data structure
|
||||||
|
which encapsulates multiple prompts, i.e. of the sort
|
||||||
|
which may be utilized for encoder/decoder models when
|
||||||
|
the user desires to express both the encoder & decoder
|
||||||
|
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
|
||||||
|
|
||||||
|
A prompt of type SingletonPromptInputs may be employed
|
||||||
|
as (1) input to a decoder-only model, (2) input to
|
||||||
|
the encoder of an encoder/decoder model, in the scenario
|
||||||
|
where the decoder-prompt is not specified explicitly, or
|
||||||
|
(3) as a member of a larger data structure encapsulating
|
||||||
|
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ExplicitEncoderDecoderPrompt(TypedDict):
|
||||||
|
"""Represents an encoder/decoder model input prompt,
|
||||||
|
comprising an explicit encoder prompt and a
|
||||||
|
decoder prompt.
|
||||||
|
|
||||||
|
The encoder and decoder prompts, respectively,
|
||||||
|
may formatted according to any of the
|
||||||
|
SingletonPromptInputs schemas, and are not
|
||||||
|
required to have the same schema.
|
||||||
|
|
||||||
|
Only the encoder prompt may have multi-modal data.
|
||||||
|
|
||||||
|
Note that an ExplicitEncoderDecoderPrompt may not
|
||||||
|
be used as an input to a decoder-only model,
|
||||||
|
and that the `encoder_prompt` and `decoder_prompt`
|
||||||
|
fields of this data structure may not themselves
|
||||||
|
must be SingletonPromptInputs instances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
encoder_prompt: SingletonPromptInputs
|
||||||
|
|
||||||
|
decoder_prompt: SingletonPromptInputs
|
||||||
|
|
||||||
|
|
||||||
|
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
|
||||||
|
"""
|
||||||
|
Set of possible schemas for an LLM input, including
|
||||||
|
both decoder-only and encoder/decoder input types:
|
||||||
|
|
||||||
|
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||||
|
- A tokenized prompt (:class:`TokensPrompt`)
|
||||||
|
- A single data structure containing both an encoder and a decoder prompt
|
||||||
|
(:class:`ExplicitEncoderDecoderPrompt`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _has_required_keys(
|
||||||
|
d: dict,
|
||||||
|
required_keys: set,
|
||||||
|
) -> bool:
|
||||||
|
return required_keys.issubset(d.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the type-name of the prompt argument instance, given that
|
||||||
|
isinstance() cannot apply to TypedDict subclasses directly.
|
||||||
|
If the prompt is None, return 'None' as the type name.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* prompt: LLM input prompt or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
* String representation of prompt type
|
||||||
|
"""
|
||||||
|
|
||||||
|
if prompt is None:
|
||||||
|
return 'None'
|
||||||
|
|
||||||
|
required_keys_dict = {
|
||||||
|
'TextPrompt': {'prompt'},
|
||||||
|
'TokensPrompt': {'prompt_token_ids'},
|
||||||
|
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(prompt, dict):
|
||||||
|
for (ptype, required_keys) in required_keys_dict.items():
|
||||||
|
# Ignore type checking in the conditional below because type
|
||||||
|
# checker does not understand that is_dict(prompt) narrows
|
||||||
|
# down the possible types
|
||||||
|
if _has_required_keys(
|
||||||
|
prompt, # type: ignore
|
||||||
|
required_keys):
|
||||||
|
return ptype
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid prompt {prompt}, valid types are "
|
||||||
|
"required_keys_dict={required_keys_dict}")
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
return "str"
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid prompt {prompt}")
|
||||||
|
|
||||||
|
|
||||||
class LLMInputs(TypedDict):
|
class LLMInputs(TypedDict):
|
||||||
"""
|
"""
|
||||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||||
@ -114,8 +213,29 @@ class LLMInputs(TypedDict):
|
|||||||
The original prompt text corresponding to the token IDs, if available.
|
The original prompt text corresponding to the token IDs, if available.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
encoder_prompt_token_ids: NotRequired[List[int]]
|
||||||
|
"""The token IDs of the encoder prompt."""
|
||||||
|
|
||||||
|
encoder_prompt: NotRequired[Optional[str]]
|
||||||
|
"""
|
||||||
|
The original encoder prompt text corresponding to the token IDs, if
|
||||||
|
available.
|
||||||
|
"""
|
||||||
|
|
||||||
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||||
"""
|
"""
|
||||||
Optional multi-modal data to pass to the model,
|
Optional multi-modal data to pass to the model,
|
||||||
if the model supports it.
|
if the model supports it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if the LLMInputs instance has the correct configuration
|
||||||
|
for encoder/decoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# True if encoder prompt token ids field exists &
|
||||||
|
# is not None
|
||||||
|
return ('encoder_prompt_token_ids' in inputs
|
||||||
|
and inputs['encoder_prompt_token_ids'] is not None)
|
||||||
|
@ -83,7 +83,16 @@ _EMBEDDING_MODELS = {
|
|||||||
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
|
_CONDITIONAL_GENERATION_MODELS = {
|
||||||
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
|
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_MODELS = {
|
||||||
|
**_GENERATION_MODELS,
|
||||||
|
**_EMBEDDING_MODELS,
|
||||||
|
**_CONDITIONAL_GENERATION_MODELS
|
||||||
|
}
|
||||||
|
|
||||||
# Architecture -> type.
|
# Architecture -> type.
|
||||||
# out of tree models
|
# out of tree models
|
||||||
|
996
vllm/model_executor/models/bart.py
Normal file
996
vllm/model_executor/models/bart.py
Normal file
@ -0,0 +1,996 @@
|
|||||||
|
# Derived from BART implementation posted on HuggingFace; license below:
|
||||||
|
#
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch BART model."""
|
||||||
|
import math
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import BartConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bsz_seq_len(input_ids):
|
||||||
|
shp = input_ids.shape
|
||||||
|
ndim = len(shp)
|
||||||
|
if ndim == 1:
|
||||||
|
return 1, input_ids.numel()
|
||||||
|
else:
|
||||||
|
return shp[:2]
|
||||||
|
|
||||||
|
|
||||||
|
class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
|
||||||
|
"""
|
||||||
|
This module learns positional embeddings up to a fixed maximum size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||||
|
# Bart is set up so that if padding_idx is
|
||||||
|
# specified then offset the embedding ids by 2
|
||||||
|
# and adjust num_embeddings appropriately.
|
||||||
|
# Other models don't have this hack
|
||||||
|
self.offset = 2
|
||||||
|
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
||||||
|
|
||||||
|
assert attn_type != AttentionType.ENCODER_DECODER
|
||||||
|
|
||||||
|
return super().forward(positions + self.offset)
|
||||||
|
|
||||||
|
|
||||||
|
class BartScaledWordEmbedding(VocabParallelEmbedding):
|
||||||
|
"""
|
||||||
|
This module overrides VocabParallelEmbedding's
|
||||||
|
forward by multiplying with embeddings scale.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
embed_scale: float = 1.0):
|
||||||
|
super().__init__(num_embeddings, embedding_dim)
|
||||||
|
self.embed_scale = embed_scale
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(input_ids) * self.embed_scale
|
||||||
|
|
||||||
|
|
||||||
|
class BartParallelLMHead(ParallelLMHead):
|
||||||
|
"""
|
||||||
|
This module overrides ParallelLMHead's
|
||||||
|
forward by dividing by embeddings scale,
|
||||||
|
yielding effectively the inverse of
|
||||||
|
BartScaledWordEmbedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
embed_scale: float = 1.0):
|
||||||
|
super().__init__(num_embeddings, embedding_dim)
|
||||||
|
self.embed_scale = embed_scale
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(input_ids) / self.embed_scale
|
||||||
|
|
||||||
|
|
||||||
|
class BartEncoderAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
bias: bool = True,
|
||||||
|
config: Optional[BartConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = config.d_model
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
self.total_num_kv_heads = self.total_num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if (self.head_dim * num_heads) != self.embed_dim:
|
||||||
|
raise ValueError(f"embed_dim must be divisible by num_heads "
|
||||||
|
f"(got `embed_dim`: {self.embed_dim}"
|
||||||
|
f" and `num_heads`: {num_heads}).")
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
self.d_model,
|
||||||
|
self.d_model // self.total_num_heads,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert self.total_num_heads % tp_world_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_world_size
|
||||||
|
|
||||||
|
if self.total_num_kv_heads >= tp_world_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_world_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
attn_output = self.attn(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=AttentionType.ENCODER)
|
||||||
|
|
||||||
|
output, _ = self.out_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BartDecoderSelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
bias: bool = True,
|
||||||
|
config: Optional[BartConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = config.d_model
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
self.total_num_kv_heads = self.total_num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if (self.head_dim * num_heads) != self.embed_dim:
|
||||||
|
raise ValueError(f"embed_dim must be divisible by num_heads "
|
||||||
|
f"(got `embed_dim`: {self.embed_dim}"
|
||||||
|
f" and `num_heads`: {num_heads}).")
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
self.d_model,
|
||||||
|
self.d_model // self.total_num_heads,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert self.total_num_heads % tp_world_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_world_size
|
||||||
|
|
||||||
|
if self.total_num_kv_heads >= tp_world_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_world_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
attn_output = self.attn(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=AttentionType.DECODER)
|
||||||
|
|
||||||
|
output, _ = self.out_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BartCrossAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
bias: bool = True,
|
||||||
|
config: Optional[BartConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = config.d_model
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
self.total_num_kv_heads = self.total_num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if (self.head_dim * num_heads) != self.embed_dim:
|
||||||
|
raise ValueError(f"embed_dim must be divisible by num_heads "
|
||||||
|
f"(got `embed_dim`: {self.embed_dim}"
|
||||||
|
f" and `num_heads`: {num_heads}).")
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
self.d_model,
|
||||||
|
self.d_model // self.total_num_heads,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert self.total_num_heads % tp_world_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_world_size
|
||||||
|
|
||||||
|
if self.total_num_kv_heads >= tp_world_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_world_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
decoder_hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
# (afeldman-nm 2024/07/22) TODO:
|
||||||
|
# Need a more efficient solution for q/k/v
|
||||||
|
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
|
||||||
|
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
|
||||||
|
dim=-1)
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
k = None
|
||||||
|
v = None
|
||||||
|
else:
|
||||||
|
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
|
||||||
|
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
attn_output = self.attn(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER)
|
||||||
|
|
||||||
|
output, _ = self.out_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BartEncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.d_model
|
||||||
|
|
||||||
|
self.self_attn = BartEncoderAttention(
|
||||||
|
embed_dim=self.embed_dim,
|
||||||
|
num_heads=config.encoder_attention_heads,
|
||||||
|
config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
self.activation_fn = get_act_fn(config.activation_function,
|
||||||
|
quant_config)
|
||||||
|
|
||||||
|
ffn_hidden_size = self.embed_dim
|
||||||
|
ffn_intermediate_size = config.encoder_ffn_dim
|
||||||
|
ffn_has_bias = True
|
||||||
|
self.fc1 = ColumnParallelLinear(
|
||||||
|
ffn_hidden_size,
|
||||||
|
ffn_intermediate_size,
|
||||||
|
bias=ffn_has_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
|
||||||
|
self.fc2 = RowParallelLinear(
|
||||||
|
ffn_intermediate_size,
|
||||||
|
ffn_hidden_size,
|
||||||
|
bias=ffn_has_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
hidden_states
|
||||||
|
torch.Tensor of *encoder* input embeddings.
|
||||||
|
kv_cache:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Encoder layer output torch.Tensor
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
fc1_out, _ = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(fc1_out)
|
||||||
|
|
||||||
|
hidden_states, _ = self.fc2(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
if hidden_states.dtype == torch.float16 and (
|
||||||
|
torch.isinf(hidden_states).any()
|
||||||
|
or torch.isnan(hidden_states).any()):
|
||||||
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||||
|
hidden_states = torch.clamp(hidden_states,
|
||||||
|
min=-clamp_value,
|
||||||
|
max=clamp_value)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BartDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.d_model
|
||||||
|
|
||||||
|
self.self_attn = BartDecoderSelfAttention(
|
||||||
|
embed_dim=self.embed_dim,
|
||||||
|
num_heads=config.decoder_attention_heads,
|
||||||
|
config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.activation_fn = get_act_fn(config.activation_function,
|
||||||
|
quant_config)
|
||||||
|
|
||||||
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
'''
|
||||||
|
afeldman-nm: personally I would call this "cross-attention",
|
||||||
|
however I left the name as "encoder_attn" to maintain consistency
|
||||||
|
with the name of the pretrained weights.
|
||||||
|
'''
|
||||||
|
self.encoder_attn = BartCrossAttention(
|
||||||
|
self.embed_dim,
|
||||||
|
config.decoder_attention_heads,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
|
ffn_hidden_size = self.embed_dim
|
||||||
|
ffn_intermediate_size = config.encoder_ffn_dim
|
||||||
|
ffn_has_bias = True
|
||||||
|
self.fc1 = ColumnParallelLinear(
|
||||||
|
ffn_hidden_size,
|
||||||
|
ffn_intermediate_size,
|
||||||
|
bias=ffn_has_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.fc2 = RowParallelLinear(
|
||||||
|
ffn_intermediate_size,
|
||||||
|
ffn_hidden_size,
|
||||||
|
bias=ffn_has_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
decoder_hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
decoder_hidden_states
|
||||||
|
torch.Tensor of *decoder* input embeddings.
|
||||||
|
kv_cache:
|
||||||
|
KV cache tensor
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
encoder_hidden_states
|
||||||
|
torch.Tensor of *encoder* input embeddings.
|
||||||
|
Returns:
|
||||||
|
Decoder layer output torch.Tensor
|
||||||
|
"""
|
||||||
|
residual = decoder_hidden_states
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
# Cross-Attention Block
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.encoder_attn(
|
||||||
|
decoder_hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
fc1_out, _ = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(fc1_out)
|
||||||
|
|
||||||
|
hidden_states, _ = self.fc2(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BartEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer encoder consisting of *config.encoder_layers*
|
||||||
|
self attention layers. Each layer is a [`BartEncoderLayer`].
|
||||||
|
Args:
|
||||||
|
config: BartConfig
|
||||||
|
embed_tokens (nn.Embedding): output embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
embed_tokens: Optional[nn.Embedding] = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
embed_dim = config.d_model
|
||||||
|
self.max_source_positions = config.max_position_embeddings
|
||||||
|
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||||
|
embed_dim,
|
||||||
|
embed_scale=embed_scale)
|
||||||
|
|
||||||
|
if embed_tokens is not None:
|
||||||
|
self.embed_tokens.weight = embed_tokens.weight
|
||||||
|
|
||||||
|
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
embed_dim,
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[BartEncoderLayer(config,cache_config,quant_config) \
|
||||||
|
for _ in range(config.encoder_layers)])
|
||||||
|
|
||||||
|
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||||
|
Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
positions
|
||||||
|
Positions of *encoder* input sequence tokens.
|
||||||
|
kv_caches:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Decoder output torch.Tensor
|
||||||
|
"""
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
|
||||||
|
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
embed_pos = self.embed_positions(
|
||||||
|
positions,
|
||||||
|
AttentionType.ENCODER,
|
||||||
|
)
|
||||||
|
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + embed_pos
|
||||||
|
hidden_states = self.layernorm_embedding(hidden_states)
|
||||||
|
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_caches[idx],
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BartDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.decoder_layers* layers.
|
||||||
|
Each layer is a [`BartDecoderLayer`]
|
||||||
|
Args:
|
||||||
|
config: BartConfig
|
||||||
|
embed_tokens (nn.Embedding): output embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
embed_tokens: Optional[nn.Embedding] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.max_target_positions = config.max_position_embeddings
|
||||||
|
embed_scale = math.sqrt(
|
||||||
|
config.d_model) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
embed_scale=embed_scale)
|
||||||
|
|
||||||
|
if embed_tokens is not None:
|
||||||
|
self.embed_tokens.weight = embed_tokens.weight
|
||||||
|
|
||||||
|
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.d_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[BartDecoderLayer(config,cache_config,quant_config) \
|
||||||
|
for _ in range(config.decoder_layers)])
|
||||||
|
|
||||||
|
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
|
def forward(self, decoder_input_ids: torch.Tensor,
|
||||||
|
decoder_positions: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor],
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
decoder_input_ids
|
||||||
|
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||||
|
Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
decoder_positions
|
||||||
|
Positions of *decoder* input sequence tokens.
|
||||||
|
encoder_hidden_states:
|
||||||
|
Tensor of encoder output embeddings
|
||||||
|
kv_caches:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Decoder output torch.Tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
embed_pos = self.embed_positions(
|
||||||
|
decoder_positions,
|
||||||
|
AttentionType.DECODER,
|
||||||
|
)
|
||||||
|
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + embed_pos
|
||||||
|
hidden_states = self.layernorm_embedding(hidden_states)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
hidden_states = decoder_layer(
|
||||||
|
decoder_hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_caches[idx],
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BartModel(nn.Module):
|
||||||
|
_tied_weights_keys = [
|
||||||
|
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.encoder = BartEncoder(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.decoder = BartDecoder(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
|
encoder_input_ids: torch.Tensor,
|
||||||
|
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||||
|
Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
positions
|
||||||
|
Positions of *decoder* input sequence tokens.
|
||||||
|
encoder_input_ids
|
||||||
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||||
|
encoder_positions:
|
||||||
|
Positions of *encoder* input sequence tokens.
|
||||||
|
kv_caches:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Model output torch.Tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
encoder_hidden_states = None
|
||||||
|
|
||||||
|
if encoder_input_ids.numel() > 0:
|
||||||
|
# Run encoder attention if a non-zero number of encoder tokens
|
||||||
|
# are provided as input
|
||||||
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||||
|
positions=encoder_positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
# decoder outputs consists of
|
||||||
|
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||||
|
decoder_outputs = self.decoder(
|
||||||
|
decoder_input_ids=input_ids,
|
||||||
|
decoder_positions=positions,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
return decoder_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BartForConditionalGeneration(nn.Module):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = BartModel(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
lora_config=lora_config)
|
||||||
|
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
|
||||||
|
embed_scale = math.sqrt(
|
||||||
|
config.d_model) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
self.lm_head = BartParallelLMHead(config.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
embed_scale=embed_scale)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
encoder_input_ids: torch.Tensor,
|
||||||
|
encoder_positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
torch.Tensor of *decoder* input token ids.
|
||||||
|
positions
|
||||||
|
torch.Tensor of *decoder* position indices.
|
||||||
|
encoder_input_ids
|
||||||
|
torch.Tensor of *encoder* input token ids.
|
||||||
|
encoder_positions
|
||||||
|
torch.Tensor of *encoder* position indices
|
||||||
|
kv_caches:
|
||||||
|
Layer-wise list of KV cache tensors
|
||||||
|
attn_metadata:
|
||||||
|
vLLM Attention metadata structure
|
||||||
|
Returns:
|
||||||
|
Output torch.Tensor
|
||||||
|
"""
|
||||||
|
return self.model(input_ids, positions, encoder_input_ids,
|
||||||
|
encoder_positions, kv_caches, attn_metadata)
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
stacked_params_mapping = {
|
||||||
|
"q_proj": {
|
||||||
|
"param_name": "qkv_proj",
|
||||||
|
"shard_id": "q",
|
||||||
|
},
|
||||||
|
"k_proj": {
|
||||||
|
"param_name": "qkv_proj",
|
||||||
|
"shard_id": "k",
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"param_name": "qkv_proj",
|
||||||
|
"shard_id": "v",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
params_mapping = {
|
||||||
|
"beta": "bias",
|
||||||
|
"gamma": "weight",
|
||||||
|
"LayerNorm": "layernorm",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _rename_key(self, key: str):
|
||||||
|
prefix = f"{self.base_model_prefix}."
|
||||||
|
key = key[len(prefix):] if key.startswith(prefix) else key
|
||||||
|
|
||||||
|
for src, dst in self.params_mapping.items():
|
||||||
|
key = key.replace(src, dst)
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _rename_stacked_param(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
) -> Tuple[str, Optional[str]]:
|
||||||
|
for key, mapping in self.stacked_params_mapping.items():
|
||||||
|
if key in name:
|
||||||
|
name = name.replace(key, mapping["param_name"])
|
||||||
|
return name, mapping["shard_id"]
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
|
model_params_dict = dict(self.model.named_parameters())
|
||||||
|
top_params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
weights_tuple_list = list(weights)
|
||||||
|
|
||||||
|
shared_embedding_weight = None
|
||||||
|
shared_embedding_shard_id = None
|
||||||
|
|
||||||
|
for name, loaded_weight in weights_tuple_list:
|
||||||
|
|
||||||
|
name = self._rename_key(name)
|
||||||
|
name, shard_id = self._rename_stacked_param(name)
|
||||||
|
|
||||||
|
if ('shared.weight' in name
|
||||||
|
or 'encoder.embed_tokens.weight' in name
|
||||||
|
or 'decoder.embed_tokens.weight' in name
|
||||||
|
or 'lm_head.weight' in name):
|
||||||
|
assert shared_embedding_weight is None, (
|
||||||
|
"Conflicting embedding weights.")
|
||||||
|
shared_embedding_weight = loaded_weight
|
||||||
|
shared_embedding_shard_id = shard_id
|
||||||
|
else:
|
||||||
|
# Skip the specific downstream task weight.
|
||||||
|
if name.startswith('cls.'):
|
||||||
|
continue
|
||||||
|
# use Pooler instead.
|
||||||
|
if name.startswith('pooler.'):
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in model_params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = model_params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
if shard_id:
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
else:
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# Assign shared weight values
|
||||||
|
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
|
||||||
|
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
|
||||||
|
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
|
||||||
|
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
|
||||||
|
lm_head_in_param = top_params_dict['lm_head.weight']
|
||||||
|
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
|
||||||
|
assert shared_embedding_weight is not None
|
||||||
|
|
||||||
|
if shared_embedding_shard_id:
|
||||||
|
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
|
||||||
|
shared_embedding_shard_id)
|
||||||
|
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
|
||||||
|
shared_embedding_shard_id)
|
||||||
|
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
|
||||||
|
shared_embedding_shard_id)
|
||||||
|
else:
|
||||||
|
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
|
||||||
|
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
|
||||||
|
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
|
@ -70,12 +70,20 @@ class RequestOutput:
|
|||||||
Args:
|
Args:
|
||||||
request_id: The unique ID of the request.
|
request_id: The unique ID of the request.
|
||||||
prompt: The prompt string of the request.
|
prompt: The prompt string of the request.
|
||||||
|
For encoder/decoder models, this is the
|
||||||
|
decoder input prompt.
|
||||||
prompt_token_ids: The token IDs of the prompt.
|
prompt_token_ids: The token IDs of the prompt.
|
||||||
|
For encoder/decoder models, this is the
|
||||||
|
decoder input prompt token ids.
|
||||||
prompt_logprobs: The log probabilities to return per prompt token.
|
prompt_logprobs: The log probabilities to return per prompt token.
|
||||||
outputs: The output sequences of the request.
|
outputs: The output sequences of the request.
|
||||||
finished: Whether the whole request is finished.
|
finished: Whether the whole request is finished.
|
||||||
metrics: Metrics associated with the request.
|
metrics: Metrics associated with the request.
|
||||||
lora_request: The LoRA request that was used to generate the output.
|
lora_request: The LoRA request that was used to generate the output.
|
||||||
|
encoder_prompt: The encoder prompt string of the request;
|
||||||
|
None if decoder-only
|
||||||
|
encoder_prompt_token_ids: The token IDs of the encoder prompt;
|
||||||
|
None if decoder-only
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -88,6 +96,8 @@ class RequestOutput:
|
|||||||
finished: bool,
|
finished: bool,
|
||||||
metrics: Optional[RequestMetrics] = None,
|
metrics: Optional[RequestMetrics] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
encoder_prompt: Optional[str] = None,
|
||||||
|
encoder_prompt_token_ids: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@ -97,6 +107,8 @@ class RequestOutput:
|
|||||||
self.finished = finished
|
self.finished = finished
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
|
self.encoder_prompt = encoder_prompt
|
||||||
|
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||||
@ -137,6 +149,8 @@ class RequestOutput:
|
|||||||
# Every sequence in the sequence group should have the same prompt.
|
# Every sequence in the sequence group should have the same prompt.
|
||||||
prompt = seq_group.prompt
|
prompt = seq_group.prompt
|
||||||
prompt_token_ids = seq_group.prompt_token_ids
|
prompt_token_ids = seq_group.prompt_token_ids
|
||||||
|
encoder_prompt = seq_group.encoder_prompt
|
||||||
|
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
|
||||||
prompt_logprobs = seq_group.prompt_logprobs
|
prompt_logprobs = seq_group.prompt_logprobs
|
||||||
finished = seq_group.is_finished()
|
finished = seq_group.is_finished()
|
||||||
finished_time = time.time() if finished else None
|
finished_time = time.time() if finished else None
|
||||||
@ -148,12 +162,16 @@ class RequestOutput:
|
|||||||
outputs,
|
outputs,
|
||||||
finished,
|
finished,
|
||||||
seq_group.metrics,
|
seq_group.metrics,
|
||||||
lora_request=seq_group.lora_request)
|
lora_request=seq_group.lora_request,
|
||||||
|
encoder_prompt=encoder_prompt,
|
||||||
|
encoder_prompt_token_ids=encoder_prompt_token_ids)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"RequestOutput(request_id={self.request_id}, "
|
return (f"RequestOutput(request_id={self.request_id}, "
|
||||||
f"prompt={self.prompt!r}, "
|
f"prompt={self.prompt!r}, "
|
||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
|
f"encoder_prompt={self.encoder_prompt!r}, "
|
||||||
|
f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
|
||||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||||
f"outputs={self.outputs}, "
|
f"outputs={self.outputs}, "
|
||||||
f"finished={self.finished}, "
|
f"finished={self.finished}, "
|
||||||
|
@ -7,10 +7,11 @@ from array import array
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||||
Union)
|
Union, cast)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
@ -244,13 +245,26 @@ class SequenceData:
|
|||||||
class Sequence:
|
class Sequence:
|
||||||
"""Stores the data, status, and block information of a sequence.
|
"""Stores the data, status, and block information of a sequence.
|
||||||
|
|
||||||
|
The sequence is constructed from the LLMInputs instance passed
|
||||||
|
in through the `inputs` constructor argument.
|
||||||
|
|
||||||
|
For encoder/decoder models, LLMInputs encapsulates both a
|
||||||
|
decoder and encoder prompt, creating an ambiguity about which
|
||||||
|
prompt to construct the sequence from. The `from_decoder_prompt`
|
||||||
|
constructor argument signals whether to construct the Sequence
|
||||||
|
from the LLMInputs decoder prompt, or encoder prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_id: The ID of the sequence.
|
seq_id: The ID of the sequence.
|
||||||
inputs: The inputs of the sequence.
|
inputs: The inputs of the sequence.
|
||||||
block_size: The block size of the sequence. Should be the same as the
|
block_size: The block size of the sequence. Should be the same as the
|
||||||
block size used by the block manager and cache engine.
|
block size used by the block manager and cache engine.
|
||||||
|
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
prompt_adapter_request: Prompt Adapter request.
|
prompt_adapter_request: Prompt Adapter request.
|
||||||
|
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
|
||||||
|
(True) or encoder prompt (False.) Must be True
|
||||||
|
for decoder-only model.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -261,7 +275,8 @@ class Sequence:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
from_decoder_prompt: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_id = seq_id
|
self.seq_id = seq_id
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
@ -269,6 +284,36 @@ class Sequence:
|
|||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
|
self.from_decoder_prompt = from_decoder_prompt
|
||||||
|
self._prompt: Optional[str] = None
|
||||||
|
self._prompt_token_ids: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# For decoder-only models, a Sequence is constructed
|
||||||
|
# from an LLMInputs instance (the `inputs` arg.)
|
||||||
|
#
|
||||||
|
# For encoder/decoder models the same `inputs`
|
||||||
|
# instance could be utilized to construct either an
|
||||||
|
# encoder sequence or a decoder sequence, because
|
||||||
|
# `LLMInputs` has both decoder- and encoder-oriented
|
||||||
|
# member variables (i.e. it encapsulates both an encoder
|
||||||
|
# and a decoder prompt.) The decision of which type of sequence
|
||||||
|
# to generate is determined by the `from_decoder_prompt` argument.
|
||||||
|
#
|
||||||
|
# When constructing a encoder sequence
|
||||||
|
# (`from_decoder_prompt` False) it matters that
|
||||||
|
# the `LLMInputs` instance stored in `inputs` is valid
|
||||||
|
# in the sense that its encoder-related member variables are
|
||||||
|
# populated; below, an exception is raised if this is
|
||||||
|
# not the case.
|
||||||
|
#
|
||||||
|
# When constructing a decoder sequence (`from_decoder_prompt` True)
|
||||||
|
# it does not matter whether `inputs` has its encoder-related
|
||||||
|
# member variables populated.
|
||||||
|
if not (from_decoder_prompt
|
||||||
|
or is_valid_encoder_decoder_llm_inputs(inputs)):
|
||||||
|
raise ValueError("Cannot extract encoder input prompt from "
|
||||||
|
f"invalid input {inputs}; did you forget the "
|
||||||
|
"encoder input prompt fields?")
|
||||||
|
|
||||||
self.data = SequenceData(self.prompt_token_ids)
|
self.data = SequenceData(self.prompt_token_ids)
|
||||||
self.output_logprobs: SampleLogprobs = []
|
self.output_logprobs: SampleLogprobs = []
|
||||||
@ -289,11 +334,35 @@ class Sequence:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
return self.inputs.get("prompt")
|
if self._prompt is not None:
|
||||||
|
# Reuse precomputed prompt string
|
||||||
|
return self._prompt
|
||||||
|
|
||||||
|
# Select decoder or encoder input prompt str,
|
||||||
|
# as appropriate
|
||||||
|
prompt_key: str = ("prompt"
|
||||||
|
if self.from_decoder_prompt else "encoder_prompt")
|
||||||
|
|
||||||
|
# Cache prompt
|
||||||
|
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
|
||||||
|
return self._prompt
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt_token_ids(self) -> List[int]:
|
def prompt_token_ids(self) -> List[int]:
|
||||||
return self.inputs["prompt_token_ids"]
|
if self._prompt_token_ids is not None:
|
||||||
|
# Reuse precomputed prompt token ids
|
||||||
|
return self._prompt_token_ids
|
||||||
|
|
||||||
|
# Select decoder or encoder input prompt
|
||||||
|
# token ids, as appropriate
|
||||||
|
prompt_token_ids_key: str = ("prompt_token_ids"
|
||||||
|
if self.from_decoder_prompt else
|
||||||
|
"encoder_prompt_token_ids")
|
||||||
|
|
||||||
|
# Cache computed prompt token ids
|
||||||
|
self._prompt_token_ids = cast(List[int],
|
||||||
|
self.inputs.get(prompt_token_ids_key))
|
||||||
|
return self._prompt_token_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||||
@ -472,6 +541,22 @@ class SequenceGroup:
|
|||||||
# We use the prompt of an arbitrary sequence.
|
# We use the prompt of an arbitrary sequence.
|
||||||
return self.seqs[0].prompt_token_ids
|
return self.seqs[0].prompt_token_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoder_prompt(self) -> Optional[str]:
|
||||||
|
# There are either 0 or 1 encoder sequences
|
||||||
|
# If one is present, its prompt is distinct
|
||||||
|
# from the decoder's.
|
||||||
|
return (self.encoder_seq.prompt
|
||||||
|
if self.encoder_seq is not None else None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
|
||||||
|
# There are either 0 or 1 encoder sequences
|
||||||
|
# If one is present, its prompt token ids are
|
||||||
|
# distinct from the decoder's.
|
||||||
|
return (self.encoder_seq.prompt_token_ids
|
||||||
|
if self.encoder_seq is not None else None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||||
# All sequences in the group should have the same multi-modal data.
|
# All sequences in the group should have the same multi-modal data.
|
||||||
|
130
vllm/utils.py
130
vllm/utils.py
@ -27,10 +27,93 @@ from typing_extensions import ParamSpec
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
|
||||||
|
SingletonPromptInputs)
|
||||||
from vllm.logger import enable_trace_function_call, init_logger
|
from vllm.logger import enable_trace_function_call, init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Exception strings for non-implemented encoder/decoder scenarios
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||||
|
"Sliding window attention for encoder/decoder models " + \
|
||||||
|
"is not currently supported."
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
||||||
|
"Prefix caching for encoder/decoder models " + \
|
||||||
|
"is not currently supported."
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
|
||||||
|
"Chunked prefill for encoder/decoder models " + \
|
||||||
|
"is not currently supported."
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
|
||||||
|
"Models with logits_soft_cap "
|
||||||
|
"require FlashInfer backend, which is "
|
||||||
|
"currently not supported for encoder/decoder "
|
||||||
|
"models.")
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
|
||||||
|
"supported with encoder/decoder "
|
||||||
|
"models.")
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
|
||||||
|
"currently supported with "
|
||||||
|
"encoder/decoder models.")
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
|
||||||
|
"supported with encoder/decoder "
|
||||||
|
"models.")
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
|
||||||
|
"currently supported with encoder/"
|
||||||
|
"decoder models.")
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
|
||||||
|
"currently supported with encoder/"
|
||||||
|
"decoder models.")
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
|
||||||
|
"currently supported with encoder/"
|
||||||
|
"decoder models.")
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
||||||
|
"currently supported with encoder/"
|
||||||
|
"decoder models.")
|
||||||
|
|
||||||
|
# Efficiently import all enc/dec error strings
|
||||||
|
# rather than having to import all of the above
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
|
||||||
|
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Constants related to forcing the attention backend selection
|
||||||
|
|
||||||
|
# String name of register which may be set in order to
|
||||||
|
# force auto-selection of attention backend by Attention
|
||||||
|
# wrapper
|
||||||
|
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||||
|
|
||||||
|
# Possible string values of STR_BACKEND_ENV_VAR
|
||||||
|
# register, corresponding to possible backends
|
||||||
|
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||||
|
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||||
|
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||||
|
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||||
|
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||||
|
STR_INVALID_VAL: str = "INVALID"
|
||||||
|
|
||||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.half,
|
"half": torch.half,
|
||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
|||||||
"""Utility function to run async task in a lock"""
|
"""Utility function to run async task in a lock"""
|
||||||
async with lock:
|
async with lock:
|
||||||
return await task(*args, **kwargs)
|
return await task(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def is_encoder_decoder_model_config(model_config) -> bool:
|
||||||
|
'''
|
||||||
|
Extract the HF encoder/decoder model flag from the ModelConfig instance.
|
||||||
|
Return False if model_config is None.
|
||||||
|
'''
|
||||||
|
return model_config is not None and \
|
||||||
|
getattr(model_config.hf_config,
|
||||||
|
"is_encoder_decoder",
|
||||||
|
False)
|
||||||
|
|
||||||
|
|
||||||
|
def is_embedding_model_config(model_config) -> bool:
|
||||||
|
'''
|
||||||
|
Extract the embedding model flag from the ModelConfig instance.
|
||||||
|
Return False if model_config is None.
|
||||||
|
'''
|
||||||
|
return model_config is not None and \
|
||||||
|
model_config.embedding_mode
|
||||||
|
|
||||||
|
|
||||||
|
def build_explicit_enc_dec_prompt(
|
||||||
|
encoder_prompt: SingletonPromptInputs,
|
||||||
|
decoder_prompt: SingletonPromptInputs,
|
||||||
|
) -> ExplicitEncoderDecoderPrompt:
|
||||||
|
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
|
||||||
|
decoder_prompt=decoder_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def zip_enc_dec_prompt_lists(
|
||||||
|
enc_prompt_list: List[SingletonPromptInputs],
|
||||||
|
dec_prompt_list: List[SingletonPromptInputs],
|
||||||
|
) -> List[ExplicitEncoderDecoderPrompt]:
|
||||||
|
return [
|
||||||
|
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
|
||||||
|
for (encoder_prompt,
|
||||||
|
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def to_enc_dec_tuple_list(
|
||||||
|
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
|
||||||
|
) -> List[Tuple[PromptInputs, PromptInputs]]:
|
||||||
|
return [(enc_dec_prompt['encoder_prompt'],
|
||||||
|
enc_dec_prompt['decoder_prompt'])
|
||||||
|
for enc_dec_prompt in enc_dec_prompts]
|
||||||
|
472
vllm/worker/enc_dec_model_runner.py
Normal file
472
vllm/worker/enc_dec_model_runner.py
Normal file
@ -0,0 +1,472 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata)
|
||||||
|
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||||
|
get_global_forced_attn_backend,
|
||||||
|
global_force_attn_backend)
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
|
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||||
|
PromptAdapterConfig, SchedulerConfig)
|
||||||
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
|
||||||
|
SequenceGroupMetadata)
|
||||||
|
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
|
||||||
|
from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase,
|
||||||
|
ModelInputForGPUBuilder,
|
||||||
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict)
|
||||||
|
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||||
|
"""
|
||||||
|
Used by the EncoderDecoderModelRunner.
|
||||||
|
"""
|
||||||
|
encoder_input_tokens: Optional[torch.Tensor] = None
|
||||||
|
encoder_input_positions: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"encoder_input_tokens": self.encoder_input_tokens,
|
||||||
|
"encoder_input_positions": self.encoder_input_positions,
|
||||||
|
"virtual_engine": self.virtual_engine,
|
||||||
|
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||||
|
"finished_requests_ids": self.finished_requests_ids,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "EncoderDecoderModelInput":
|
||||||
|
return cast(
|
||||||
|
EncoderDecoderModelInput,
|
||||||
|
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||||
|
_model_input_cls: Type[EncoderDecoderModelInput] = (
|
||||||
|
EncoderDecoderModelInput)
|
||||||
|
_builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||||
|
multimodal_config: Optional[MultiModalConfig] = None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
EncoderDecoderModelRunner constructor.
|
||||||
|
|
||||||
|
`lora_config`, `multimodal_config`, and prompt_adapter_config are
|
||||||
|
unused (since these features are not yet supported for encoder/decoder
|
||||||
|
models) but these arguments are present here for compatibility with
|
||||||
|
the base-class constructor.
|
||||||
|
'''
|
||||||
|
|
||||||
|
self._maybe_force_supported_attention_backend()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
cache_config,
|
||||||
|
load_config,
|
||||||
|
lora_config=None,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Crash for unsupported encoder/scenarios
|
||||||
|
assert_enc_dec_mr_supported_scenario(self)
|
||||||
|
|
||||||
|
def _maybe_force_supported_attention_backend(self):
|
||||||
|
'''
|
||||||
|
Force vLLM to use the XFormers attention backend,
|
||||||
|
which is currently the only supported option.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def raise_backend_err():
|
||||||
|
# The user has specified an attention backend override
|
||||||
|
# which is invalid for encoder/decoder models
|
||||||
|
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
|
||||||
|
|
||||||
|
maybe_env_var_forced_backend = get_env_variable_attn_backend()
|
||||||
|
maybe_global_forced_backend = get_global_forced_attn_backend()
|
||||||
|
is_forced_by_global = maybe_global_forced_backend is not None
|
||||||
|
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
||||||
|
|
||||||
|
if not (is_forced_by_global or is_forced_by_env_var):
|
||||||
|
# The user has not already specified an attention backend
|
||||||
|
# override
|
||||||
|
logger.info("EncoderDecoderModelRunner requires "
|
||||||
|
"XFormers backend; overriding backend "
|
||||||
|
"auto-selection and forcing XFormers.")
|
||||||
|
global_force_attn_backend(_Backend.XFORMERS)
|
||||||
|
elif is_forced_by_global:
|
||||||
|
# Backend override enforced by global variable takes
|
||||||
|
# precedence over vLLM backend environment variable.
|
||||||
|
if maybe_global_forced_backend != _Backend.XFORMERS:
|
||||||
|
raise_backend_err()
|
||||||
|
elif is_forced_by_env_var:
|
||||||
|
# Backend override enforced by vLLM backend
|
||||||
|
# environment variable
|
||||||
|
if maybe_env_var_forced_backend != _Backend.XFORMERS:
|
||||||
|
raise_backend_err()
|
||||||
|
|
||||||
|
def _list_to_int32_tensor(
|
||||||
|
self,
|
||||||
|
_list: List[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.tensor(_list, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
def _list_to_long_tensor(
|
||||||
|
self,
|
||||||
|
_list: List[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.tensor(_list, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
|
def _empty_int32_tensor(self) -> torch.Tensor:
|
||||||
|
return self._list_to_int32_tensor([])
|
||||||
|
|
||||||
|
def _empty_long_tensor(self) -> torch.Tensor:
|
||||||
|
return self._list_to_long_tensor([])
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: EncoderDecoderModelInput,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[PoolerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError("num_steps > 1 is not supported in "
|
||||||
|
"EncoderDecoderModelRunner")
|
||||||
|
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
seqlen_agnostic_kwargs = {
|
||||||
|
"finished_requests_ids": model_input.finished_requests_ids,
|
||||||
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||||
|
} if self.has_seqlen_agnostic else {}
|
||||||
|
hidden_or_intermediate_states = model_executable(
|
||||||
|
input_ids=model_input.input_tokens,
|
||||||
|
positions=model_input.input_positions,
|
||||||
|
encoder_input_ids=model_input.encoder_input_tokens,
|
||||||
|
encoder_positions=model_input.encoder_input_positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=model_input.attn_metadata,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
**seqlen_agnostic_kwargs)
|
||||||
|
|
||||||
|
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output: SamplerOutput = self.model.sample(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [output]
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
|
||||||
|
return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> EncoderDecoderModelInput:
|
||||||
|
"""Prepare the model input based on a given sequence group, including
|
||||||
|
metadata for the sampling step.
|
||||||
|
|
||||||
|
Since chunked prefill is not supported for encoder/decoder models,
|
||||||
|
`input_tokens` is assumed to be either entirely prefill tokens or
|
||||||
|
entirely decode tokens.
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_input = self._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
|
|
||||||
|
(
|
||||||
|
attn_metadata,
|
||||||
|
encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions_tensor,
|
||||||
|
) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||||
|
model_input))
|
||||||
|
|
||||||
|
# Inject attn_metadata encoder/cross-attention fields &
|
||||||
|
# encoder input tokens/positions into model_input.
|
||||||
|
# Frozen dataclass fields cannot be modified, so use
|
||||||
|
# dataclasses.replace to construct a new model input
|
||||||
|
# instance.
|
||||||
|
model_input = dataclasses.replace(
|
||||||
|
model_input,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
encoder_input_tokens=encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions=encoder_input_positions_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||||
|
model_input.seq_lens,
|
||||||
|
model_input.query_lens,
|
||||||
|
self.device,
|
||||||
|
self.pin_memory)
|
||||||
|
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||||
|
if seq_group_metadata_list else None)
|
||||||
|
return dataclasses.replace(model_input,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
is_prompt=is_prompt,
|
||||||
|
virtual_engine=virtual_engine)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def profile_run(self) -> None:
|
||||||
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
|
|
||||||
|
model_config = self.model_config
|
||||||
|
|
||||||
|
batch_size = 0
|
||||||
|
for group_id in range(max_num_seqs):
|
||||||
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||||
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
|
batch_size += seq_len
|
||||||
|
|
||||||
|
seq_data, _ = INPUT_REGISTRY \
|
||||||
|
.dummy_data_for_profiling(model_config, seq_len)
|
||||||
|
|
||||||
|
# Having more tokens is over-conservative but otherwise fine
|
||||||
|
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||||
|
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||||
|
f"but got: {len(seq_data.prompt_token_ids)}")
|
||||||
|
|
||||||
|
seq = SequenceGroupMetadata(
|
||||||
|
request_id=str(group_id),
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={group_id: seq_data},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables=None,
|
||||||
|
encoder_seq_data=seq_data,
|
||||||
|
cross_block_table=None,
|
||||||
|
)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
# Run the model with the dummy inputs.
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
kv_caches = [None] * num_layers
|
||||||
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
|
model_input = self.prepare_model_input(
|
||||||
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
|
intermediate_tensors = None
|
||||||
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return
|
||||||
|
|
||||||
|
def _prepare_encoder_model_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
model_input: EncoderDecoderModelInput,
|
||||||
|
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
"""Helper method to prepare the encoder- and cross-attn-related
|
||||||
|
model inputs based on a given sequence group. These additional inputs
|
||||||
|
are used to augment an already-computed `EncoderDecoderModelInput`
|
||||||
|
data structure which already has decoder-related model inputs
|
||||||
|
populated.
|
||||||
|
|
||||||
|
Sets the following attn_metadata fields:
|
||||||
|
* `num_encoder_tokens`
|
||||||
|
* `encoder_seq_lens`
|
||||||
|
* `encoder_seq_lens_tensor`
|
||||||
|
* `max_encoder_seq_len`
|
||||||
|
* `cross_slot_mapping`
|
||||||
|
* `cross_block_tables`
|
||||||
|
|
||||||
|
Constructs a new model inputs data structure, based on
|
||||||
|
(1) the existing fields in the `model_inputs` argument,
|
||||||
|
and (2) the following additional fields which are
|
||||||
|
computed (or in the case of `attn_metadata`, updated)
|
||||||
|
by this function:
|
||||||
|
* attn_metadata
|
||||||
|
* encoder_input_tokens
|
||||||
|
* encoder_input_positions
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* seq_group_metadata_list: list of sequence groups for which to
|
||||||
|
compute inputs
|
||||||
|
* model_inputs: model inputs data structure with decoder-oriented
|
||||||
|
fields already computed.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
|
||||||
|
* Updated model inputs data structure
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(seq_group_metadata_list) == 0:
|
||||||
|
return (model_input.attn_metadata, None, None)
|
||||||
|
|
||||||
|
# Since we are not supporting chunked prefill either the entire
|
||||||
|
# batch is prefill or it is decode
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
|
||||||
|
# Build encoder inputs
|
||||||
|
encoder_seq_lens: List[int] = []
|
||||||
|
if is_prompt:
|
||||||
|
# Prefill phase.
|
||||||
|
cross_block_tables = self._empty_int32_tensor().view(
|
||||||
|
len(seq_group_metadata_list), -1)
|
||||||
|
|
||||||
|
# Extract input tokens/positions, cross-attention slot-mapping,
|
||||||
|
# & seq len from each sequence group metadata
|
||||||
|
(
|
||||||
|
encoder_input_tokens,
|
||||||
|
encoder_input_positions,
|
||||||
|
cross_slot_mapping,
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
# Build seq lens
|
||||||
|
seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||||
|
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
|
||||||
|
encoder_seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
# Build slot mapping
|
||||||
|
is_profile_run = (seq_group_metadata.block_tables is None)
|
||||||
|
if is_profile_run:
|
||||||
|
# During memory profiling, the block tables are not
|
||||||
|
# initialized yet. In this case, we just use a dummy
|
||||||
|
# slot mapping.
|
||||||
|
# In embeddings, the block tables are {seq_id: None}.
|
||||||
|
cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||||
|
else:
|
||||||
|
for i in range(0, seq_len):
|
||||||
|
block_number = seq_group_metadata.cross_block_table[
|
||||||
|
i // self.block_size]
|
||||||
|
block_offset = i % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
cross_slot_mapping.append(slot)
|
||||||
|
|
||||||
|
# Build encoder input tokens
|
||||||
|
encoder_input_tokens.extend(token_ids)
|
||||||
|
encoder_input_positions.extend(list(range(0, seq_len)))
|
||||||
|
|
||||||
|
# Convert tokens/positions & cross-attention
|
||||||
|
# slot-mapping to encoder input tensors
|
||||||
|
encoder_input_tokens_tensor = self._list_to_long_tensor(
|
||||||
|
encoder_input_tokens)
|
||||||
|
encoder_input_positions_tensor = self._list_to_long_tensor(
|
||||||
|
encoder_input_positions)
|
||||||
|
cross_slot_mapping_tensor = self._list_to_long_tensor(
|
||||||
|
cross_slot_mapping)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Decode phase.
|
||||||
|
encoder_input_tokens_tensor = self._empty_long_tensor()
|
||||||
|
encoder_input_positions_tensor = self._empty_long_tensor()
|
||||||
|
cross_slot_mapping_tensor = self._empty_long_tensor()
|
||||||
|
|
||||||
|
# Extract cross-attention block tables &
|
||||||
|
# seq len from each sequence group metadata.
|
||||||
|
# Cross-attention block tables are empty
|
||||||
|
# during vLLM memory profiling.
|
||||||
|
cross_block_tables = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
encoder_seq_lens.append(
|
||||||
|
seq_group_metadata.encoder_seq_data.get_len())
|
||||||
|
cross_block_table = seq_group_metadata.cross_block_table
|
||||||
|
cross_block_tables.append([] if (
|
||||||
|
cross_block_table is None) else cross_block_table)
|
||||||
|
|
||||||
|
# Convert cross-attention block tables to encoder input tensor
|
||||||
|
cross_block_tables = make_tensor_with_pad(
|
||||||
|
cross_block_tables,
|
||||||
|
max_len=max(
|
||||||
|
len(block_table) for block_table in cross_block_tables),
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute encoder sequence lengths & encoder
|
||||||
|
# sequence starting offset tensors
|
||||||
|
max_encoder_seq_len = max(encoder_seq_lens, default=0)
|
||||||
|
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
|
||||||
|
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
|
||||||
|
1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
torch.cumsum(encoder_seq_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=encoder_seq_start_loc.dtype,
|
||||||
|
out=encoder_seq_start_loc[1:])
|
||||||
|
|
||||||
|
# Update attention metadata with encoder-oriented attributes
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
assert attn_metadata is not None
|
||||||
|
(
|
||||||
|
attn_metadata.num_encoder_tokens,
|
||||||
|
attn_metadata.encoder_seq_lens,
|
||||||
|
attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
attn_metadata.max_encoder_seq_len,
|
||||||
|
attn_metadata.cross_slot_mapping,
|
||||||
|
attn_metadata.cross_block_tables,
|
||||||
|
) = (
|
||||||
|
sum(encoder_seq_lens),
|
||||||
|
encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len,
|
||||||
|
cross_slot_mapping_tensor,
|
||||||
|
cross_block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (attn_metadata, encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions_tensor)
|
56
vllm/worker/utils.py
Normal file
56
vllm/worker/utils.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
'''
|
||||||
|
Worker-related helper functions.
|
||||||
|
'''
|
||||||
|
|
||||||
|
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
|
||||||
|
from vllm.worker.model_runner import GPUModelRunnerBase
|
||||||
|
|
||||||
|
|
||||||
|
def assert_enc_dec_mr_supported_scenario(
|
||||||
|
enc_dec_mr: GPUModelRunnerBase) -> None:
|
||||||
|
'''
|
||||||
|
Asserted that the provided encoder/decoder model runner instance reflects
|
||||||
|
a supported scenario.
|
||||||
|
'''
|
||||||
|
|
||||||
|
if enc_dec_mr.cache_config.enable_prefix_caching:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
|
||||||
|
|
||||||
|
if enc_dec_mr.sliding_window is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
|
||||||
|
|
||||||
|
if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
|
||||||
|
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||||
|
'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
|
||||||
|
|
||||||
|
if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
|
||||||
|
None) is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
|
||||||
|
)
|
||||||
|
|
||||||
|
if enc_dec_mr.lora_config is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
|
||||||
|
|
||||||
|
if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
|
||||||
|
|
||||||
|
if enc_dec_mr.multimodal_config is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
|
||||||
|
|
||||||
|
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
|
||||||
|
|
||||||
|
if not enc_dec_mr.model_config.enforce_eager:
|
||||||
|
raise NotImplementedError(
|
||||||
|
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'])
|
||||||
|
|
||||||
|
if enc_dec_mr.prompt_adapter_config is not None:
|
||||||
|
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||||
|
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])
|
@ -19,8 +19,11 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.utils import (is_embedding_model_config,
|
||||||
|
is_encoder_decoder_model_config)
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||||
|
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||||
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
||||||
|
|
||||||
@ -85,8 +88,10 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
if model_runner_cls is not None:
|
if model_runner_cls is not None:
|
||||||
ModelRunnerClass = model_runner_cls
|
ModelRunnerClass = model_runner_cls
|
||||||
elif self.model_config.embedding_mode:
|
elif self._is_embedding_model():
|
||||||
ModelRunnerClass = EmbeddingModelRunner
|
ModelRunnerClass = EmbeddingModelRunner
|
||||||
|
elif self._is_encoder_decoder_model():
|
||||||
|
ModelRunnerClass = EncoderDecoderModelRunner
|
||||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||||
model_config,
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
@ -107,6 +112,12 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||||
|
|
||||||
|
def _is_encoder_decoder_model(self):
|
||||||
|
return is_encoder_decoder_model_config(self.model_config)
|
||||||
|
|
||||||
|
def _is_embedding_model(self):
|
||||||
|
return is_embedding_model_config(self.model_config)
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
# torch.distributed.all_reduce does not free the input tensor until
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
Loading…
x
Reference in New Issue
Block a user