2008 lines
86 KiB
Python
2008 lines
86 KiB
Python
###############################################################################
|
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
|
###############################################################################
|
|
|
|
import collections
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import gc
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import os
|
|
import time
|
|
from array import array
|
|
from dataclasses import dataclass, field
|
|
from enum import IntEnum
|
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
|
Optional, Set, Tuple, Type, TypeVar, Union)
|
|
|
|
import habana_frameworks.torch as htorch
|
|
import habana_frameworks.torch.internal.bridge_config as bc
|
|
import torch
|
|
from vllm_hpu_extension.ops import LoraMask as LoraMask
|
|
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
|
HabanaMemoryProfiler, format_bytes)
|
|
|
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
|
from vllm.config import DeviceConfig, VllmConfig
|
|
from vllm.distributed.parallel_state import get_world_group
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.layers import LoRAMapping
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
|
from vllm.model_executor import SamplingMetadata
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
|
MultiModalInputs)
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.sequence import (IntermediateTensors, SequenceData,
|
|
SequenceGroupMetadata)
|
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
|
from vllm.worker.model_runner_base import (
|
|
ModelRunnerBase, ModelRunnerInputBase,
|
|
_add_attn_metadata_broadcastable_dict,
|
|
_add_sampling_metadata_broadcastable_dict,
|
|
_init_attn_metadata_from_tensor_dict,
|
|
_init_sampling_metadata_from_tensor_dict)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_TYPE_CACHE = {}
|
|
# These values are assumed to be zero in several places.
|
|
# Use caution when updating them!
|
|
_PAD_SLOT_ID = 0
|
|
_PAD_BLOCK_ID = 0
|
|
|
|
LORA_WARMUP_RANK = 8
|
|
|
|
|
|
class Singleton(type):
|
|
_instances: Dict[type, object] = {}
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
if cls not in cls._instances:
|
|
cls._instances[cls] = super().__call__(*args, **kwargs)
|
|
return cls._instances[cls]
|
|
|
|
|
|
@dataclass
|
|
class HPUBucketingGlobalState(metaclass=Singleton):
|
|
prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
|
decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
|
prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
|
decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
|
prompt_buckets: List[Tuple[int, int]] = field(init=False)
|
|
decode_buckets: List[Tuple[int, int]] = field(init=False)
|
|
|
|
|
|
def subtuple(obj: object,
|
|
typename: str,
|
|
to_copy: List[str],
|
|
to_override: Optional[Dict[str, object]] = None):
|
|
if obj is None:
|
|
return None
|
|
if to_override is None:
|
|
to_override = {}
|
|
fields = set(to_copy) | set(to_override.keys())
|
|
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
|
if typename not in _TYPE_CACHE:
|
|
_TYPE_CACHE[typename] = collections.namedtuple(typename,
|
|
' '.join(fields))
|
|
return _TYPE_CACHE[typename](**values)
|
|
|
|
|
|
def read_bucket_settings(phase: str, dim: str, **defaults):
|
|
"""Read bucketing configuration from env variables.
|
|
|
|
phase is either 'prompt' or 'decode'
|
|
dim is either 'bs', 'seq' or 'block'
|
|
param is either 'min', 'step' or 'max'
|
|
example env variable: VLLM_DECODE_BS_BUCKET_STEP=128
|
|
"""
|
|
params = ['min', 'step', 'max']
|
|
env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params]
|
|
default_values = [defaults[p] for p in params]
|
|
values = [
|
|
int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values)
|
|
]
|
|
for e, v, d in zip(env_vars, values, default_values):
|
|
logger.info('%s=%s (default:%s)', e, v, d)
|
|
return values
|
|
|
|
|
|
def warmup_range(config: Tuple[int, int, int]):
|
|
"""Generate a warmup range.
|
|
|
|
Start from bmin and multiply by 2 until you reach bstep.
|
|
Then, increase the values in the range by the value of bstep until you
|
|
reach bmax.
|
|
|
|
Example:
|
|
bmin = 2, bstep = 32, bmax = 64
|
|
=> ramp_up = (2, 4, 8, 16)
|
|
=> stable = (32, 64)
|
|
=> return ramp_up + stable => (2, 4, 8, 16, 32, 64)
|
|
"""
|
|
bmin, bstep, bmax = config
|
|
assert bmin <= bmax, ("Min. batch size cannot be greater than max. "
|
|
"batch size. If you want to skip warmup, "
|
|
"set VLLM_SKIP_WARMUP=true")
|
|
base = itertools.repeat(2)
|
|
ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin)
|
|
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \
|
|
ramp_up_acc)
|
|
stable = range(bstep, bmax + 1, bstep)
|
|
buckets = list(ramp_up_tw) + list(stable)
|
|
return list(filter(lambda bucket: bucket >= bmin, buckets))
|
|
|
|
|
|
def generate_prompt_buckets(bs_bucket_config,
|
|
seq_bucket_config,
|
|
max_num_batched_tokens=None):
|
|
buckets = list(
|
|
itertools.product(warmup_range(bs_bucket_config),
|
|
warmup_range(seq_bucket_config)))
|
|
if len(buckets) == 0:
|
|
msg = ("No buckets could be captured with following config "
|
|
f"(min, step, max_warmup): "
|
|
f"bs:{bs_bucket_config}, "
|
|
f"seq:{seq_bucket_config}")
|
|
raise ValueError(msg)
|
|
|
|
filtered_buckets = buckets
|
|
if max_num_batched_tokens is not None:
|
|
# Remove buckets exceeding batch token budget
|
|
filtered_buckets = list(
|
|
filter(
|
|
lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
|
|
buckets))
|
|
|
|
if len(filtered_buckets) == 0:
|
|
# we can handle this if we ignore max_num_batched_tokens
|
|
min_bucket_bs, min_bucket_seq = min(buckets,
|
|
key=lambda b: (b[0] * b[1]))
|
|
min_reqd_budget = min_bucket_bs * min_bucket_seq
|
|
msg = (
|
|
"The current bucketing configuration "
|
|
f"(min, step, max_warmup): "
|
|
f"bs:{bs_bucket_config}, "
|
|
f"seq:{seq_bucket_config} cannot be used with specified "
|
|
f"max_num_batched_tokens ({max_num_batched_tokens}), as the "
|
|
f"smallest bucket ({min_reqd_budget}) would exceed token "
|
|
"budget. Please increase max_num_batched_tokens or decrease "
|
|
"bucket minimum Ignoring max_num_batched_tokens at risk of "
|
|
"out-of-memory errors.")
|
|
logger.error(msg)
|
|
return list(
|
|
sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), []
|
|
|
|
captured_buckets = list(
|
|
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
|
|
omitted_buckets = list(
|
|
sorted([x for x in buckets if x not in filtered_buckets]))
|
|
return captured_buckets, omitted_buckets
|
|
|
|
|
|
def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
|
|
max_blocks):
|
|
buckets = []
|
|
bs_buckets = warmup_range(bs_bucket_config)
|
|
block_buckets = warmup_range(blocks_bucket_config)
|
|
bmin, bstep, bmax = blocks_bucket_config
|
|
last_bucket = round_up(max_blocks, bstep)
|
|
for bs in bs_buckets:
|
|
for blocks in block_buckets:
|
|
if blocks < bs:
|
|
continue
|
|
if blocks > last_bucket:
|
|
break
|
|
buckets.append((bs, blocks))
|
|
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
|
|
|
|
|
|
def next_pow2(value: int, base: int):
|
|
res = base
|
|
while value > 1:
|
|
value = (value + 1) // 2
|
|
res *= 2
|
|
return res
|
|
|
|
|
|
def round_up(value: int, k: int):
|
|
return (value + k - 1) // k * k
|
|
|
|
|
|
def find_bucket(value: int, config: Tuple[int, int, int]):
|
|
bmin, bstep, _ = config
|
|
next_step = round_up(value, bstep)
|
|
next_pow = next_pow2(value, bmin)
|
|
return max(bmin, min(next_step, next_pow))
|
|
|
|
|
|
def align_workers(value, op):
|
|
group = get_world_group().cpu_group
|
|
world_size = torch.distributed.get_world_size()
|
|
if world_size <= 1:
|
|
return value
|
|
value_t = torch.tensor(value, device='cpu')
|
|
torch.distributed.all_reduce(value_t, op=op, group=group)
|
|
return value_t.item()
|
|
|
|
|
|
def setup_profiler():
|
|
schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
|
|
DEVICE = 'hpu'
|
|
activities = [torch.profiler.ProfilerActivity.CPU]
|
|
activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE ==
|
|
'hpu' else [])
|
|
#from habana_frameworks.torch.activity_profiler import DebugActivity
|
|
#debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS]
|
|
|
|
profiler = torch.profiler.profile(
|
|
schedule=schedule,
|
|
activities=activities,
|
|
#debug_activities=debug_activities,
|
|
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
|
|
use_gzip=True),
|
|
record_shapes=False,
|
|
with_stack=True)
|
|
return profiler
|
|
|
|
|
|
def pad_list(list, k, v):
|
|
target_len = round_up(len(list), k)
|
|
padding = target_len - len(list)
|
|
return list + [v] * padding
|
|
|
|
|
|
def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
|
|
slot_mapping = slot_mapping.flatten()
|
|
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
|
if is_prompt:
|
|
indices = indices.unflatten(0, (-1, block_size))[:, 0]
|
|
offsets = None
|
|
else:
|
|
offsets = torch.fmod(slot_mapping, block_size)
|
|
return indices, offsets
|
|
|
|
|
|
class HpuModelAdapter:
|
|
|
|
def __init__(self, model, block_size, dtype, enforce_eager):
|
|
self.model = model
|
|
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
|
'0').lower() in ['1', 'true']
|
|
self.block_size = block_size
|
|
self.dtype = dtype
|
|
if not htorch.utils.internal.is_lazy() and not enforce_eager:
|
|
self.model = torch.compile(self.model,
|
|
backend='hpu_backend',
|
|
dynamic=False)
|
|
|
|
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
|
|
dtype):
|
|
prefill_metadata = attn_metadata
|
|
if prefill_metadata is None or self.prefill_use_fusedsdpa:
|
|
return attn_metadata
|
|
|
|
seq_lens_t = prefill_metadata.seq_lens_tensor
|
|
len_mask = (torch.arange(0, seq_len, device=device,
|
|
dtype=torch.int32).view(1, seq_len).ge(
|
|
seq_lens_t.unsqueeze(-1)).view(
|
|
batch_size, 1, 1, seq_len))
|
|
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
|
|
device=device,
|
|
dtype=torch.bool),
|
|
diagonal=1)
|
|
mask = causal_mask.logical_or(len_mask)
|
|
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
|
|
mask, -math.inf))
|
|
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
|
|
return attn_metadata
|
|
|
|
def _set_block_mapping(self, metadata, batch_size, device, dtype):
|
|
mask = torch.arange(0,
|
|
self.block_size,
|
|
device=device,
|
|
dtype=torch.int32).unsqueeze(0)
|
|
mask = mask >= metadata.block_usage.unsqueeze(-1)
|
|
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
|
|
mask, -math.inf))
|
|
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
|
|
num_classes=batch_size)
|
|
block_mapping = block_mapping.to(dtype)
|
|
metadata = metadata._replace(block_mapping=block_mapping,
|
|
attn_bias=attn_bias)
|
|
return metadata
|
|
|
|
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
|
|
dtype):
|
|
if attn_metadata.is_prompt:
|
|
meta = attn_metadata
|
|
attn_metadata = self._set_attn_bias(meta, batch_size, seq_len,
|
|
device, dtype)
|
|
else:
|
|
meta = attn_metadata
|
|
attn_metadata = self._set_block_mapping(meta, batch_size, device,
|
|
dtype)
|
|
return attn_metadata
|
|
|
|
def forward(self, *args, **kwargs):
|
|
kwargs = kwargs.copy()
|
|
selected_token_indices = kwargs.pop('selected_token_indices')
|
|
if 'warmup_mode' in kwargs:
|
|
kwargs.pop('warmup_mode')
|
|
input_ids = kwargs['input_ids']
|
|
kwargs['attn_metadata'] = self._update_metadata(
|
|
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
|
|
input_ids.device, self.dtype)
|
|
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
|
|
hidden_states = self.model(*args, **kwargs)
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
hidden_states = hidden_states.index_select(0, selected_token_indices)
|
|
return hidden_states
|
|
|
|
def compute_logits(self, *args, **kwargs):
|
|
return self.model.compute_logits(*args, **kwargs)
|
|
|
|
def sample(self, *args, **kwargs):
|
|
return self.model.sample(*args, **kwargs)
|
|
|
|
|
|
class PreparePromptMetadata(NamedTuple):
|
|
input_tokens: torch.Tensor
|
|
input_positions: List[List[int]]
|
|
attn_metadata: Optional[AttentionMetadata]
|
|
seq_lens: List[int]
|
|
query_lens: List[int]
|
|
lora_index_mapping: List[List[int]]
|
|
lora_prompt_mapping: List[List[int]]
|
|
lora_requests: Set[LoRARequest]
|
|
multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]]
|
|
slot_mapping: List[List[int]]
|
|
lora_ids: List[int]
|
|
|
|
@classmethod
|
|
def empty(cls):
|
|
return PreparePromptMetadata(input_tokens=[],
|
|
input_positions=[],
|
|
attn_metadata=None,
|
|
seq_lens=[],
|
|
query_lens=[],
|
|
lora_index_mapping=[],
|
|
lora_prompt_mapping=[],
|
|
lora_requests=set(),
|
|
multi_modal_kwargs=None,
|
|
slot_mapping=[],
|
|
lora_ids=[])
|
|
|
|
|
|
class PrepareDecodeMetadata(NamedTuple):
|
|
input_tokens: torch.Tensor
|
|
input_positions: List[List[int]]
|
|
attn_metadata: Optional[AttentionMetadata]
|
|
lora_index_mapping: List[List[int]]
|
|
lora_prompt_mapping: List[List[int]]
|
|
lora_requests: Set[LoRARequest]
|
|
slot_mapping: List[List[int]]
|
|
lora_ids: List[int]
|
|
|
|
@classmethod
|
|
def empty(cls):
|
|
return PrepareDecodeMetadata(input_tokens=[],
|
|
input_positions=[],
|
|
attn_metadata=None,
|
|
lora_index_mapping=[],
|
|
lora_prompt_mapping=[],
|
|
lora_requests=set(),
|
|
slot_mapping=[],
|
|
lora_ids=[])
|
|
|
|
|
|
# How batches are constructed.
|
|
class BatchType(IntEnum):
|
|
# Every batch is prefill.
|
|
PREFILL = 0
|
|
# Every batch is decode.
|
|
DECODE = 1
|
|
# Batch is a mixture of prefill and decode.
|
|
MIXED = 2
|
|
|
|
|
|
TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ModelInputForHPU(ModelRunnerInputBase):
|
|
"""
|
|
This base class contains metadata needed for the base model forward pass
|
|
but not metadata for possible additional steps, e.g., sampling. Model
|
|
runners that run additional steps should subclass this method to add
|
|
additional fields.
|
|
"""
|
|
input_tokens: Optional[torch.Tensor] = None
|
|
input_positions: Optional[torch.Tensor] = None
|
|
seq_lens: Optional[List[int]] = None
|
|
query_lens: Optional[List[int]] = None
|
|
lora_mapping: Optional["LoRAMapping"] = None
|
|
lora_requests: Optional[Set[LoRARequest]] = None
|
|
attn_metadata: Optional["AttentionMetadata"] = None
|
|
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
|
real_batch_size: Optional[int] = None
|
|
batch_size_padded: Optional[int] = None
|
|
virtual_engine: int = 0
|
|
lora_ids: Optional[List[int]] = None
|
|
async_callback: Optional[Callable] = None
|
|
|
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
|
tensor_dict = {
|
|
"input_tokens": self.input_tokens,
|
|
"input_positions": self.input_positions,
|
|
"lora_requests": self.lora_requests,
|
|
"lora_mapping": self.lora_mapping,
|
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
|
"real_batch_size": self.real_batch_size,
|
|
"batch_size_padded": self.batch_size_padded,
|
|
"virtual_engine": self.virtual_engine,
|
|
"lora_ids": self.lora_ids,
|
|
}
|
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
|
return tensor_dict
|
|
|
|
@classmethod
|
|
def from_broadcasted_tensor_dict(
|
|
cls: Type[TModelInputForHPU],
|
|
tensor_dict: Dict[str, Any],
|
|
attn_backend: Optional["AttentionBackend"] = None,
|
|
) -> TModelInputForHPU:
|
|
if attn_backend is not None:
|
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
|
attn_backend, tensor_dict)
|
|
return cls(**tensor_dict)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU):
|
|
"""
|
|
Used by the ModelRunner.
|
|
"""
|
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
|
# Used for speculative decoding. We do not broadcast it because it is only
|
|
# used by the driver worker.
|
|
is_prompt: Optional[bool] = None
|
|
|
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
|
tensor_dict = {
|
|
"input_tokens": self.input_tokens,
|
|
"input_positions": self.input_positions,
|
|
"lora_requests": self.lora_requests,
|
|
"lora_mapping": self.lora_mapping,
|
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
|
"lora_ids": self.lora_ids,
|
|
}
|
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
|
self.sampling_metadata)
|
|
return tensor_dict
|
|
|
|
@classmethod
|
|
def from_broadcasted_tensor_dict(
|
|
cls,
|
|
tensor_dict: Dict[str, Any],
|
|
attn_backend: Optional["AttentionBackend"] = None,
|
|
) -> "ModelInputForHPUWithSamplingMetadata":
|
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
|
# FIXME(kzawora): this fails for whatever reason - why?
|
|
if attn_backend is not None:
|
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
|
attn_backend, tensor_dict)
|
|
return cls(**tensor_dict)
|
|
|
|
|
|
class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|
"""
|
|
Helper class for shared methods between GPU model runners.
|
|
"""
|
|
_model_input_cls: Type[TModelInputForHPU]
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
is_driver_worker: bool = False,
|
|
return_hidden_states: bool = False,
|
|
):
|
|
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
|
self.is_driver_worker = is_driver_worker
|
|
self.return_hidden_states = return_hidden_states
|
|
|
|
self.sliding_window = (self.model_config.get_sliding_window()
|
|
if self.model_config is not None else None)
|
|
self.device_config = (self.device_config if self.device_config
|
|
is not None else DeviceConfig())
|
|
self.device = self.device_config.device
|
|
self.enforce_eager = self.model_config.enforce_eager
|
|
self.max_num_seqs = self.scheduler_config.max_num_seqs
|
|
# NOTE(kzawora): Change that to scheduler_config.max_num_prefill_seqs
|
|
# once padding-aware scheduling gets merged
|
|
self.max_num_prefill_seqs = 64
|
|
self.max_model_len = self.scheduler_config.max_model_len
|
|
self.max_num_batched_tokens = \
|
|
self.scheduler_config.max_num_batched_tokens
|
|
self.block_size = self.cache_config.block_size
|
|
|
|
self.pin_memory = is_pin_memory_available()
|
|
self.kv_cache_dtype = self.cache_config.cache_dtype
|
|
|
|
self.attn_backend = get_attn_backend(
|
|
self.model_config.get_head_size(),
|
|
self.model_config.dtype,
|
|
self.kv_cache_dtype,
|
|
self.block_size,
|
|
self.model_config.is_attention_free,
|
|
)
|
|
|
|
# Lazy initialization
|
|
self.lora_manager: LRUCacheWorkerLoRAManager = None
|
|
self.model: torch.nn.Module = None
|
|
self.inc_initialized_successfully = False
|
|
|
|
# Profiler stats
|
|
self.profiler = HabanaHighLevelProfiler()
|
|
self.profiler_counter_helper = HabanaProfilerCounterHelper()
|
|
self.seen_configs: set = set()
|
|
self._mem_margin: Optional[int] = None
|
|
self.bucketing_global_state = HPUBucketingGlobalState()
|
|
self._setup_buckets()
|
|
self._set_gc_threshold()
|
|
|
|
def _set_gc_threshold(self) -> None:
|
|
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
|
|
# for comprehensive description of gc generations.
|
|
# We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority)
|
|
# to set particular generation threshold or use simpler
|
|
# VLLM_GC_THR_MULTIPLIER to multiply default values.
|
|
default_gc_thrs = list(gc.get_threshold())
|
|
requested_gc_thrs = [0] * len(default_gc_thrs)
|
|
for i in range(len(default_gc_thrs)):
|
|
requested_gc_thrs[i] = int(
|
|
os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i]))
|
|
if requested_gc_thrs == default_gc_thrs:
|
|
gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER',
|
|
2))
|
|
requested_gc_thrs = [
|
|
t * gc_thr_multiplier for t in default_gc_thrs
|
|
]
|
|
gc.set_threshold(*requested_gc_thrs)
|
|
|
|
# Multi-modal data support
|
|
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
|
.create_input_mapper(self.model_config)
|
|
|
|
self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP',
|
|
'false').lower() == 'true'
|
|
|
|
def load_model(self) -> None:
|
|
import habana_frameworks.torch.core as htcore
|
|
if self.model_config.quantization == 'inc' or \
|
|
self.model_config.quantization == 'fp8':
|
|
htcore.hpu_set_env()
|
|
with HabanaMemoryProfiler() as m:
|
|
with HabanaMemoryProfiler() as m_getmodel:
|
|
self.model = get_model(vllm_config=self.vllm_config)
|
|
msg = ("Pre-loading model weights on "
|
|
f"{next(self.model.parameters()).device} "
|
|
f"took {m_getmodel.get_summary_string()}")
|
|
logger.info(msg)
|
|
|
|
if self.lora_config:
|
|
assert hasattr(self.model, "supported_lora_modules"
|
|
) and self.model.supported_lora_modules, (
|
|
"Model does not support LoRA")
|
|
assert hasattr(self.model, "embedding_modules"
|
|
), "Model does not have embedding_modules"
|
|
assert hasattr(
|
|
self.model, "embedding_padding_modules"
|
|
), "Model does not have embedding_padding_modules"
|
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
|
self.scheduler_config.max_num_seqs,
|
|
self.scheduler_config.max_num_batched_tokens,
|
|
self.vocab_size, self.lora_config, self.device,
|
|
self.model.embedding_modules,
|
|
self.model.embedding_padding_modules)
|
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
|
|
|
if self.model_config.quantization == 'inc':
|
|
logger.info("Preparing model with INC..")
|
|
with HabanaMemoryProfiler() as m_inc:
|
|
from neural_compressor.torch.quantization import (
|
|
FP8Config, convert, prepare)
|
|
config = FP8Config.from_json_file(
|
|
os.getenv("QUANT_CONFIG", ""))
|
|
if config.measure:
|
|
self.model = prepare(self.model, config)
|
|
elif config.quantize:
|
|
self.model = convert(self.model, config)
|
|
htcore.hpu_initialize(self.model,
|
|
mark_only_scales_as_const=True)
|
|
self.inc_initialized_successfully = True
|
|
logger.info("Preparing model with INC took %s",
|
|
m_inc.get_summary_string())
|
|
else:
|
|
self.model = self.model.to("hpu")
|
|
htcore.mark_step()
|
|
torch.hpu.synchronize()
|
|
|
|
with HabanaMemoryProfiler() as m_wrap:
|
|
self.model = _maybe_wrap_in_hpu_graph(
|
|
self.model,
|
|
self.block_size,
|
|
dtype=self.model_config.dtype,
|
|
enforce_eager=self.enforce_eager)
|
|
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
|
|
logger.info(msg)
|
|
|
|
self.model_memory_usage = m.consumed_device_memory
|
|
msg = f"Loading model weights took in total {m.get_summary_string()}"
|
|
logger.info(msg)
|
|
|
|
def _use_graphs(self, batch_size, seq_len, is_prompt):
|
|
if self.enforce_eager:
|
|
return False
|
|
if self.skip_warmup:
|
|
return True
|
|
return (batch_size, seq_len, is_prompt) in self.graphed_buckets
|
|
|
|
def _is_valid_bucket(self, bucket):
|
|
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
|
|
|
|
def _setup_buckets(self) -> None:
|
|
align_bs = lambda x: min(self.max_num_seqs, x)
|
|
#FIXME: The default values should be max_model_len
|
|
max_prompt_seq = 1024
|
|
max_decode_seq = 2048
|
|
self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings(
|
|
'prompt',
|
|
'bs',
|
|
min=1,
|
|
step=align_bs(32),
|
|
max=self.max_num_prefill_seqs)
|
|
self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings(
|
|
'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs)
|
|
self.bucketing_global_state.prompt_seq_bucket_cfg = \
|
|
read_bucket_settings(
|
|
'prompt',
|
|
'seq',
|
|
min=self.block_size,
|
|
step=self.block_size,
|
|
max=max_prompt_seq)
|
|
self.bucketing_global_state.decode_block_bucket_cfg = \
|
|
read_bucket_settings(
|
|
'decode',
|
|
'block',
|
|
min=self.block_size,
|
|
step=self.block_size,
|
|
max=max(self.block_size,
|
|
self.max_num_seqs * max_decode_seq // self.block_size))
|
|
self.graphed_buckets: Set[Any] = set()
|
|
|
|
msg = ("Prompt bucket config (min, step, max_warmup) "
|
|
f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, "
|
|
f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}")
|
|
logger.info(msg)
|
|
|
|
msg = ("Decode bucket config (min, step, max_warmup) "
|
|
f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, "
|
|
f"block:{self.bucketing_global_state.decode_block_bucket_cfg}")
|
|
logger.info(msg)
|
|
|
|
def _prepare_prompt(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
) -> PreparePromptMetadata:
|
|
input_tokens: List[List[int]] = []
|
|
input_positions: List[List[int]] = []
|
|
slot_mapping: List[List[int]] = []
|
|
lora_index_mapping: List[List[int]] = []
|
|
lora_prompt_mapping: List[List[int]] = []
|
|
lora_requests: Set[LoRARequest] = set()
|
|
|
|
seq_lens: List[int] = []
|
|
context_lens: List[int] = []
|
|
query_lens: List[int] = []
|
|
prefix_block_tables: List[List[int]] = []
|
|
multi_modal_inputs_list: List[MultiModalInputs] = []
|
|
|
|
if len(seq_group_metadata_list) == 0:
|
|
return PreparePromptMetadata.empty()
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
assert seq_group_metadata.is_prompt
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
assert len(seq_ids) == 1
|
|
seq_id = seq_ids[0]
|
|
|
|
computed_block_nums = seq_group_metadata.computed_block_nums
|
|
if (self.scheduler_config is not None
|
|
and self.scheduler_config.chunked_prefill_enabled
|
|
and not (computed_block_nums is None
|
|
or computed_block_nums == [])):
|
|
raise RuntimeError(
|
|
"chunked prefill cannot be used with prefix caching "
|
|
"now.")
|
|
|
|
token_chunk_size = seq_group_metadata.token_chunk_size
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
context_len = seq_data.get_num_computed_tokens()
|
|
# We should use get_len here because in case of preemption
|
|
# it contains output tokens.
|
|
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
|
|
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
|
|
seq_lens.append(seq_len)
|
|
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
if computed_block_nums is not None and len(
|
|
computed_block_nums) > 0 and self.sliding_window is None:
|
|
# Prefix is not supported with sliding_window
|
|
context_len = len(computed_block_nums) * self.block_size
|
|
prompt_tokens = prompt_tokens[context_len:]
|
|
prefix_block_tables.append(computed_block_nums)
|
|
elif self.scheduler_config.chunked_prefill_enabled:
|
|
if seq_group_metadata.block_tables is not None:
|
|
# Prefill has chunked before.
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
|
prefix_block_tables.append(block_table)
|
|
else:
|
|
# The first prefill.
|
|
prefix_block_tables.append([])
|
|
else:
|
|
prefix_block_tables.append([])
|
|
# Right now, prefill start is always 0. However, this
|
|
# assumption can be changed once chunked prefill is introduced.
|
|
assert context_len == 0
|
|
|
|
# actual prompt lens
|
|
context_lens.append(context_len)
|
|
query_lens.append(seq_len - context_len)
|
|
input_tokens.append(prompt_tokens)
|
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
|
# is always the first token in the sequence.
|
|
input_positions.append(list(range(context_len, seq_len)))
|
|
|
|
mm_data = seq_group_metadata.multi_modal_data
|
|
if mm_data:
|
|
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
|
multi_modal_inputs_list.append(mm_kwargs)
|
|
|
|
if seq_group_metadata.block_tables is None:
|
|
# During memory profiling, the block tables are not initialized
|
|
# yet. In this case, we just use a dummy slot mapping.
|
|
slot_mapping.append([_PAD_SLOT_ID] * seq_len)
|
|
continue
|
|
|
|
# Compute the slot mapping.
|
|
slot_mapping.append([])
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
|
|
|
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
|
# where start_idx is max(0, seq_len - sliding_window).
|
|
# For example, if the prompt len is 10, sliding window is 8, and
|
|
# block size is 4, the first two tokens are masked and the slot
|
|
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
|
start_idx = 0
|
|
if self.sliding_window is not None:
|
|
assert context_len == 0, (
|
|
"Prefix caching is currently not supported with "
|
|
"sliding window attention")
|
|
start_idx = max(0, seq_len - self.sliding_window)
|
|
for i in range(context_len, seq_len):
|
|
if i < start_idx:
|
|
slot_mapping[-1].append(_PAD_SLOT_ID)
|
|
continue
|
|
|
|
block_number = block_table[i // self.block_size]
|
|
block_offset = i % self.block_size
|
|
slot = block_number * self.block_size + block_offset
|
|
slot_mapping[-1].append(slot)
|
|
|
|
max_query_len = max(query_lens)
|
|
sum_query_len = sum(query_lens)
|
|
real_num_seqs = len(query_lens)
|
|
assert max_query_len > 0
|
|
|
|
max_prompt_len = max(
|
|
find_bucket(max(seq_lens),
|
|
self.bucketing_global_state.prompt_seq_bucket_cfg),
|
|
self.block_size)
|
|
|
|
lora_ids: List[int] = []
|
|
for seq_group_metadata, context_len in zip(seq_group_metadata_list,
|
|
context_lens):
|
|
lora_id = seq_group_metadata.lora_int_id
|
|
lora_ids.append(lora_id)
|
|
|
|
if lora_id > 0:
|
|
lora_requests.add(seq_group_metadata.lora_request)
|
|
|
|
lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
|
|
lora_prompt_mapping.extend(
|
|
[lora_id] *
|
|
(max_prompt_len - context_len
|
|
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
|
|
|
input_tokens = make_tensor_with_pad(input_tokens,
|
|
max_len=max_prompt_len,
|
|
pad=0,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
|
|
input_positions = make_tensor_with_pad(input_positions,
|
|
max_len=max_prompt_len,
|
|
pad=0,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
|
|
slot_mapping = make_tensor_with_pad(slot_mapping,
|
|
max_len=max_prompt_len,
|
|
pad=_PAD_SLOT_ID,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
|
|
seq_lens_tensor = torch.tensor(seq_lens,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
|
|
block_indices, block_offsets = precompute_indices_and_offsets(
|
|
self.block_size, slot_mapping, True)
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
is_prompt=True,
|
|
block_list=None,
|
|
block_mapping=None,
|
|
block_usage=None,
|
|
block_indices=block_indices,
|
|
block_offsets=block_offsets,
|
|
block_scales=None,
|
|
attn_bias=None,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
num_prefills=real_num_seqs,
|
|
num_prefill_tokens=sum_query_len,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=
|
|
None # FIXME(kzawora): mutli-modality will not work here
|
|
)
|
|
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
|
|
|
|
return PreparePromptMetadata(input_tokens=input_tokens,
|
|
input_positions=input_positions,
|
|
attn_metadata=attn_metadata,
|
|
seq_lens=seq_lens,
|
|
query_lens=query_lens,
|
|
lora_index_mapping=lora_index_mapping,
|
|
lora_prompt_mapping=lora_prompt_mapping,
|
|
lora_requests=lora_requests,
|
|
multi_modal_kwargs=multi_modal_kwargs,
|
|
slot_mapping=slot_mapping,
|
|
lora_ids=lora_ids)
|
|
|
|
def _prepare_decode(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
) -> PrepareDecodeMetadata:
|
|
input_tokens: List[List[int]] = []
|
|
input_positions: List[List[int]] = []
|
|
slot_mapping: List[List[int]] = []
|
|
seq_lens: List[int] = []
|
|
block_tables: List[List[int]] = []
|
|
lora_index_mapping: List[List[int]] = []
|
|
lora_prompt_mapping: List[List[int]] = []
|
|
lora_requests: Set[LoRARequest] = set()
|
|
|
|
if len(seq_group_metadata_list) == 0:
|
|
return PrepareDecodeMetadata.empty()
|
|
lora_ids: List[int] = []
|
|
|
|
dummy_slots = itertools.cycle(
|
|
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
assert not seq_group_metadata.is_prompt
|
|
assert seq_group_metadata.token_chunk_size == 1
|
|
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
lora_id = seq_group_metadata.lora_int_id
|
|
lora_ids.append(lora_id)
|
|
|
|
if lora_id > 0:
|
|
lora_requests.add(seq_group_metadata.lora_request)
|
|
|
|
for seq_id in seq_ids:
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
generation_token = seq_data.get_last_token_id()
|
|
input_tokens.append([generation_token])
|
|
|
|
seq_len = seq_data.get_len()
|
|
position = seq_len - 1
|
|
input_positions.append([position])
|
|
|
|
seq_len = seq_len if self.sliding_window is None else min(
|
|
seq_len, self.sliding_window)
|
|
seq_lens.append(seq_len)
|
|
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
|
if len(block_table) == 0:
|
|
block_number = _PAD_BLOCK_ID
|
|
else:
|
|
block_number = block_table[position // self.block_size]
|
|
if block_number == _PAD_BLOCK_ID:
|
|
slot = next(dummy_slots)
|
|
else:
|
|
block_offset = position % self.block_size
|
|
slot = block_number * self.block_size + block_offset
|
|
slot_mapping.append([slot])
|
|
lora_index_mapping.append(lora_id)
|
|
lora_prompt_mapping.append(lora_id)
|
|
|
|
if self.sliding_window is not None:
|
|
sliding_window_blocks = (self.sliding_window //
|
|
self.block_size)
|
|
block_table = block_table[-sliding_window_blocks:]
|
|
block_tables.append(block_table)
|
|
|
|
input_tokens = torch.tensor(input_tokens,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
input_positions = torch.tensor(input_positions,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
|
|
num_decode_tokens = sum(seq_lens)
|
|
|
|
blocks_used = [len(bt) for bt in block_tables if bt]
|
|
block_list = []
|
|
block_scales = []
|
|
for i, bt in enumerate(block_tables):
|
|
block_list.extend(bt)
|
|
blocks_in_group = len(bt)
|
|
if blocks_in_group > 0:
|
|
scale = 1.0 / blocks_in_group
|
|
block_scales.extend([scale] * blocks_in_group)
|
|
|
|
block_mapping_nested: List[List[int]] = [
|
|
[i] * b_u for i, b_u in enumerate(blocks_used)
|
|
]
|
|
block_mapping: List[int] = list(
|
|
itertools.chain.from_iterable(block_mapping_nested))
|
|
|
|
last_block = [
|
|
sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)
|
|
]
|
|
block_usage = [[self.block_size] * (b_u - 1) + [lb]
|
|
for b_u, lb in zip(blocks_used, last_block)]
|
|
block_usage = list(itertools.chain(*block_usage))
|
|
|
|
block_bucket_size = find_bucket(
|
|
len(block_list),
|
|
self.bucketing_global_state.decode_block_bucket_cfg)
|
|
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
|
|
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
|
|
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
|
block_scales = pad_list(block_scales, block_bucket_size, 0.0)
|
|
|
|
block_list = torch.tensor(block_list,
|
|
dtype=torch.int,
|
|
device=self.device)
|
|
block_mapping = torch.tensor(block_mapping,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
block_usage = torch.tensor(block_usage,
|
|
dtype=self.model_config.dtype,
|
|
device=self.device)
|
|
|
|
slot_mapping = torch.tensor(slot_mapping,
|
|
dtype=torch.long,
|
|
device=self.device)
|
|
|
|
block_indices, block_offsets = precompute_indices_and_offsets(
|
|
self.block_size, slot_mapping, False)
|
|
block_scales = torch.tensor(block_scales,
|
|
dtype=self.model_config.dtype,
|
|
device=self.device)
|
|
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
is_prompt=False,
|
|
block_list=block_list,
|
|
block_mapping=block_mapping,
|
|
block_usage=block_usage,
|
|
block_indices=block_indices,
|
|
block_offsets=block_offsets,
|
|
block_scales=block_scales,
|
|
attn_bias=None,
|
|
seq_lens_tensor=None,
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=num_decode_tokens,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None)
|
|
return PrepareDecodeMetadata(input_tokens=input_tokens,
|
|
input_positions=input_positions,
|
|
attn_metadata=attn_metadata,
|
|
lora_index_mapping=lora_index_mapping,
|
|
lora_prompt_mapping=lora_prompt_mapping,
|
|
lora_requests=lora_requests,
|
|
slot_mapping=slot_mapping,
|
|
lora_ids=lora_ids)
|
|
|
|
def prepare_input_tensors(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
) -> Tuple[TModelInputForHPU, SamplingMetadata]:
|
|
if len(seq_group_metadata_list) == 0:
|
|
return self._model_input_cls(), None
|
|
|
|
input_tokens = None
|
|
input_positions = None
|
|
lora_mapping = None
|
|
lora_requests = None
|
|
multi_modal_kwargs = None
|
|
batch_type = None
|
|
seq_lens = None
|
|
query_lens = None
|
|
real_batch_size = None
|
|
batch_size_padded = None
|
|
|
|
self.event_start = self.profiler.get_timestamp_us()
|
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
|
base_event_name = 'prompt' if is_prompt else 'decode'
|
|
self.profiler.start('internal', base_event_name)
|
|
|
|
real_batch_size = len(seq_group_metadata_list)
|
|
bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \
|
|
if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg
|
|
batch_size_padded = find_bucket(real_batch_size, bucket_cfg)
|
|
batch_size_padding = batch_size_padded - real_batch_size
|
|
seq_group_metadata_list = seq_group_metadata_list.copy()
|
|
if batch_size_padding > 0:
|
|
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
|
|
0, 0, is_prompt)
|
|
seq_group_metadata_list.extend(dummy_seq_group_metadata
|
|
for _ in range(batch_size_padding))
|
|
|
|
prefill_reqs = []
|
|
decode_reqs = []
|
|
for seq_group_meta in seq_group_metadata_list:
|
|
if seq_group_meta.is_prompt:
|
|
prefill_reqs.append(seq_group_meta)
|
|
else:
|
|
decode_reqs.append(seq_group_meta)
|
|
|
|
# Prepare input tensors.
|
|
(
|
|
input_tokens,
|
|
input_positions,
|
|
prefill_attn_metadata,
|
|
seq_lens,
|
|
query_lens,
|
|
lora_index_mapping,
|
|
lora_prompt_mapping,
|
|
lora_requests,
|
|
multi_modal_kwargs,
|
|
slot_mapping,
|
|
lora_ids,
|
|
) = self._prepare_prompt(prefill_reqs)
|
|
(
|
|
decode_input_tokens,
|
|
decode_input_positions,
|
|
decode_attn_metadata,
|
|
decode_lora_index_mapping,
|
|
decode_lora_prompt_mapping,
|
|
decode_lora_requests,
|
|
decode_slot_mapping,
|
|
decode_lora_ids,
|
|
) = self._prepare_decode(decode_reqs)
|
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
|
seq_lens, query_lens,
|
|
self.device,
|
|
self.pin_memory)
|
|
|
|
if not self.scheduler_config.chunked_prefill_enabled:
|
|
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
|
|
|
num_prefills = len(seq_lens)
|
|
num_prefill_tokens = len(input_tokens)
|
|
num_decode_tokens = len(decode_input_tokens)
|
|
|
|
# NOTE(kzawora): Here we diverge from GPU code - we don't
|
|
# support mixed batches, so we either use decode or prefill
|
|
# inputs, without coalescing.
|
|
assert (num_prefills == 0 and num_decode_tokens > 0) or (
|
|
num_prefills > 0
|
|
and num_decode_tokens == 0), "HPU does not support mixed batches!"
|
|
if num_decode_tokens > 0:
|
|
input_tokens = decode_input_tokens
|
|
input_positions = decode_input_positions
|
|
slot_mapping = decode_slot_mapping
|
|
lora_index_mapping = decode_lora_index_mapping
|
|
lora_prompt_mapping = decode_lora_prompt_mapping
|
|
lora_requests = decode_lora_requests
|
|
lora_ids = decode_lora_ids
|
|
|
|
# FIXME: We need to adjust selected_token_indices to accommodate
|
|
# for padding
|
|
max_len = input_tokens.size(1)
|
|
paddings = [max_len - s for s in seq_lens]
|
|
paddings = [0] + paddings[:-1]
|
|
paddings = list(itertools.accumulate(paddings))
|
|
paddings_prompt_logprobs = []
|
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
if seq_group_metadata.sampling_params.prompt_logprobs is not None \
|
|
and seq_group_metadata.is_prompt:
|
|
paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i])
|
|
paddings = torch.tensor(
|
|
paddings_prompt_logprobs if paddings_prompt_logprobs else paddings,
|
|
dtype=sampling_metadata.selected_token_indices.dtype,
|
|
device=sampling_metadata.selected_token_indices.device)
|
|
sampling_metadata.selected_token_indices.add_(paddings)
|
|
|
|
if self.lora_config:
|
|
lora_mapping = LoRAMapping(
|
|
**dict(index_mapping=lora_index_mapping,
|
|
prompt_mapping=lora_prompt_mapping,
|
|
is_prefill=(num_prefills > 0)))
|
|
else:
|
|
lora_mapping = None
|
|
|
|
if (prefill_attn_metadata is not None
|
|
and decode_attn_metadata is not None):
|
|
batch_type = BatchType.MIXED
|
|
raise NotImplementedError("Mixed batch is not supported on HPU")
|
|
elif prefill_attn_metadata is not None:
|
|
batch_type = BatchType.PREFILL
|
|
else:
|
|
batch_type = BatchType.DECODE
|
|
|
|
metadata_dict = {
|
|
"input_tokens": input_tokens,
|
|
"input_positions": input_positions,
|
|
"selected_token_indices": sampling_metadata.selected_token_indices,
|
|
"lora_requests": lora_requests,
|
|
"lora_mapping": lora_mapping,
|
|
"multi_modal_kwargs": multi_modal_kwargs,
|
|
"num_prefill_tokens": num_prefill_tokens,
|
|
"num_decode_tokens": num_decode_tokens,
|
|
"slot_mapping": slot_mapping,
|
|
"num_prefills": num_prefills,
|
|
"batch_type": batch_type,
|
|
"seq_lens": seq_lens,
|
|
"query_lens": query_lens
|
|
}
|
|
if prefill_attn_metadata is not None:
|
|
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
|
else:
|
|
assert decode_attn_metadata is not None
|
|
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
|
|
|
attn_metadata = prefill_attn_metadata if \
|
|
prefill_attn_metadata is not None else decode_attn_metadata
|
|
|
|
return self._model_input_cls(input_tokens=input_tokens,
|
|
seq_lens=seq_lens,
|
|
query_lens=query_lens,
|
|
input_positions=input_positions,
|
|
attn_metadata=attn_metadata,
|
|
lora_requests=lora_requests,
|
|
lora_mapping=lora_mapping,
|
|
multi_modal_kwargs=multi_modal_kwargs,
|
|
real_batch_size=real_batch_size,
|
|
batch_size_padded=batch_size_padded,
|
|
lora_ids=lora_ids), \
|
|
sampling_metadata
|
|
|
|
def _seq_len(self, attn_metadata):
|
|
if attn_metadata.num_prefills != 0:
|
|
return attn_metadata.slot_mapping.size(1)
|
|
else:
|
|
return attn_metadata.block_list.numel()
|
|
|
|
def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
|
|
# NOTE(kzawora): To anyone working on this in the future:
|
|
# Trimming metadata is required when using HPUGraphs.
|
|
# Attention metadata is going to be hashed by PT bridge, and
|
|
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
|
|
|
# Before you put more keys in here, make sure you know their
|
|
# value type and make sure you know how it's going to be hashed.
|
|
# You can find that information in input_hash function
|
|
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
|
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
|
|
|
# If you use primitive types here - they will get hashed based
|
|
# on their value. You *will* get lots of excessive graph captures
|
|
# (and an OOM eventually) if you decide to put something like
|
|
# seq_len int here.
|
|
# If you absolutely need a scalar, put it in a tensor. Tensors
|
|
# get hashed using their metadata, not their values:
|
|
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
|
# input_hash(123) != input_hash(321)
|
|
# input_hash("abc") != input_hash("cba")
|
|
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
|
|
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
|
|
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
|
|
'block_offsets', 'block_scales'
|
|
])
|
|
return attention_metadata
|
|
|
|
def create_dummy_seq_group_metadata(self,
|
|
group_id,
|
|
seq_len,
|
|
is_prompt,
|
|
lora_request=None):
|
|
sampling_params = SamplingParams(temperature=0)
|
|
num_blocks = math.ceil(seq_len / self.block_size)
|
|
seq_len = max(seq_len, 1)
|
|
if is_prompt:
|
|
input_len = seq_len
|
|
output_len = 0
|
|
block_tables = None
|
|
else:
|
|
input_len = seq_len - 1
|
|
output_len = 1
|
|
block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks}
|
|
prompt_token_ids = [0] * input_len
|
|
output_token_ids = [1] * output_len
|
|
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
|
|
seq_data = SequenceData(prompt_token_ids_array)
|
|
seq_data.output_token_ids = output_token_ids
|
|
return SequenceGroupMetadata(request_id=str(group_id),
|
|
is_prompt=(output_len == 0),
|
|
seq_data={group_id: seq_data},
|
|
sampling_params=sampling_params,
|
|
block_tables=block_tables,
|
|
lora_request=lora_request)
|
|
|
|
def profile_run(self) -> None:
|
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
|
kv_caches = [None] * num_layers
|
|
max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1]
|
|
max_seq_len = min(
|
|
self.bucketing_global_state.prompt_seq_bucket_cfg[-1],
|
|
self.max_num_batched_tokens // max_batch_size)
|
|
|
|
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
|
|
False, True)
|
|
return
|
|
|
|
def warmup_scenario(self,
|
|
batch_size,
|
|
seq_len,
|
|
is_prompt,
|
|
kv_caches,
|
|
is_pt_profiler_run=False,
|
|
is_lora_profile_run=False) -> None:
|
|
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
|
scenario_name = ("warmup_"
|
|
f"{'prompt' if is_prompt else 'decode'}_"
|
|
f"bs{batch_size}_"
|
|
f"seq{seq_len}_"
|
|
f"graphs{'T' if use_graphs else 'F'}")
|
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
|
# This represents the maximum number of different requests
|
|
# that will have unique loras, an therefore the max amount of memory
|
|
# consumption create dummy lora request copies from the lora request
|
|
# passed in, which contains a lora from the lora warmup path.
|
|
dummy_lora_requests: List[LoRARequest] = []
|
|
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
|
if self.lora_config and is_lora_profile_run:
|
|
assert self.lora_manager is not None
|
|
with self.lora_manager.dummy_lora_cache():
|
|
for idx in range(self.lora_config.max_loras):
|
|
lora_id = idx + 1
|
|
dummy_lora_request = LoRARequest(
|
|
lora_name=f"warmup_{lora_id}",
|
|
lora_int_id=lora_id,
|
|
lora_local_path="/not/a/real/path",
|
|
)
|
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
|
rank=LORA_WARMUP_RANK)
|
|
dummy_lora_requests.append(dummy_lora_request)
|
|
dummy_lora_requests_per_seq = [
|
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
|
for idx in range(max_num_seqs)
|
|
]
|
|
self.profiler.start('internal', scenario_name)
|
|
times = 3 if use_graphs or is_pt_profiler_run else 1
|
|
if self.lora_config and not is_lora_profile_run:
|
|
lora_mapping = LoRAMapping(
|
|
**dict(index_mapping=[0] * batch_size * seq_len,
|
|
prompt_mapping=[0] * batch_size * seq_len,
|
|
is_prefill=is_prompt))
|
|
self.set_active_loras(set(), lora_mapping)
|
|
if is_prompt:
|
|
seqs = [
|
|
self.create_dummy_seq_group_metadata(
|
|
i,
|
|
seq_len,
|
|
is_prompt,
|
|
lora_request=dummy_lora_requests_per_seq[i]
|
|
if dummy_lora_requests_per_seq else None)
|
|
for i in range(batch_size)
|
|
]
|
|
else:
|
|
# FIXME: seq_len is actually number of blocks
|
|
blocks = [seq_len // batch_size for _ in range(batch_size)]
|
|
blocks[0] += seq_len % batch_size
|
|
seqs = [
|
|
self.create_dummy_seq_group_metadata(
|
|
i,
|
|
b * self.block_size - 1,
|
|
is_prompt,
|
|
lora_request=dummy_lora_requests_per_seq[i]
|
|
if dummy_lora_requests_per_seq else None)
|
|
for i, b in enumerate(blocks)
|
|
]
|
|
torch.hpu.synchronize()
|
|
profiler = None
|
|
if is_pt_profiler_run and self.is_driver_worker:
|
|
profiler = setup_profiler()
|
|
profiler.start()
|
|
for _ in range(times):
|
|
inputs = self.prepare_model_input(seqs)
|
|
self.execute_model(inputs, kv_caches, warmup_mode=True)
|
|
torch.hpu.synchronize()
|
|
if profiler:
|
|
profiler.step()
|
|
if profiler:
|
|
profiler.stop()
|
|
self.profiler.end()
|
|
gc.collect()
|
|
|
|
def remove_all_loras(self):
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
self.lora_manager.remove_all_adapters()
|
|
|
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
|
lora_mapping: LoRAMapping) -> None:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.add_adapter(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.remove_adapter(lora_id)
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.pin_adapter(lora_id)
|
|
|
|
def list_loras(self) -> Set[int]:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.list_adapters()
|
|
|
|
def log_warmup(self, phase, i, max_i, batch_size, seq_len):
|
|
free_mem = format_bytes(
|
|
HabanaMemoryProfiler.current_free_device_memory())
|
|
dim = "num_blocks"
|
|
if phase == "Prompt":
|
|
dim = "seq_len"
|
|
msg = (f"[Warmup][{phase}][{i+1}/{max_i}] "
|
|
f"batch_size:{batch_size} "
|
|
f"{dim}:{seq_len} "
|
|
f"free_mem:{free_mem}")
|
|
logger.info(msg)
|
|
|
|
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
|
|
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
|
|
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
|
|
len(buckets), batch_size, seq_len)
|
|
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
|
|
|
def warmup_graphs(self,
|
|
strategy,
|
|
buckets,
|
|
is_prompt,
|
|
kv_caches,
|
|
available_mem,
|
|
starting_mem=0,
|
|
total_batch_seq=0.001):
|
|
total_mem = starting_mem
|
|
idx = 0
|
|
phase = f'Graph/{"Prompt" if is_prompt else "Decode"}'
|
|
num_candidates = len(buckets)
|
|
ordering : Union[Callable[[Any], Tuple[Any, Any]], \
|
|
Callable[[Any], Tuple[Any, Any, Any]]]
|
|
if strategy == 'min_tokens':
|
|
ordering = lambda b: (b[0] * b[1], b[1], b[0])
|
|
elif strategy == 'max_bs':
|
|
ordering = lambda b: (-b[0], b[1])
|
|
else:
|
|
raise NotImplementedError(
|
|
f'Unsupported graph allocation strategy: {strategy}')
|
|
buckets = list(sorted(buckets, key=ordering))
|
|
captured_all = True
|
|
for idx, (batch_size, seq_len) in enumerate(buckets):
|
|
# Graph memory usage is proportional to seq dimension in a batch
|
|
batch_seq = batch_size * seq_len if is_prompt else batch_size
|
|
mem_estimate = batch_seq / total_batch_seq * total_mem
|
|
if mem_estimate >= available_mem:
|
|
captured_all = False
|
|
continue
|
|
graphed_bucket = (batch_size, seq_len, is_prompt)
|
|
if graphed_bucket in self.graphed_buckets:
|
|
continue
|
|
self.graphed_buckets.add(graphed_bucket)
|
|
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
|
|
with HabanaMemoryProfiler() as mem_prof:
|
|
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
|
used_mem = align_workers(mem_prof.consumed_device_memory,
|
|
torch.distributed.ReduceOp.MAX)
|
|
available_mem -= used_mem
|
|
total_mem += used_mem
|
|
total_batch_seq += batch_seq
|
|
|
|
return total_mem, total_batch_seq, captured_all
|
|
|
|
def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):
|
|
num_candidates = len(buckets)
|
|
phase = f'Graph/{"Prompt" if is_prompt else "Decode"}'
|
|
graphed = list(c[:2] for c in self.graphed_buckets
|
|
if c[2] == is_prompt)
|
|
if num_candidates == 0:
|
|
num_candidates = 1
|
|
msg = (f'{phase} captured:{len(graphed)} '
|
|
f'({100 * len(graphed) / num_candidates:.1f}%) '
|
|
f'used_mem:{format_bytes(total_mem)} '
|
|
f'buckets:{sorted(list(graphed))}')
|
|
logger.info(msg)
|
|
|
|
@torch.inference_mode()
|
|
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
|
|
if profile := os.environ.get('VLLM_PT_PROFILE', None):
|
|
phase, bs, seq_len, graph = profile.split('_')
|
|
is_prompt = phase == 'prompt'
|
|
graphs = graph == 't'
|
|
if graphs:
|
|
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
|
|
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
|
|
True)
|
|
raise AssertionError("Finished profiling")
|
|
if self.skip_warmup:
|
|
logger.info("Skipping warmup...")
|
|
return
|
|
self.profiler.start('internal', 'warmup')
|
|
max_blocks = kv_caches[0][0].size(0)
|
|
|
|
self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \
|
|
generate_prompt_buckets(
|
|
self.bucketing_global_state.prompt_bs_bucket_cfg,
|
|
self.bucketing_global_state.prompt_seq_bucket_cfg,
|
|
self.max_num_batched_tokens)
|
|
|
|
msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} "
|
|
f"prompt buckets [bs, seq]: \
|
|
{list(sorted(self.bucketing_global_state.prompt_buckets))}")
|
|
logger.info(msg)
|
|
|
|
msg = (f"Omitted {len(prompt_omitted_buckets)} "
|
|
"prompt buckets due to exceeded token budget "
|
|
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
|
|
logger.info(msg)
|
|
|
|
msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
|
|
logger.debug(msg)
|
|
|
|
self.bucketing_global_state.decode_buckets = generate_decode_buckets(
|
|
self.bucketing_global_state.decode_bs_bucket_cfg,
|
|
self.bucketing_global_state.decode_block_bucket_cfg, max_blocks)
|
|
logger.info("Generated %d decode buckets [bs, total_blocks]: %s",
|
|
len(self.bucketing_global_state.decode_buckets),
|
|
list(sorted(self.bucketing_global_state.decode_buckets)))
|
|
|
|
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
|
|
cache_size_limit = len(
|
|
self.bucketing_global_state.prompt_buckets) + len(
|
|
self.bucketing_global_state.decode_buckets) + 1
|
|
torch._dynamo.config.cache_size_limit = max(
|
|
cache_size_limit, torch._dynamo.config.cache_size_limit)
|
|
# Multiply by 8 to follow the original default ratio between
|
|
# the cache_size_limit and accumulated_cache_size_limit
|
|
torch._dynamo.config.accumulated_cache_size_limit = max(
|
|
cache_size_limit * 8,
|
|
torch._dynamo.config.accumulated_cache_size_limit)
|
|
|
|
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
|
|
start_time = time.perf_counter()
|
|
|
|
compile_only_mode_context = functools.partial(bc.env_setting,
|
|
"PT_COMPILE_ONLY_MODE",
|
|
True)
|
|
can_use_compile_only_mode = True
|
|
try:
|
|
with compile_only_mode_context():
|
|
pass
|
|
logger.debug("Using PT_COMPILE_ONLY_MODE.")
|
|
except KeyError:
|
|
can_use_compile_only_mode = False
|
|
logger.warning('Cannot use PT_COMPILE_ONLY_MODE. '
|
|
'Warmup time will be negatively impacted. '
|
|
'Please update Gaudi Software Suite.')
|
|
with compile_only_mode_context(
|
|
) if can_use_compile_only_mode else contextlib.nullcontext():
|
|
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
|
|
True, kv_caches)
|
|
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
|
|
False, kv_caches)
|
|
|
|
if not self.enforce_eager and htorch.utils.internal.is_lazy():
|
|
assert self.mem_margin is not None, \
|
|
("HabanaWorker.determine_num_available_blocks needs "
|
|
"to be called before warming up the model.")
|
|
free_mem = HabanaMemoryProfiler.current_free_device_memory()
|
|
graph_free_mem = free_mem - self.mem_margin
|
|
graph_free_mem = align_workers(graph_free_mem,
|
|
torch.distributed.ReduceOp.MIN)
|
|
prompt_graph_mem_ratio = float(
|
|
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3'))
|
|
prompt_available_memory = (prompt_graph_mem_ratio *
|
|
graph_free_mem)
|
|
decode_available_memory = (graph_free_mem -
|
|
prompt_available_memory)
|
|
msg = (
|
|
f"Using {format_bytes(graph_free_mem)}"
|
|
f"/{format_bytes(free_mem)} "
|
|
"of free device memory for HPUGraphs, "
|
|
f"{format_bytes(prompt_available_memory)} for prompt and "
|
|
f"{format_bytes(decode_available_memory)} for decode "
|
|
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})")
|
|
logger.info(msg)
|
|
prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY',
|
|
'min_tokens')
|
|
decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY',
|
|
'max_bs')
|
|
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
|
|
self.warmup_graphs(
|
|
prompt_strategy, self.bucketing_global_state.prompt_buckets,
|
|
True, kv_caches, prompt_available_memory)
|
|
mem_post_decode, decode_batch_seq, decode_captured_all = \
|
|
self.warmup_graphs(
|
|
decode_strategy, self.bucketing_global_state.decode_buckets,
|
|
False, kv_caches, decode_available_memory)
|
|
|
|
# Not all prompt buckets were captured, but all decode buckets
|
|
# were captured and we have some free graph-allocated space
|
|
# left. Let's try to use it for capturing more prompt buckets.
|
|
if (mem_post_decode + mem_post_prompt < graph_free_mem
|
|
and not prompt_captured_all and decode_captured_all):
|
|
mem_post_prompt, _, prompt_captured_all = (
|
|
self.warmup_graphs(
|
|
prompt_strategy,
|
|
self.bucketing_global_state.prompt_buckets, True,
|
|
kv_caches,
|
|
graph_free_mem - mem_post_prompt - mem_post_decode,
|
|
mem_post_prompt, prompt_batch_seq))
|
|
|
|
# Not all decode buckets were captured, but all prompt buckets
|
|
# were captured and we have some free graph-allocated space
|
|
# left. Let's try to use it for capturing more decode buckets.
|
|
if mem_post_decode + mem_post_prompt < graph_free_mem \
|
|
and not decode_captured_all \
|
|
and prompt_captured_all:
|
|
mem_post_decode, _, _ = self.warmup_graphs(
|
|
decode_strategy,
|
|
self.bucketing_global_state.decode_buckets, False,
|
|
kv_caches,
|
|
graph_free_mem - mem_post_prompt - mem_post_decode,
|
|
mem_post_decode, decode_batch_seq)
|
|
|
|
self.log_graph_warmup_summary(
|
|
self.bucketing_global_state.prompt_buckets, True,
|
|
mem_post_prompt)
|
|
self.log_graph_warmup_summary(
|
|
self.bucketing_global_state.decode_buckets, False,
|
|
mem_post_decode)
|
|
|
|
end_time = time.perf_counter()
|
|
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
|
|
elapsed_time = end_time - start_time
|
|
msg = (
|
|
f"Warmup finished in {elapsed_time:.0f} secs, "
|
|
f"allocated {format_bytes(end_mem - start_mem)} of device memory")
|
|
logger.info(msg)
|
|
self.profiler.end()
|
|
|
|
@property
|
|
def vocab_size(self) -> int:
|
|
return self.model_config.get_vocab_size()
|
|
|
|
@property
|
|
def mem_margin(self) -> Optional[int]:
|
|
return self._mem_margin
|
|
|
|
@mem_margin.setter
|
|
def mem_margin(self, value):
|
|
self._mem_margin = value
|
|
|
|
|
|
def _maybe_wrap_in_hpu_graph(*args, **kwargs):
|
|
return htorch.hpu.wrap_in_hpu_graph(
|
|
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
|
|
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)
|
|
|
|
|
|
class HabanaProfilerCounterHelper:
|
|
|
|
def __init__(self):
|
|
self.niter = 0
|
|
self.average_real_throughput = None
|
|
self.logged_once = False
|
|
self.real_seq_lens = []
|
|
self.prompt_seq_lens = []
|
|
|
|
def capture_seq_group_metadata_stats(self, seq_group_metadata_list):
|
|
self.real_seq_lens = [
|
|
len(seq_data.prompt_token_ids) + len(seq_data.output_token_ids)
|
|
for seq_group_metadata in seq_group_metadata_list
|
|
for seq_data in seq_group_metadata.seq_data.values()
|
|
]
|
|
self.prompt_seq_lens = [
|
|
len(seq_data.prompt_token_ids)
|
|
for seq_group_metadata in seq_group_metadata_list
|
|
for seq_data in seq_group_metadata.seq_data.values()
|
|
]
|
|
|
|
def get_counter_dict(self, cache_config, duration, seq_len,
|
|
batch_size_padded, real_batch_size, is_prompt):
|
|
throughput = batch_size_padded / (duration / 1e6)
|
|
throughput_effective = real_batch_size / (duration / 1e6)
|
|
|
|
real_max_seq_len = max(self.real_seq_lens)
|
|
real_num_tokens = sum(self.real_seq_lens)
|
|
padded_num_tokens = batch_size_padded * seq_len
|
|
batch_token_utilization = real_num_tokens / padded_num_tokens
|
|
if self.average_real_throughput is None:
|
|
self.average_real_throughput = throughput_effective
|
|
else: # https://www.heikohoffmann.de/htmlthesis/node134.html
|
|
self.average_real_throughput = self.average_real_throughput + 1 / (
|
|
self.niter + 1) * (throughput_effective -
|
|
self.average_real_throughput)
|
|
phase = "prompt" if is_prompt else "decode"
|
|
counters = {
|
|
f'{phase}_bucket_batch_size': batch_size_padded,
|
|
f'{phase}_batch_size': real_batch_size,
|
|
f'{phase}_bucket_seq_len': seq_len,
|
|
f'{phase}_seq_len': real_max_seq_len,
|
|
f'{phase}_bucket_gen_throughput': throughput,
|
|
f'{phase}_real_gen_throughput': throughput_effective,
|
|
f'{phase}_batch_token_utilization': batch_token_utilization,
|
|
'average_real_throughput': self.average_real_throughput,
|
|
'engine_iteration': self.niter,
|
|
}
|
|
self.niter += 1
|
|
if is_prompt:
|
|
prompt_bucket_in_throughput = (seq_len * batch_size_padded) / (
|
|
duration / 1e6)
|
|
prompt_real_in_throughput = sum(
|
|
self.prompt_seq_lens) / (duration / 1e6)
|
|
counters[
|
|
f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput
|
|
counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput
|
|
|
|
# KV cache might not be created yet (e.g. for profiling run)
|
|
if cache_config.num_gpu_blocks is not None and \
|
|
cache_config.num_gpu_blocks != 0:
|
|
cache_num_blocks_used = [
|
|
math.ceil(sl / cache_config.block_size)
|
|
for sl in self.real_seq_lens
|
|
]
|
|
cache_total_num_blocks_used = sum(cache_num_blocks_used)
|
|
num_cache_blocks = cache_config.num_gpu_blocks
|
|
cache_total_num_free_blocks = \
|
|
num_cache_blocks - cache_total_num_blocks_used
|
|
cache_computed_utilization = \
|
|
cache_total_num_blocks_used / num_cache_blocks
|
|
max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size)
|
|
batch_block_utilization = cache_total_num_blocks_used / (
|
|
batch_size_padded * max_blocks_per_seq)
|
|
counters['cache_num_blocks_used'] = cache_total_num_blocks_used
|
|
counters['cache_num_free_blocks'] = cache_total_num_free_blocks
|
|
counters['cache_computed_utilization'] = cache_computed_utilization
|
|
counters[
|
|
f'{phase}_batch_block_utilization'] = batch_block_utilization
|
|
if not self.logged_once:
|
|
counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks
|
|
counters[
|
|
'const_gpu_memory_utilization'] = \
|
|
cache_config.gpu_memory_utilization
|
|
counters['const_block_size'] = cache_config.block_size
|
|
self.logged_once = True
|
|
return counters
|
|
|
|
|
|
def unwrap_model(model):
|
|
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
|
|
return unwrap_model(model._orig_mod)
|
|
else:
|
|
model = list(vars(model)['_modules'].values())[0]
|
|
modules = list(vars(model)['_modules'].values())
|
|
return modules
|
|
|
|
|
|
class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
|
"""
|
|
GPU model runner with sampling step.
|
|
"""
|
|
_model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = (
|
|
ModelInputForHPUWithSamplingMetadata)
|
|
|
|
def make_model_input_from_broadcasted_tensor_dict(
|
|
self,
|
|
tensor_dict: Dict[str, Any],
|
|
) -> ModelInputForHPUWithSamplingMetadata:
|
|
return (
|
|
ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
|
tensor_dict,
|
|
attn_backend=self.attn_backend,
|
|
))
|
|
|
|
@torch.inference_mode()
|
|
def prepare_model_input(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
virtual_engine: int = 0,
|
|
finished_requests_ids: Optional[List[str]] = None
|
|
) -> ModelInputForHPUWithSamplingMetadata:
|
|
"""Prepare the model input based on a given sequence group, including
|
|
metadata for the sampling step.
|
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
|
The result tensors and data structure also batches input in prefill
|
|
-> decode order. For example,
|
|
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
If cuda graph is required, this API automatically pads inputs.
|
|
"""
|
|
with self.profiler.record_event('internal', 'prepare_input_tensors'):
|
|
assert seq_group_metadata_list is not None
|
|
if self.profiler.enabled:
|
|
self.profiler_counter_helper.capture_seq_group_metadata_stats(
|
|
seq_group_metadata_list=seq_group_metadata_list)
|
|
model_input, sampling_metadata = self.prepare_input_tensors(
|
|
seq_group_metadata_list)
|
|
assert model_input.attn_metadata is not None
|
|
is_prompt = model_input.attn_metadata.is_prompt
|
|
|
|
return dataclasses.replace(model_input,
|
|
sampling_metadata=sampling_metadata,
|
|
is_prompt=is_prompt,
|
|
virtual_engine=virtual_engine)
|
|
|
|
def finish_measurements(self):
|
|
from neural_compressor.torch.quantization import finalize_calibration
|
|
finalize_calibration(self.model.model)
|
|
|
|
def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
|
|
cfg = (batch_size, seq_len, is_prompt)
|
|
seen = cfg in self.seen_configs
|
|
self.seen_configs.add(cfg)
|
|
if not seen and not warmup_mode:
|
|
phase = 'prompt' if is_prompt else 'decode'
|
|
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
|
|
phase, batch_size, seq_len)
|
|
|
|
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
|
|
is_prompt: bool):
|
|
'''
|
|
This is a helper function to create the mask for lora computations.
|
|
Lora Mask is needed to ensure we match the correct lora weights for the
|
|
for the request.
|
|
For Prompt phase we have
|
|
lora_mask with shape (batch_size * seq_len, max_loras * max_rank)
|
|
lora_logits_mask with shape (batch_size, max_loras * max_rank)
|
|
For Decode phase we have both
|
|
lora_mask and lora_logits_mask with shape
|
|
(batch_size, max_loras * max_rank)
|
|
'''
|
|
lora_mask: torch.Tensor = None
|
|
lora_logits_mask: torch.Tensor = None
|
|
lora_index = 0
|
|
|
|
if self.lora_config:
|
|
if is_prompt:
|
|
lora_mask = torch.zeros(
|
|
input_tokens.shape[0] * input_tokens.shape[1],
|
|
(self.lora_config.max_loras) *\
|
|
self.lora_config.max_lora_rank,
|
|
dtype=self.lora_config.lora_dtype)
|
|
lora_logits_mask = torch.zeros(
|
|
input_tokens.shape[0], (self.lora_config.max_loras) *
|
|
self.lora_config.max_lora_rank,
|
|
dtype=self.lora_config.lora_dtype)
|
|
|
|
ones = torch.ones(input_tokens.shape[1],
|
|
self.lora_config.max_lora_rank,
|
|
dtype=self.lora_config.lora_dtype)
|
|
logit_ones = torch.ones(1,
|
|
self.lora_config.max_lora_rank,
|
|
dtype=self.lora_config.lora_dtype)
|
|
|
|
for i in range(len(lora_ids)):
|
|
if lora_ids[i] == 0:
|
|
continue
|
|
lora_index = self.lora_manager._adapter_manager.\
|
|
lora_index_to_id.index(lora_ids[i])
|
|
start_row = i * input_tokens.shape[1]
|
|
end_row = start_row + input_tokens.shape[1]
|
|
start_col = lora_index * self.lora_config.max_lora_rank
|
|
end_col = start_col + self.lora_config.max_lora_rank
|
|
lora_mask[start_row:end_row, start_col:end_col] = ones
|
|
lora_logits_mask[i, start_col:end_col] = logit_ones
|
|
lora_mask = lora_mask.to('hpu')
|
|
lora_logits_mask = lora_logits_mask.to('hpu')
|
|
else:
|
|
lora_mask = torch.zeros(input_tokens.shape[0],
|
|
(self.lora_config.max_loras) *
|
|
self.lora_config.max_lora_rank,
|
|
dtype=self.lora_config.lora_dtype)
|
|
ones = torch.ones(1,
|
|
self.lora_config.max_lora_rank,
|
|
dtype=self.lora_config.lora_dtype)
|
|
for i in range(len(lora_ids)):
|
|
if lora_ids[i] == 0:
|
|
continue
|
|
lora_index = self.lora_manager._adapter_manager.\
|
|
lora_index_to_id.index(lora_ids[i])
|
|
start_pos = lora_index * self.lora_config.max_lora_rank
|
|
end_pos = start_pos + self.lora_config.max_lora_rank
|
|
lora_mask[i, start_pos:end_pos] = ones
|
|
lora_mask = lora_mask.to('hpu')
|
|
lora_logits_mask = lora_mask
|
|
|
|
return lora_mask, lora_logits_mask
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
model_input: ModelInputForHPUWithSamplingMetadata,
|
|
kv_caches: List[torch.Tensor],
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
num_steps: int = 1,
|
|
warmup_mode=False,
|
|
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
|
if num_steps > 1:
|
|
raise ValueError(
|
|
"num_steps > 1 is not supported in HPUModelRunner")
|
|
|
|
if self.lora_config:
|
|
assert model_input.lora_requests is not None
|
|
assert model_input.lora_mapping is not None
|
|
self.set_active_loras(model_input.lora_requests,
|
|
model_input.lora_mapping)
|
|
input_tokens = model_input.input_tokens
|
|
input_positions = model_input.input_positions
|
|
attn_metadata = model_input.attn_metadata
|
|
sampling_metadata = model_input.sampling_metadata
|
|
real_batch_size = model_input.real_batch_size
|
|
batch_size_padded = model_input.batch_size_padded
|
|
assert input_tokens is not None
|
|
assert input_positions is not None
|
|
assert sampling_metadata is not None
|
|
assert attn_metadata is not None
|
|
is_prompt = attn_metadata.is_prompt
|
|
assert is_prompt is not None
|
|
batch_size = input_tokens.size(0)
|
|
seq_len = self._seq_len(attn_metadata)
|
|
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
|
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
|
|
|
|
lora_mask: torch.Tensor = None
|
|
lora_logits_mask: torch.Tensor = None
|
|
if self.lora_config:
|
|
assert model_input.lora_ids is not None
|
|
lora_mask, lora_logits_mask = self.create_lora_mask(
|
|
input_tokens, model_input.lora_ids, attn_metadata.is_prompt)
|
|
|
|
execute_model_kwargs = {
|
|
"input_ids": input_tokens,
|
|
"positions": input_positions,
|
|
"kv_caches": kv_caches,
|
|
"attn_metadata": self.trim_attn_metadata(attn_metadata),
|
|
"intermediate_tensors": intermediate_tensors,
|
|
"lora_mask": lora_mask,
|
|
**(model_input.multi_modal_kwargs or {}),
|
|
}
|
|
if htorch.utils.internal.is_lazy():
|
|
execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs})
|
|
|
|
htorch.core.mark_step()
|
|
if self.is_driver_worker:
|
|
model_event_name = ("model_"
|
|
f"{'prompt' if is_prompt else 'decode'}_"
|
|
f"bs{batch_size}_"
|
|
f"seq{seq_len}_"
|
|
f"graphs{'T' if use_graphs else 'F'}")
|
|
else:
|
|
model_event_name = 'model_executable'
|
|
with self.profiler.record_event('internal', model_event_name):
|
|
hidden_states = self.model.forward(
|
|
**execute_model_kwargs,
|
|
selected_token_indices=sampling_metadata.selected_token_indices
|
|
)
|
|
|
|
if self.lora_config:
|
|
LoraMask.setLoraMask(
|
|
lora_logits_mask.index_select(
|
|
0, sampling_metadata.selected_token_indices))
|
|
|
|
# Compute the logits.
|
|
with self.profiler.record_event(
|
|
'internal', ('compute_logits_'
|
|
f'{"prompt" if is_prompt else "decode"}_bs'
|
|
f'{batch_size}_'
|
|
f'seq{seq_len}')):
|
|
sampling_metadata.selected_token_indices = None
|
|
logits = self.model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
htorch.core.mark_step()
|
|
# Only perform sampling in the driver worker.
|
|
if not self.is_driver_worker:
|
|
return []
|
|
|
|
if model_input.async_callback is not None:
|
|
model_input.async_callback()
|
|
|
|
# Sample the next token.
|
|
with self.profiler.record_event(
|
|
'internal', ('sample_'
|
|
f'{"prompt" if is_prompt else "decode"}_'
|
|
f'bs{batch_size}_'
|
|
f'seq{seq_len}')):
|
|
output = self.model.sample(
|
|
logits=logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
output.outputs = output.outputs[:real_batch_size]
|
|
htorch.core.mark_step()
|
|
|
|
if self.is_driver_worker and self.profiler.enabled:
|
|
# Stop recording 'execute_model' event
|
|
self.profiler.end()
|
|
event_end = self.profiler.get_timestamp_us()
|
|
counters = self.profiler_counter_helper.get_counter_dict(
|
|
cache_config=self.cache_config,
|
|
duration=event_end - self.event_start,
|
|
seq_len=seq_len,
|
|
batch_size_padded=batch_size_padded,
|
|
real_batch_size=real_batch_size,
|
|
is_prompt=is_prompt)
|
|
self.profiler.record_counter(self.event_start, counters)
|
|
return [output]
|
|
|
|
def shutdown_inc(self):
|
|
can_finalize_inc = False
|
|
from contextlib import suppress
|
|
with suppress(AttributeError):
|
|
can_finalize_inc = (self.model_config.quantization == 'inc') and \
|
|
(self.model.model is not None) and \
|
|
self.inc_initialized_successfully and \
|
|
not getattr(self, "_is_inc_finalized", False)
|
|
if can_finalize_inc:
|
|
from neural_compressor.torch.quantization import (
|
|
finalize_calibration)
|
|
finalize_calibration(self.model.model)
|
|
self._is_inc_finalized = True
|
|
|
|
def __del__(self):
|
|
self.shutdown_inc()
|