[Misc] Use dataclass for InputMetadata (#3452)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
6b78837b29
commit
abfc4f3387
1
setup.py
1
setup.py
@ -2,7 +2,6 @@ import contextlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
@ -1,8 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used in PagedAttention.
|
||||
|
||||
@ -15,40 +17,17 @@ class InputMetadata:
|
||||
kv_cache_dtype: Data type to store kv cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_prompt: bool,
|
||||
slot_mapping: torch.Tensor,
|
||||
prompt_lens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[int],
|
||||
start_loc: Optional[torch.Tensor],
|
||||
max_context_len: Optional[int],
|
||||
context_lens: Optional[torch.Tensor],
|
||||
block_tables: Optional[torch.Tensor],
|
||||
use_cuda_graph: bool,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
self.is_prompt = is_prompt
|
||||
self.prompt_lens = prompt_lens
|
||||
self.max_seq_len = max_seq_len
|
||||
self.start_loc = start_loc
|
||||
self.max_context_len = max_context_len
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.block_tables = block_tables
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
prompt_lens: Optional[torch.Tensor]
|
||||
max_seq_len: Optional[int]
|
||||
start_loc: Optional[torch.Tensor]
|
||||
max_context_len: Optional[int]
|
||||
context_lens: Optional[torch.Tensor]
|
||||
block_tables: Optional[torch.Tensor]
|
||||
use_cuda_graph: bool
|
||||
kv_cache_dtype: str
|
||||
|
||||
# Set during the execution of the first attention op.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
def __post_init__(self):
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("InputMetadata("
|
||||
f"is_prompt={self.is_prompt}, "
|
||||
f"max_context_len={self.max_context_len}, "
|
||||
f"slot_mapping={self.slot_mapping}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"use_cuda_graph={self.use_cuda_graph}, "
|
||||
f"kv_cache_dtype={self.kv_cache_dtype})")
|
||||
|
@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Set, Union
|
||||
|
||||
@ -521,45 +522,27 @@ class ModelRunner:
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"is_prompt": input_metadata.is_prompt,
|
||||
"slot_mapping": input_metadata.slot_mapping,
|
||||
"prompt_lens": input_metadata.prompt_lens,
|
||||
"max_seq_len": input_metadata.max_seq_len,
|
||||
"start_loc": input_metadata.start_loc,
|
||||
"max_context_len": input_metadata.max_context_len,
|
||||
"context_lens": input_metadata.context_lens,
|
||||
"block_tables": input_metadata.block_tables,
|
||||
"use_cuda_graph": input_metadata.use_cuda_graph,
|
||||
"kv_cache_dtype": input_metadata.kv_cache_dtype,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
}
|
||||
metadata_dict.update(dataclasses.asdict(input_metadata))
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict["input_tokens"]
|
||||
input_positions = metadata_dict["input_positions"]
|
||||
lora_mapping = metadata_dict["lora_mapping"]
|
||||
lora_requests = metadata_dict["lora_requests"]
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=metadata_dict["is_prompt"],
|
||||
slot_mapping=metadata_dict["slot_mapping"],
|
||||
prompt_lens=metadata_dict["prompt_lens"],
|
||||
max_seq_len=metadata_dict["max_seq_len"],
|
||||
start_loc=metadata_dict["start_loc"],
|
||||
max_context_len=metadata_dict["max_context_len"],
|
||||
context_lens=metadata_dict["context_lens"],
|
||||
block_tables=metadata_dict["block_tables"],
|
||||
use_cuda_graph=metadata_dict["use_cuda_graph"],
|
||||
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
|
||||
)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
input_metadata = InputMetadata(**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None,
|
||||
selected_token_indices=metadata_dict["selected_token_indices"],
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
generators=None,
|
||||
perform_sampling=False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user