[FlashInfer] Upgrade to 0.2.0 (#11194)
Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
3f1fc7425a
commit
2bc3fbba0c
@ -183,7 +183,16 @@ steps:
|
||||
- vllm/
|
||||
- tests/v1
|
||||
commands:
|
||||
- VLLM_USE_V1=1 pytest -v -s v1
|
||||
# split the test to avoid interference
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/core
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
||||
|
||||
- label: Examples Test # 25min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
|
23
Dockerfile
23
Dockerfile
@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
WORKDIR /vllm-workspace
|
||||
@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
# How to build this FlashInfer wheel:
|
||||
# $ export FLASHINFER_ENABLE_AOT=1
|
||||
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
|
||||
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
|
||||
# $ cd flashinfer
|
||||
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
|
||||
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
||||
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
||||
fi
|
||||
COPY examples examples
|
||||
|
||||
# Although we build Flashinfer with AOT mode, there's still
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
# install build dependencies for JIT compilation.
|
||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-build.txt
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
|
@ -61,9 +61,10 @@ def test_models(
|
||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||
|
||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||
if backend in ("XFORMERS",
|
||||
"FLASHINFER") and model == "google/gemma-2-2b-it":
|
||||
pytest.skip(
|
||||
"XFORMERS does not support gemma2 with full context length.")
|
||||
f"{backend} does not support gemma2 with full context length.")
|
||||
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
|
@ -58,7 +58,7 @@ test_settings = [
|
||||
model_args=["--task", "embed"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASHINFER",
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
fullgraph=True,
|
||||
),
|
||||
|
@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
use_tensor_cores=(
|
||||
(num_query_heads//num_kv_heads) > 4)
|
||||
)
|
||||
wrapper.begin_forward(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
data_type=dtype)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
|
||||
output = wrapper.run(query, key_value_cache)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD")
|
||||
wrapper.begin_forward(
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = wrapper.forward(
|
||||
output = wrapper.run(
|
||||
query,
|
||||
key_value_cache,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD")
|
||||
wrapper.begin_forward(
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = wrapper.forward(query,
|
||||
kv_cache_fp8,
|
||||
logits_soft_cap=soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache.squeeze(1),
|
||||
@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
del query
|
||||
del block_tables
|
||||
# verify prefill fp8
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
wrapper.begin_forward(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
data_type=dtype,
|
||||
q_data_type=dtype)
|
||||
output = wrapper.forward(query,
|
||||
kv_cache_fp8,
|
||||
logits_soft_cap=soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
@ -13,9 +14,11 @@ try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
# Avoid turning these types into variables during type checking
|
||||
if not TYPE_CHECKING:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import torch
|
||||
@ -30,7 +33,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
|
||||
@ -99,6 +104,72 @@ class FlashInferBackend(AttentionBackend):
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
|
||||
"""
|
||||
Scan all attention layers and determine some hyperparameters
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = vllm_config.compilation_config.static_forward_context
|
||||
per_layer_params: Dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
assert isinstance(layer, Attention)
|
||||
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, FlashInferImpl)
|
||||
|
||||
# Infer hyperparameters from the attention layer
|
||||
window_size = impl.sliding_window
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = impl.logits_soft_cap
|
||||
sm_scale = impl.scale
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(window_left,
|
||||
logits_soft_cap, sm_scale)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
- `sm_scale`
|
||||
|
||||
So this function asserts that all layers share the same values for these
|
||||
hyperparameters and returns the global values.
|
||||
"""
|
||||
|
||||
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
class FlashInferState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
@ -108,6 +179,11 @@ class FlashInferState(AttentionState):
|
||||
self._decode_wrapper = None
|
||||
self._prefill_wrapper = None
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
@ -215,6 +291,9 @@ class FlashInferState(AttentionState):
|
||||
batch_size + 1,
|
||||
dtype=torch.int32)
|
||||
|
||||
global_params = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config))
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
@ -238,7 +317,9 @@ class FlashInferState(AttentionState):
|
||||
q_data_type=self.runner.model_config.dtype,
|
||||
use_cuda_graph=True,
|
||||
decode_wrapper=self._graph_decode_wrapper,
|
||||
prefill_wrapper=None)
|
||||
prefill_wrapper=None,
|
||||
**dataclasses.asdict(global_params),
|
||||
)
|
||||
attn_metadata.begin_forward()
|
||||
return attn_metadata
|
||||
|
||||
@ -325,9 +406,28 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
data_type: torch.dtype = None
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype = None
|
||||
device: torch.device = torch.device("cuda")
|
||||
# FlashInfer 0.2 encourages passing host tensors
|
||||
device: torch.device = torch.device("cpu")
|
||||
is_profile_run: bool = False
|
||||
|
||||
# The FlashInfer backend currently supports only models in which all layers
|
||||
# share the same following hyperparameters:
|
||||
|
||||
# The left (inclusive) window size for the attention window, when
|
||||
# set to `-1`, the window size will be set to the full length of
|
||||
# the sequence. Defaults to `-1`.
|
||||
window_left: int = -1
|
||||
# The attention logits soft capping value (used in Gemini, Grok and
|
||||
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
|
||||
# than 0, the logits will be capped according to formula:
|
||||
# $$\texttt{logits\_soft\_cap} \times
|
||||
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
|
||||
# where $x$ is the input logits.
|
||||
logits_soft_cap: Optional[float] = None
|
||||
# The scale used in softmax, if not provided, will be set to
|
||||
# `1.0 / sqrt(head_dim)`.
|
||||
sm_scale: Optional[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
@ -363,14 +463,21 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.prefill_wrapper.end_forward()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.prefill_wrapper.plan(
|
||||
self.query_start_loc,
|
||||
self.paged_kv_indptr[:self.num_prefills + 1],
|
||||
self.paged_kv_indices,
|
||||
self.paged_kv_last_page_len[:self.num_prefills],
|
||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||
self.page_size)
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.data_type)
|
||||
if self.num_decode_tokens > 0:
|
||||
assert self.paged_kv_indices is not None
|
||||
assert self.paged_kv_indptr is not None
|
||||
@ -386,8 +493,7 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||
|
||||
assert self.decode_wrapper is not None
|
||||
self.decode_wrapper.end_forward()
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.decode_wrapper.plan(
|
||||
self.paged_kv_indptr[self.num_prefills:],
|
||||
self.paged_kv_indices,
|
||||
self.paged_kv_last_page_len[self.num_prefills:],
|
||||
@ -397,8 +503,11 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
sm_scale=self.sm_scale,
|
||||
# kv-cache data type.
|
||||
data_type=self.data_type,
|
||||
kv_data_type=self.data_type,
|
||||
# query data type.
|
||||
q_data_type=self.q_data_type)
|
||||
|
||||
@ -496,6 +605,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
@ -528,6 +642,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.total_blocks = 0
|
||||
self.is_profile_run: bool = False
|
||||
|
||||
if self.global_hyperparameters is None:
|
||||
# Infer global hyperparameters, since currently we only support
|
||||
# models in which all layers share the same values for the
|
||||
# following hyperparameters:
|
||||
# - `window_left`
|
||||
# - `logits_soft_cap`
|
||||
# - `sm_scale`
|
||||
inferred_params = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config))
|
||||
self.global_hyperparameters = inferred_params
|
||||
self.window_left = inferred_params.window_left
|
||||
self.logits_soft_cap = inferred_params.logits_soft_cap
|
||||
self.sm_scale = inferred_params.sm_scale
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
@ -756,7 +884,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
data_type=kv_cache_dtype,
|
||||
q_data_type=self.runner.model_config.dtype,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
is_profile_run=self.is_profile_run)
|
||||
is_profile_run=self.is_profile_run,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
sm_scale=self.sm_scale,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
@ -885,25 +1017,34 @@ class FlashInferImpl(AttentionImpl):
|
||||
else:
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||
|
||||
assert prefill_meta.prefill_wrapper._causal
|
||||
assert prefill_meta.prefill_wrapper._window_left == window_left
|
||||
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
|
||||
logits_soft_cap or 0.0)
|
||||
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale
|
||||
|
||||
prefill_output = prefill_meta.prefill_wrapper.run(
|
||||
query,
|
||||
kv_cache,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
causal=True,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
window_left=window_left)
|
||||
)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_wrapper is not None
|
||||
decode_output = decode_meta.decode_wrapper.forward(
|
||||
|
||||
assert decode_meta.decode_wrapper._window_left == window_left
|
||||
assert decode_meta.decode_wrapper._logits_soft_cap == (
|
||||
logits_soft_cap or 0.0)
|
||||
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
||||
|
||||
decode_output = decode_meta.decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
window_left=window_left)
|
||||
)
|
||||
|
||||
if prefill_output is None and decode_output is not None:
|
||||
# Decode only batch.
|
||||
|
@ -310,14 +310,15 @@ class ModelConfig:
|
||||
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
|
||||
|
||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
|
||||
if (backend :=
|
||||
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
|
||||
sliding_window_len_min = get_min_sliding_window(
|
||||
self.hf_text_config.sliding_window)
|
||||
|
||||
logger.warning_once(
|
||||
f"{self.hf_text_config.model_type} has interleaved "
|
||||
"attention, which is currently not supported by the "
|
||||
"XFORMERS backend. Disabling sliding window and capping "
|
||||
f"{backend} backend. Disabling sliding window and capping "
|
||||
"the max length to the sliding window size "
|
||||
f"({sliding_window_len_min}).")
|
||||
self.disable_sliding_window = True
|
||||
@ -3310,7 +3311,7 @@ _current_vllm_config: Optional[VllmConfig] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_current_vllm_config(vllm_config: VllmConfig):
|
||||
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||
"""
|
||||
Temporarily set the current VLLM config.
|
||||
Used during model initialization.
|
||||
@ -3330,7 +3331,8 @@ def set_current_vllm_config(vllm_config: VllmConfig):
|
||||
vllm_config.compilation_config.enabled_custom_ops)
|
||||
logger.debug("disabled custom ops: %s",
|
||||
vllm_config.compilation_config.disabled_custom_ops)
|
||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||
if check_compile and \
|
||||
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||
and compilation_counter.num_models_seen == num_models_seen:
|
||||
# If the model supports compilation,
|
||||
# compilation_counter.num_models_seen should be increased
|
||||
|
@ -114,7 +114,7 @@ def _initialize_model(
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config):
|
||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
@ -142,7 +142,7 @@ def _initialize_model(
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config):
|
||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
|
@ -288,7 +288,8 @@ class TensorizerAgent:
|
||||
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||
assert self.tensorizer_config.model_class is not None
|
||||
# TODO: Do we need to consider old-style model class?
|
||||
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config):
|
||||
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config,
|
||||
check_compile=True):
|
||||
return self.tensorizer_config.model_class(
|
||||
vllm_config=self.vllm_config, )
|
||||
|
||||
|
@ -8,7 +8,8 @@ import cloudpickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ObservabilityConfig, VllmConfig
|
||||
from vllm.config import (ObservabilityConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -498,8 +499,11 @@ class WorkerWrapperBase:
|
||||
group.
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.vllm_config = vllm_config
|
||||
self.worker: Optional[WorkerBase] = None
|
||||
# do not store this `vllm_config`, `init_worker` will set the final
|
||||
# one. TODO: investigate if we can remove this field in
|
||||
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
||||
# unnecessary now.
|
||||
if vllm_config.model_config is not None:
|
||||
# it can be None in tests
|
||||
trust_remote_code = vllm_config.model_config.trust_remote_code
|
||||
@ -533,6 +537,9 @@ class WorkerWrapperBase:
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
kwargs = all_kwargs[self.rpc_rank]
|
||||
self.vllm_config = kwargs.get("vllm_config", None)
|
||||
assert self.vllm_config is not None, (
|
||||
"vllm_config is required to initialize the worker")
|
||||
enable_trace_function_call_for_thread(self.vllm_config)
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -546,8 +553,10 @@ class WorkerWrapperBase:
|
||||
bytes)
|
||||
worker_class = cloudpickle.loads(
|
||||
self.vllm_config.parallel_config.worker_cls)
|
||||
self.worker = worker_class(**kwargs)
|
||||
assert self.worker is not None
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during worker initialization
|
||||
self.worker = worker_class(**kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user