vllm/cacheflow/model_executor/input_metadata.py

53 lines
2.0 KiB
Python
Raw Normal View History

from typing import List, Dict, Tuple
2023-02-22 19:01:20 +00:00
import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
2023-02-22 19:01:20 +00:00
from cacheflow.sampling_params import SamplingParams
2023-02-22 19:01:20 +00:00
class InputMetadata:
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
2023-02-22 19:01:20 +00:00
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
2023-02-22 19:01:20 +00:00
self.prompt_lens = prompt_lens
2023-02-23 00:10:07 +00:00
self.slot_mapping = slot_mapping
2023-02-22 19:01:20 +00:00
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
2023-02-22 19:01:20 +00:00
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
2023-02-22 19:01:20 +00:00
self.num_generation_tokens = context_lens.shape[0]
2023-03-06 10:05:27 -08:00
self.num_valid_tokens = slot_mapping.shape[0]
2023-02-23 20:20:33 +00:00
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens
2023-02-23 23:02:25 +00:00
def __repr__(self) -> str:
return (f'InputMetadata('
f'num_valid_tokens={self.num_valid_tokens}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
2023-03-22 04:45:42 +08:00
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
2023-03-22 04:45:42 +08:00
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')