[core] separate builder init and builder prepare for each batch (#12253)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-22 14:13:52 +08:00 committed by GitHub
parent 222a9dc350
commit 66818e5b63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 90 additions and 47 deletions

View File

@ -65,11 +65,6 @@ class AttentionBackend(ABC):
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_kv_cache_shape(
@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@abstractmethod
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError
@abstractmethod

View File

@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
@ -388,11 +394,6 @@ class FlashAttentionMetadataBuilder(
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):

View File

@ -488,6 +488,14 @@ class FlashInferMetadata(AttentionMetadata):
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
@ -500,12 +508,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:

View File

@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
def prepare(self):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
@ -263,9 +268,6 @@ class PlaceholderAttentionMetadataBuilder(
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):

View File

@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_data = input_builder.input_data
self.input_builder = input_builder
def prepare(self):
self.input_data = self.input_builder.input_data
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:

View File

@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
@ -134,12 +141,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):

View File

@ -144,9 +144,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
@ -156,10 +154,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.enable_lora = self.runner.lora_config is not None
if self.runner.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend = self.runner.attn_backend
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
self)
self.att_metadata_builder.prepare()
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder
def __init__(
self,
@ -477,6 +483,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)
@ -522,10 +532,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
metadata for possible additional steps, e.g., sampling.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
builder.set_seq_group_list(seq_group_metadata_list)
self.builder.prepare(finished_requests_ids)
self.builder.set_seq_group_list(seq_group_metadata_list)
return builder.build() # type: ignore
return self.builder.build() # type: ignore
# sampler property will be used by spec_decode_worker
@property

View File

@ -457,17 +457,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
# Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
weakref.proxy(self))
if self.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
# Engine/Model configurations.
self.chunked_prefill_enabled = (
@ -479,6 +475,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
self.attn_metadata_builder.prepare()
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
builder: ModelInputForGPUBuilder
def __init__(
self,
@ -1093,6 +1101,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
@ -1226,13 +1238,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
self.builder.add_seq_group(seq_group_metadata)
builder.reset_cached_inter_data()
self.builder.reset_cached_inter_data()
return builder.build() # type: ignore
return self.builder.build() # type: ignore
@contextmanager
def set_in_profile_run(self):

View File

@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@abstractmethod
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
raise NotImplementedError
@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""

View File

@ -113,7 +113,6 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
runner: "XPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
@ -121,6 +120,10 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
self.block_size = self.runner.block_size
self.device = self.runner.device
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
@ -408,6 +411,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config)
@ -517,7 +522,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
metadata for possible additional steps, e.g., sampling.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
builder = self.builder
builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)