[Misc] Use dataclass for InputMetadata (#3452)

Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Woosuk Kwon 2024-03-17 03:02:46 -07:00 committed by GitHub
parent 6b78837b29
commit abfc4f3387
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 63 deletions

View File

@ -2,7 +2,6 @@ import contextlib
import io import io
import os import os
import re import re
import shutil
import subprocess import subprocess
import warnings import warnings
from pathlib import Path from pathlib import Path

View File

@ -1,8 +1,10 @@
from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
@dataclass
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention. """Metadata for input sequences. Used in PagedAttention.
@ -15,40 +17,17 @@ class InputMetadata:
kv_cache_dtype: Data type to store kv cache. kv_cache_dtype: Data type to store kv cache.
""" """
def __init__( is_prompt: bool
self, slot_mapping: torch.Tensor
is_prompt: bool, prompt_lens: Optional[torch.Tensor]
slot_mapping: torch.Tensor, max_seq_len: Optional[int]
prompt_lens: Optional[torch.Tensor], start_loc: Optional[torch.Tensor]
max_seq_len: Optional[int], max_context_len: Optional[int]
start_loc: Optional[torch.Tensor], context_lens: Optional[torch.Tensor]
max_context_len: Optional[int], block_tables: Optional[torch.Tensor]
context_lens: Optional[torch.Tensor], use_cuda_graph: bool
block_tables: Optional[torch.Tensor], kv_cache_dtype: str
use_cuda_graph: bool,
kv_cache_dtype: str,
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
self.max_seq_len = max_seq_len
self.start_loc = start_loc
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.kv_cache_dtype = kv_cache_dtype
# Set during the execution of the first attention op. def __post_init__(self):
# FIXME(woosuk): This is a hack. # will not appear in the __repr__ and __init__
self.attn_bias = None self.attn_bias = None
def __repr__(self) -> str:
return ("InputMetadata("
f"is_prompt={self.is_prompt}, "
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_cache_dtype={self.kv_cache_dtype})")

View File

@ -1,4 +1,5 @@
import contextlib import contextlib
import dataclasses
import time import time
from typing import Dict, List, Optional, Tuple, Set, Union from typing import Dict, List, Optional, Tuple, Set, Union
@ -521,45 +522,27 @@ class ModelRunner:
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
"input_positions": input_positions, "input_positions": input_positions,
"is_prompt": input_metadata.is_prompt,
"slot_mapping": input_metadata.slot_mapping,
"prompt_lens": input_metadata.prompt_lens,
"max_seq_len": input_metadata.max_seq_len,
"start_loc": input_metadata.start_loc,
"max_context_len": input_metadata.max_context_len,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
"use_cuda_graph": input_metadata.use_cuda_graph,
"kv_cache_dtype": input_metadata.kv_cache_dtype,
"selected_token_indices": "selected_token_indices":
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
} }
metadata_dict.update(dataclasses.asdict(input_metadata))
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict["input_tokens"] input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict["input_positions"] input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict["lora_mapping"] selected_token_indices = metadata_dict.pop(
lora_requests = metadata_dict["lora_requests"] "selected_token_indices")
input_metadata = InputMetadata( lora_mapping = metadata_dict.pop("lora_mapping")
is_prompt=metadata_dict["is_prompt"], lora_requests = metadata_dict.pop("lora_requests")
slot_mapping=metadata_dict["slot_mapping"], input_metadata = InputMetadata(**metadata_dict)
prompt_lens=metadata_dict["prompt_lens"],
max_seq_len=metadata_dict["max_seq_len"],
start_loc=metadata_dict["start_loc"],
max_context_len=metadata_dict["max_context_len"],
context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"],
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
prompt_lens=None, prompt_lens=None,
selected_token_indices=metadata_dict["selected_token_indices"], selected_token_indices=selected_token_indices,
categorized_sample_indices=None, categorized_sample_indices=None,
generators=None, generators=None,
perform_sampling=False, perform_sampling=False,