[Misc] Use dataclass for InputMetadata (#3452)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
6b78837b29
commit
abfc4f3387
1
setup.py
1
setup.py
@ -2,7 +2,6 @@ import contextlib
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class InputMetadata:
|
class InputMetadata:
|
||||||
"""Metadata for input sequences. Used in PagedAttention.
|
"""Metadata for input sequences. Used in PagedAttention.
|
||||||
|
|
||||||
@ -15,40 +17,17 @@ class InputMetadata:
|
|||||||
kv_cache_dtype: Data type to store kv cache.
|
kv_cache_dtype: Data type to store kv cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
is_prompt: bool
|
||||||
self,
|
slot_mapping: torch.Tensor
|
||||||
is_prompt: bool,
|
prompt_lens: Optional[torch.Tensor]
|
||||||
slot_mapping: torch.Tensor,
|
max_seq_len: Optional[int]
|
||||||
prompt_lens: Optional[torch.Tensor],
|
start_loc: Optional[torch.Tensor]
|
||||||
max_seq_len: Optional[int],
|
max_context_len: Optional[int]
|
||||||
start_loc: Optional[torch.Tensor],
|
context_lens: Optional[torch.Tensor]
|
||||||
max_context_len: Optional[int],
|
block_tables: Optional[torch.Tensor]
|
||||||
context_lens: Optional[torch.Tensor],
|
use_cuda_graph: bool
|
||||||
block_tables: Optional[torch.Tensor],
|
kv_cache_dtype: str
|
||||||
use_cuda_graph: bool,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
) -> None:
|
|
||||||
self.is_prompt = is_prompt
|
|
||||||
self.prompt_lens = prompt_lens
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
self.start_loc = start_loc
|
|
||||||
self.max_context_len = max_context_len
|
|
||||||
self.slot_mapping = slot_mapping
|
|
||||||
self.context_lens = context_lens
|
|
||||||
self.block_tables = block_tables
|
|
||||||
self.use_cuda_graph = use_cuda_graph
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
|
||||||
|
|
||||||
# Set during the execution of the first attention op.
|
def __post_init__(self):
|
||||||
# FIXME(woosuk): This is a hack.
|
# will not appear in the __repr__ and __init__
|
||||||
self.attn_bias = None
|
self.attn_bias = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return ("InputMetadata("
|
|
||||||
f"is_prompt={self.is_prompt}, "
|
|
||||||
f"max_context_len={self.max_context_len}, "
|
|
||||||
f"slot_mapping={self.slot_mapping}, "
|
|
||||||
f"context_lens={self.context_lens}, "
|
|
||||||
f"block_tables={self.block_tables}, "
|
|
||||||
f"use_cuda_graph={self.use_cuda_graph}, "
|
|
||||||
f"kv_cache_dtype={self.kv_cache_dtype})")
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple, Set, Union
|
from typing import Dict, List, Optional, Tuple, Set, Union
|
||||||
|
|
||||||
@ -521,45 +522,27 @@ class ModelRunner:
|
|||||||
metadata_dict = {
|
metadata_dict = {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": input_tokens,
|
||||||
"input_positions": input_positions,
|
"input_positions": input_positions,
|
||||||
"is_prompt": input_metadata.is_prompt,
|
|
||||||
"slot_mapping": input_metadata.slot_mapping,
|
|
||||||
"prompt_lens": input_metadata.prompt_lens,
|
|
||||||
"max_seq_len": input_metadata.max_seq_len,
|
|
||||||
"start_loc": input_metadata.start_loc,
|
|
||||||
"max_context_len": input_metadata.max_context_len,
|
|
||||||
"context_lens": input_metadata.context_lens,
|
|
||||||
"block_tables": input_metadata.block_tables,
|
|
||||||
"use_cuda_graph": input_metadata.use_cuda_graph,
|
|
||||||
"kv_cache_dtype": input_metadata.kv_cache_dtype,
|
|
||||||
"selected_token_indices":
|
"selected_token_indices":
|
||||||
sampling_metadata.selected_token_indices,
|
sampling_metadata.selected_token_indices,
|
||||||
"lora_requests": lora_requests,
|
"lora_requests": lora_requests,
|
||||||
"lora_mapping": lora_mapping,
|
"lora_mapping": lora_mapping,
|
||||||
}
|
}
|
||||||
|
metadata_dict.update(dataclasses.asdict(input_metadata))
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
else:
|
else:
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
metadata_dict = broadcast_tensor_dict(src=0)
|
||||||
input_tokens = metadata_dict["input_tokens"]
|
input_tokens = metadata_dict.pop("input_tokens")
|
||||||
input_positions = metadata_dict["input_positions"]
|
input_positions = metadata_dict.pop("input_positions")
|
||||||
lora_mapping = metadata_dict["lora_mapping"]
|
selected_token_indices = metadata_dict.pop(
|
||||||
lora_requests = metadata_dict["lora_requests"]
|
"selected_token_indices")
|
||||||
input_metadata = InputMetadata(
|
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||||
is_prompt=metadata_dict["is_prompt"],
|
lora_requests = metadata_dict.pop("lora_requests")
|
||||||
slot_mapping=metadata_dict["slot_mapping"],
|
input_metadata = InputMetadata(**metadata_dict)
|
||||||
prompt_lens=metadata_dict["prompt_lens"],
|
|
||||||
max_seq_len=metadata_dict["max_seq_len"],
|
|
||||||
start_loc=metadata_dict["start_loc"],
|
|
||||||
max_context_len=metadata_dict["max_context_len"],
|
|
||||||
context_lens=metadata_dict["context_lens"],
|
|
||||||
block_tables=metadata_dict["block_tables"],
|
|
||||||
use_cuda_graph=metadata_dict["use_cuda_graph"],
|
|
||||||
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
|
|
||||||
)
|
|
||||||
sampling_metadata = SamplingMetadata(
|
sampling_metadata = SamplingMetadata(
|
||||||
seq_groups=None,
|
seq_groups=None,
|
||||||
seq_data=None,
|
seq_data=None,
|
||||||
prompt_lens=None,
|
prompt_lens=None,
|
||||||
selected_token_indices=metadata_dict["selected_token_indices"],
|
selected_token_indices=selected_token_indices,
|
||||||
categorized_sample_indices=None,
|
categorized_sample_indices=None,
|
||||||
generators=None,
|
generators=None,
|
||||||
perform_sampling=False,
|
perform_sampling=False,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user