[core] separate builder init and builder prepare for each batch (#12253)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
222a9dc350
commit
66818e5b63
@ -65,11 +65,6 @@ class AttentionBackend(ABC):
|
||||
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata_builder(cls, *args,
|
||||
**kwargs) -> "AttentionMetadataBuilder":
|
||||
return cls.get_builder_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
|
||||
"""Create the builder, remember some configuration and parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self) -> None:
|
||||
"""Prepare for one batch."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
@ -388,11 +394,6 @@ class FlashAttentionMetadataBuilder(
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
|
@ -488,6 +488,14 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
@ -500,12 +508,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
||||
# for the precise definition of the following fields.
|
||||
# An example:
|
||||
|
@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def prepare(self):
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
@ -263,9 +268,6 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
|
@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_data = input_builder.input_data
|
||||
self.input_builder = input_builder
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
|
||||
|
@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
_metadata_cls: Type[TAttentionMetadata]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
@ -134,12 +141,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
|
@ -144,9 +144,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
runner: "CPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
self.runner = runner
|
||||
|
||||
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
|
||||
or runner.cache_config.enable_prefix_caching)
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
@ -156,10 +154,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
self.device = self.runner.device
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
if self.runner.attn_backend is not None:
|
||||
# spec decode (e.g. Medusa) does not have atten backend
|
||||
attn_backend = self.runner.attn_backend
|
||||
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
self.input_data = ModelInputForCPUBuilder.ModelInputData(
|
||||
self.runner.model_config.uses_mrope)
|
||||
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
|
||||
self)
|
||||
self.att_metadata_builder.prepare()
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
"""
|
||||
_model_input_cls: Type[TModelInputForCPU]
|
||||
_builder_cls: Type[ModelInputForCPUBuilder]
|
||||
builder: ModelInputForCPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -477,6 +483,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
|
||||
if hasattr(self, "_builder_cls"):
|
||||
# multi-step model runner does not have `_builder_cls`
|
||||
self.builder = self._builder_cls(weakref.proxy(self))
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
@ -522,10 +532,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
"""
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
builder.set_seq_group_list(seq_group_metadata_list)
|
||||
self.builder.prepare(finished_requests_ids)
|
||||
self.builder.set_seq_group_list(seq_group_metadata_list)
|
||||
|
||||
return builder.build() # type: ignore
|
||||
return self.builder.build() # type: ignore
|
||||
|
||||
# sampler property will be used by spec_decode_worker
|
||||
@property
|
||||
|
@ -457,17 +457,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
|
||||
is not None)
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
self.finished_requests_ids = finished_requests_ids
|
||||
self.decode_only = True
|
||||
|
||||
# Intermediate data (data in CPU before going to GPU) for
|
||||
# the current sequence group.
|
||||
self.inter_data_list: List[
|
||||
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
|
||||
|
||||
# Attention metadata inputs.
|
||||
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
|
||||
weakref.proxy(self))
|
||||
if self.attn_backend is not None:
|
||||
# spec decode (e.g. Medusa) does not have atten backend
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
|
||||
# Engine/Model configurations.
|
||||
self.chunked_prefill_enabled = (
|
||||
@ -479,6 +475,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.block_aligned_sliding_window = \
|
||||
self.sliding_window_blocks * self.block_size
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.finished_requests_ids = finished_requests_ids
|
||||
|
||||
# Intermediate data (data in CPU before going to GPU) for
|
||||
# the current sequence group.
|
||||
self.inter_data_list: List[
|
||||
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
|
||||
|
||||
self.attn_metadata_builder.prepare()
|
||||
|
||||
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
|
||||
seq_group_metadata: SequenceGroupMetadata):
|
||||
"""Compute context length, sequence length and tokens
|
||||
@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"""
|
||||
_model_input_cls: Type[TModelInputForGPU]
|
||||
_builder_cls: Type[ModelInputForGPUBuilder]
|
||||
builder: ModelInputForGPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1093,6 +1101,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
if hasattr(self, "_builder_cls"):
|
||||
# multi-step model runner does not have `_builder_cls`
|
||||
self.builder = self._builder_cls(weakref.proxy(self))
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m:
|
||||
@ -1226,13 +1238,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
If cuda graph is required, this API automatically pads inputs.
|
||||
"""
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
self.builder.prepare(finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
self.builder.add_seq_group(seq_group_metadata)
|
||||
|
||||
builder.reset_cached_inter_data()
|
||||
self.builder.reset_cached_inter_data()
|
||||
|
||||
return builder.build() # type: ignore
|
||||
return self.builder.build() # type: ignore
|
||||
|
||||
@contextmanager
|
||||
def set_in_profile_run(self):
|
||||
|
@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
|
||||
"""A builder to create ModelRunnerInputBase objects.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_seq_group(self, seq_group_metadata):
|
||||
"""TBA"""
|
||||
|
@ -113,7 +113,6 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
runner: "XPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
self.runner = runner
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
@ -121,6 +120,10 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
@ -408,6 +411,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
self.builder = self._builder_cls(weakref.proxy(self))
|
||||
|
||||
def load_model(self) -> None:
|
||||
with DeviceMemoryProfiler() as m:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
@ -517,7 +522,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
"""
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
builder = self.builder
|
||||
builder.prepare(finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user