[FlashInfer] Upgrade to 0.2.0 (#11194)
Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
3f1fc7425a
commit
2bc3fbba0c
@ -183,7 +183,16 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/v1
|
- tests/v1
|
||||||
commands:
|
commands:
|
||||||
- VLLM_USE_V1=1 pytest -v -s v1
|
# split the test to avoid interference
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/core
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||||
|
# TODO: accuracy does not match, whether setting
|
||||||
|
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
||||||
|
|
||||||
- label: Examples Test # 25min
|
- label: Examples Test # 25min
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
|
23
Dockerfile
23
Dockerfile
@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
# image with vLLM installed
|
# image with vLLM installed
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
|
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||||
|
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||||
ARG CUDA_VERSION=12.4.1
|
ARG CUDA_VERSION=12.4.1
|
||||||
ARG PYTHON_VERSION=3.12
|
ARG PYTHON_VERSION=3.12
|
||||||
WORKDIR /vllm-workspace
|
WORKDIR /vllm-workspace
|
||||||
@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
|||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install dist/*.whl --verbose
|
python3 -m pip install dist/*.whl --verbose
|
||||||
|
|
||||||
|
# How to build this FlashInfer wheel:
|
||||||
|
# $ export FLASHINFER_ENABLE_AOT=1
|
||||||
|
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
|
||||||
|
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
|
||||||
|
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
|
||||||
|
# $ cd flashinfer
|
||||||
|
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
|
||||||
|
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
. /etc/environment && \
|
. /etc/environment && \
|
||||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
||||||
fi
|
fi
|
||||||
COPY examples examples
|
COPY examples examples
|
||||||
|
|
||||||
|
# Although we build Flashinfer with AOT mode, there's still
|
||||||
|
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||||
|
# install build dependencies for JIT compilation.
|
||||||
|
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||||
|
COPY requirements-build.txt requirements-build.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
python3 -m pip install -r requirements-build.txt
|
||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
#################### TEST IMAGE ####################
|
#################### TEST IMAGE ####################
|
||||||
|
@ -61,9 +61,10 @@ def test_models(
|
|||||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||||
|
|
||||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
if backend in ("XFORMERS",
|
||||||
|
"FLASHINFER") and model == "google/gemma-2-2b-it":
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"XFORMERS does not support gemma2 with full context length.")
|
f"{backend} does not support gemma2 with full context length.")
|
||||||
|
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ test_settings = [
|
|||||||
model_args=["--task", "embed"],
|
model_args=["--task", "embed"],
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
attn_backend="FLASHINFER",
|
attn_backend="FLASH_ATTN",
|
||||||
method="encode",
|
method="encode",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
),
|
),
|
||||||
|
@ -133,7 +133,7 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
use_tensor_cores=(
|
use_tensor_cores=(
|
||||||
(num_query_heads//num_kv_heads) > 4)
|
(num_query_heads//num_kv_heads) > 4)
|
||||||
)
|
)
|
||||||
wrapper.begin_forward(kv_indptr,
|
wrapper.plan(kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
kv_last_page_lens,
|
kv_last_page_lens,
|
||||||
num_query_heads,
|
num_query_heads,
|
||||||
@ -141,9 +141,11 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
head_size,
|
head_size,
|
||||||
block_size,
|
block_size,
|
||||||
"NONE",
|
"NONE",
|
||||||
data_type=dtype)
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
|
logits_soft_cap=soft_cap)
|
||||||
|
|
||||||
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
|
output = wrapper.run(query, key_value_cache)
|
||||||
|
|
||||||
ref_output = ref_paged_attn(query=query,
|
ref_output = ref_paged_attn(query=query,
|
||||||
key_cache=key_cache,
|
key_cache=key_cache,
|
||||||
@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
|||||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
workspace_buffer, "NHD")
|
workspace_buffer, "NHD")
|
||||||
wrapper.begin_forward(
|
wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
|||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
block_size,
|
block_size,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
|
logits_soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = wrapper.forward(
|
output = wrapper.run(
|
||||||
query,
|
query,
|
||||||
key_value_cache,
|
key_value_cache,
|
||||||
logits_soft_cap=soft_cap,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_output = ref_paged_attn(query=query,
|
ref_output = ref_paged_attn(query=query,
|
||||||
@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
soft_cap=soft_cap)
|
soft_cap=soft_cap)
|
||||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|
||||||
@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
|||||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
workspace_buffer, "NHD")
|
workspace_buffer, "NHD")
|
||||||
wrapper.begin_forward(
|
wrapper.plan(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
|||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
block_size,
|
block_size,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=kv_cache_dtype,
|
||||||
|
logits_soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = wrapper.forward(query,
|
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||||
kv_cache_fp8,
|
|
||||||
logits_soft_cap=soft_cap,
|
|
||||||
k_scale=k_scale,
|
|
||||||
v_scale=v_scale)
|
|
||||||
|
|
||||||
ref_output = ref_paged_attn(query=query,
|
ref_output = ref_paged_attn(query=query,
|
||||||
key_cache=key_cache.squeeze(1),
|
key_cache=key_cache.squeeze(1),
|
||||||
@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
|||||||
del query
|
del query
|
||||||
del block_tables
|
del block_tables
|
||||||
# verify prefill fp8
|
# verify prefill fp8
|
||||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|
||||||
@ -439,7 +442,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
|||||||
wrapper = flashinfer.\
|
wrapper = flashinfer.\
|
||||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||||
use_tensor_cores=use_tensor_cores)
|
use_tensor_cores=use_tensor_cores)
|
||||||
wrapper.begin_forward(kv_indptr,
|
wrapper.plan(kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
kv_last_page_lens,
|
kv_last_page_lens,
|
||||||
num_query_heads,
|
num_query_heads,
|
||||||
@ -447,13 +450,10 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
|||||||
head_size,
|
head_size,
|
||||||
block_size,
|
block_size,
|
||||||
"NONE",
|
"NONE",
|
||||||
data_type=dtype,
|
q_data_type=dtype,
|
||||||
q_data_type=dtype)
|
kv_data_type=kv_cache_dtype,
|
||||||
output = wrapper.forward(query,
|
logits_soft_cap=soft_cap)
|
||||||
kv_cache_fp8,
|
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||||
logits_soft_cap=soft_cap,
|
|
||||||
k_scale=k_scale,
|
|
||||||
v_scale=v_scale)
|
|
||||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import dataclasses
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -13,6 +14,8 @@ try:
|
|||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
# Avoid turning these types into variables during type checking
|
||||||
|
if not TYPE_CHECKING:
|
||||||
BatchDecodeWithPagedKVCacheWrapper = None
|
BatchDecodeWithPagedKVCacheWrapper = None
|
||||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||||
BatchPrefillWithPagedKVCacheWrapper = None
|
BatchPrefillWithPagedKVCacheWrapper = None
|
||||||
@ -30,7 +33,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx,
|
||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||||
make_tensor_with_pad)
|
make_tensor_with_pad)
|
||||||
|
|
||||||
@ -99,6 +104,72 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PerLayerParameters:
|
||||||
|
"""
|
||||||
|
Currently, FlashInfer backend only support models in which all layers share
|
||||||
|
the same values for the following hyperparameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
window_left: int
|
||||||
|
logits_soft_cap: Optional[float]
|
||||||
|
sm_scale: float
|
||||||
|
|
||||||
|
|
||||||
|
def get_per_layer_parameters(
|
||||||
|
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
|
||||||
|
"""
|
||||||
|
Scan all attention layers and determine some hyperparameters
|
||||||
|
to use during `plan`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
layers = vllm_config.compilation_config.static_forward_context
|
||||||
|
per_layer_params: Dict[str, PerLayerParameters] = {}
|
||||||
|
|
||||||
|
for key, layer in layers.items():
|
||||||
|
assert isinstance(layer, Attention)
|
||||||
|
|
||||||
|
impl = layer.impl
|
||||||
|
assert isinstance(impl, FlashInferImpl)
|
||||||
|
|
||||||
|
# Infer hyperparameters from the attention layer
|
||||||
|
window_size = impl.sliding_window
|
||||||
|
window_left = window_size[0] if window_size is not None else -1
|
||||||
|
logits_soft_cap = impl.logits_soft_cap
|
||||||
|
sm_scale = impl.scale
|
||||||
|
|
||||||
|
per_layer_params[key] = PerLayerParameters(window_left,
|
||||||
|
logits_soft_cap, sm_scale)
|
||||||
|
|
||||||
|
return per_layer_params
|
||||||
|
|
||||||
|
|
||||||
|
def infer_global_hyperparameters(
|
||||||
|
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||||
|
"""
|
||||||
|
Currently, FlashInfer backend only support models in which all layers share
|
||||||
|
the same values for the following hyperparameters:
|
||||||
|
- `window_left`
|
||||||
|
- `logits_soft_cap`
|
||||||
|
- `sm_scale`
|
||||||
|
|
||||||
|
So this function asserts that all layers share the same values for these
|
||||||
|
hyperparameters and returns the global values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
||||||
|
|
||||||
|
param_sets = list(per_layer_params.values())
|
||||||
|
global_params = param_sets[0]
|
||||||
|
for params in param_sets:
|
||||||
|
assert params == global_params, (
|
||||||
|
"FlashInfer backend currently only supports models in which all "
|
||||||
|
"layers share the same values for the following hyperparameters: "
|
||||||
|
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||||
|
|
||||||
|
return global_params
|
||||||
|
|
||||||
|
|
||||||
class FlashInferState(AttentionState):
|
class FlashInferState(AttentionState):
|
||||||
|
|
||||||
def __init__(self, runner):
|
def __init__(self, runner):
|
||||||
@ -108,6 +179,11 @@ class FlashInferState(AttentionState):
|
|||||||
self._decode_wrapper = None
|
self._decode_wrapper = None
|
||||||
self._prefill_wrapper = None
|
self._prefill_wrapper = None
|
||||||
|
|
||||||
|
# Global hyperparameters shared by all attention layers
|
||||||
|
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||||
|
|
||||||
|
self.vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
self._workspace_buffer = torch.empty(
|
self._workspace_buffer = torch.empty(
|
||||||
@ -215,6 +291,9 @@ class FlashInferState(AttentionState):
|
|||||||
batch_size + 1,
|
batch_size + 1,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
global_params = infer_global_hyperparameters(
|
||||||
|
get_per_layer_parameters(self.vllm_config))
|
||||||
|
|
||||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||||
@ -238,7 +317,9 @@ class FlashInferState(AttentionState):
|
|||||||
q_data_type=self.runner.model_config.dtype,
|
q_data_type=self.runner.model_config.dtype,
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
decode_wrapper=self._graph_decode_wrapper,
|
decode_wrapper=self._graph_decode_wrapper,
|
||||||
prefill_wrapper=None)
|
prefill_wrapper=None,
|
||||||
|
**dataclasses.asdict(global_params),
|
||||||
|
)
|
||||||
attn_metadata.begin_forward()
|
attn_metadata.begin_forward()
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@ -325,9 +406,28 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
data_type: torch.dtype = None
|
data_type: torch.dtype = None
|
||||||
# The data type of the query
|
# The data type of the query
|
||||||
q_data_type: torch.dtype = None
|
q_data_type: torch.dtype = None
|
||||||
device: torch.device = torch.device("cuda")
|
# FlashInfer 0.2 encourages passing host tensors
|
||||||
|
device: torch.device = torch.device("cpu")
|
||||||
is_profile_run: bool = False
|
is_profile_run: bool = False
|
||||||
|
|
||||||
|
# The FlashInfer backend currently supports only models in which all layers
|
||||||
|
# share the same following hyperparameters:
|
||||||
|
|
||||||
|
# The left (inclusive) window size for the attention window, when
|
||||||
|
# set to `-1`, the window size will be set to the full length of
|
||||||
|
# the sequence. Defaults to `-1`.
|
||||||
|
window_left: int = -1
|
||||||
|
# The attention logits soft capping value (used in Gemini, Grok and
|
||||||
|
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
|
||||||
|
# than 0, the logits will be capped according to formula:
|
||||||
|
# $$\texttt{logits\_soft\_cap} \times
|
||||||
|
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
|
||||||
|
# where $x$ is the input logits.
|
||||||
|
logits_soft_cap: Optional[float] = None
|
||||||
|
# The scale used in softmax, if not provided, will be set to
|
||||||
|
# `1.0 / sqrt(head_dim)`.
|
||||||
|
sm_scale: Optional[float] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Refer to
|
# Refer to
|
||||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||||
@ -363,14 +463,21 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.block_table_bound = self.block_table_bound.to(self.device)
|
self.block_table_bound = self.block_table_bound.to(self.device)
|
||||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||||
self.prefill_wrapper.end_forward()
|
self.prefill_wrapper.plan(
|
||||||
self.prefill_wrapper.begin_forward(
|
|
||||||
self.query_start_loc,
|
self.query_start_loc,
|
||||||
self.paged_kv_indptr[:self.num_prefills + 1],
|
self.paged_kv_indptr[:self.num_prefills + 1],
|
||||||
self.paged_kv_indices,
|
self.paged_kv_indices,
|
||||||
self.paged_kv_last_page_len[:self.num_prefills],
|
self.paged_kv_last_page_len[:self.num_prefills],
|
||||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
self.num_qo_heads,
|
||||||
self.page_size)
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.page_size,
|
||||||
|
causal=True,
|
||||||
|
sm_scale=self.sm_scale,
|
||||||
|
window_left=self.window_left,
|
||||||
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
kv_data_type=self.data_type)
|
||||||
if self.num_decode_tokens > 0:
|
if self.num_decode_tokens > 0:
|
||||||
assert self.paged_kv_indices is not None
|
assert self.paged_kv_indices is not None
|
||||||
assert self.paged_kv_indptr is not None
|
assert self.paged_kv_indptr is not None
|
||||||
@ -386,8 +493,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
|
||||||
|
|
||||||
assert self.decode_wrapper is not None
|
assert self.decode_wrapper is not None
|
||||||
self.decode_wrapper.end_forward()
|
self.decode_wrapper.plan(
|
||||||
self.decode_wrapper.begin_forward(
|
|
||||||
self.paged_kv_indptr[self.num_prefills:],
|
self.paged_kv_indptr[self.num_prefills:],
|
||||||
self.paged_kv_indices,
|
self.paged_kv_indices,
|
||||||
self.paged_kv_last_page_len[self.num_prefills:],
|
self.paged_kv_last_page_len[self.num_prefills:],
|
||||||
@ -397,8 +503,11 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
pos_encoding_mode="NONE",
|
pos_encoding_mode="NONE",
|
||||||
|
window_left=self.window_left,
|
||||||
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
sm_scale=self.sm_scale,
|
||||||
# kv-cache data type.
|
# kv-cache data type.
|
||||||
data_type=self.data_type,
|
kv_data_type=self.data_type,
|
||||||
# query data type.
|
# query data type.
|
||||||
q_data_type=self.q_data_type)
|
q_data_type=self.q_data_type)
|
||||||
|
|
||||||
@ -496,6 +605,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.sliding_window = input_builder.sliding_window
|
self.sliding_window = input_builder.sliding_window
|
||||||
self.block_size = input_builder.block_size
|
self.block_size = input_builder.block_size
|
||||||
|
|
||||||
|
# Global hyperparameters shared by all attention layers
|
||||||
|
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||||
|
|
||||||
|
self.vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
self.slot_mapping: List[int] = []
|
self.slot_mapping: List[int] = []
|
||||||
self.prefill_seq_lens: List[int] = []
|
self.prefill_seq_lens: List[int] = []
|
||||||
@ -528,6 +642,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.total_blocks = 0
|
self.total_blocks = 0
|
||||||
self.is_profile_run: bool = False
|
self.is_profile_run: bool = False
|
||||||
|
|
||||||
|
if self.global_hyperparameters is None:
|
||||||
|
# Infer global hyperparameters, since currently we only support
|
||||||
|
# models in which all layers share the same values for the
|
||||||
|
# following hyperparameters:
|
||||||
|
# - `window_left`
|
||||||
|
# - `logits_soft_cap`
|
||||||
|
# - `sm_scale`
|
||||||
|
inferred_params = infer_global_hyperparameters(
|
||||||
|
get_per_layer_parameters(self.vllm_config))
|
||||||
|
self.global_hyperparameters = inferred_params
|
||||||
|
self.window_left = inferred_params.window_left
|
||||||
|
self.logits_soft_cap = inferred_params.logits_soft_cap
|
||||||
|
self.sm_scale = inferred_params.sm_scale
|
||||||
|
|
||||||
def _add_seq_group(
|
def _add_seq_group(
|
||||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||||
chunked_prefill_enabled: bool):
|
chunked_prefill_enabled: bool):
|
||||||
@ -756,7 +884,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
data_type=kv_cache_dtype,
|
data_type=kv_cache_dtype,
|
||||||
q_data_type=self.runner.model_config.dtype,
|
q_data_type=self.runner.model_config.dtype,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
is_profile_run=self.is_profile_run)
|
is_profile_run=self.is_profile_run,
|
||||||
|
window_left=self.window_left,
|
||||||
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
sm_scale=self.sm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferImpl(AttentionImpl):
|
class FlashInferImpl(AttentionImpl):
|
||||||
@ -885,25 +1017,34 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
assert prefill_meta is not None
|
assert prefill_meta is not None
|
||||||
assert prefill_meta.prefill_wrapper is not None
|
assert prefill_meta.prefill_wrapper is not None
|
||||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
|
||||||
|
assert prefill_meta.prefill_wrapper._causal
|
||||||
|
assert prefill_meta.prefill_wrapper._window_left == window_left
|
||||||
|
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
|
||||||
|
logits_soft_cap or 0.0)
|
||||||
|
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale
|
||||||
|
|
||||||
|
prefill_output = prefill_meta.prefill_wrapper.run(
|
||||||
query,
|
query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
causal=True,
|
|
||||||
k_scale=layer._k_scale_float,
|
k_scale=layer._k_scale_float,
|
||||||
v_scale=layer._v_scale_float,
|
v_scale=layer._v_scale_float,
|
||||||
window_left=window_left)
|
)
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
assert decode_meta.decode_wrapper is not None
|
assert decode_meta.decode_wrapper is not None
|
||||||
decode_output = decode_meta.decode_wrapper.forward(
|
|
||||||
|
assert decode_meta.decode_wrapper._window_left == window_left
|
||||||
|
assert decode_meta.decode_wrapper._logits_soft_cap == (
|
||||||
|
logits_soft_cap or 0.0)
|
||||||
|
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
||||||
|
|
||||||
|
decode_output = decode_meta.decode_wrapper.run(
|
||||||
decode_query,
|
decode_query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
sm_scale=softmax_scale,
|
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
k_scale=layer._k_scale_float,
|
k_scale=layer._k_scale_float,
|
||||||
v_scale=layer._v_scale_float,
|
v_scale=layer._v_scale_float,
|
||||||
window_left=window_left)
|
)
|
||||||
|
|
||||||
if prefill_output is None and decode_output is not None:
|
if prefill_output is None and decode_output is not None:
|
||||||
# Decode only batch.
|
# Decode only batch.
|
||||||
|
@ -310,14 +310,15 @@ class ModelConfig:
|
|||||||
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
|
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
|
||||||
|
|
||||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||||
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
|
if (backend :=
|
||||||
|
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
|
||||||
sliding_window_len_min = get_min_sliding_window(
|
sliding_window_len_min = get_min_sliding_window(
|
||||||
self.hf_text_config.sliding_window)
|
self.hf_text_config.sliding_window)
|
||||||
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"{self.hf_text_config.model_type} has interleaved "
|
f"{self.hf_text_config.model_type} has interleaved "
|
||||||
"attention, which is currently not supported by the "
|
"attention, which is currently not supported by the "
|
||||||
"XFORMERS backend. Disabling sliding window and capping "
|
f"{backend} backend. Disabling sliding window and capping "
|
||||||
"the max length to the sliding window size "
|
"the max length to the sliding window size "
|
||||||
f"({sliding_window_len_min}).")
|
f"({sliding_window_len_min}).")
|
||||||
self.disable_sliding_window = True
|
self.disable_sliding_window = True
|
||||||
@ -3310,7 +3311,7 @@ _current_vllm_config: Optional[VllmConfig] = None
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_current_vllm_config(vllm_config: VllmConfig):
|
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||||
"""
|
"""
|
||||||
Temporarily set the current VLLM config.
|
Temporarily set the current VLLM config.
|
||||||
Used during model initialization.
|
Used during model initialization.
|
||||||
@ -3330,7 +3331,8 @@ def set_current_vllm_config(vllm_config: VllmConfig):
|
|||||||
vllm_config.compilation_config.enabled_custom_ops)
|
vllm_config.compilation_config.enabled_custom_ops)
|
||||||
logger.debug("disabled custom ops: %s",
|
logger.debug("disabled custom ops: %s",
|
||||||
vllm_config.compilation_config.disabled_custom_ops)
|
vllm_config.compilation_config.disabled_custom_ops)
|
||||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
if check_compile and \
|
||||||
|
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||||
and compilation_counter.num_models_seen == num_models_seen:
|
and compilation_counter.num_models_seen == num_models_seen:
|
||||||
# If the model supports compilation,
|
# If the model supports compilation,
|
||||||
# compilation_counter.num_models_seen should be increased
|
# compilation_counter.num_models_seen should be increased
|
||||||
|
@ -114,7 +114,7 @@ def _initialize_model(
|
|||||||
all_params = [param.name for param in signatures.parameters.values()]
|
all_params = [param.name for param in signatures.parameters.values()]
|
||||||
if "vllm_config" in all_params and "prefix" in all_params:
|
if "vllm_config" in all_params and "prefix" in all_params:
|
||||||
# new-style model class
|
# new-style model class
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||||
@ -142,7 +142,7 @@ def _initialize_model(
|
|||||||
kwargs["lora_config"] = vllm_config.lora_config
|
kwargs["lora_config"] = vllm_config.lora_config
|
||||||
if "scheduler_config" in all_params:
|
if "scheduler_config" in all_params:
|
||||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||||
return model_class(**kwargs)
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -288,7 +288,8 @@ class TensorizerAgent:
|
|||||||
model_args.torch_dtype = self.tensorizer_config.dtype
|
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||||
assert self.tensorizer_config.model_class is not None
|
assert self.tensorizer_config.model_class is not None
|
||||||
# TODO: Do we need to consider old-style model class?
|
# TODO: Do we need to consider old-style model class?
|
||||||
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config):
|
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config,
|
||||||
|
check_compile=True):
|
||||||
return self.tensorizer_config.model_class(
|
return self.tensorizer_config.model_class(
|
||||||
vllm_config=self.vllm_config, )
|
vllm_config=self.vllm_config, )
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ import cloudpickle
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import ObservabilityConfig, VllmConfig
|
from vllm.config import (ObservabilityConfig, VllmConfig,
|
||||||
|
set_current_vllm_config)
|
||||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -498,8 +499,11 @@ class WorkerWrapperBase:
|
|||||||
group.
|
group.
|
||||||
"""
|
"""
|
||||||
self.rpc_rank = rpc_rank
|
self.rpc_rank = rpc_rank
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.worker: Optional[WorkerBase] = None
|
self.worker: Optional[WorkerBase] = None
|
||||||
|
# do not store this `vllm_config`, `init_worker` will set the final
|
||||||
|
# one. TODO: investigate if we can remove this field in
|
||||||
|
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
||||||
|
# unnecessary now.
|
||||||
if vllm_config.model_config is not None:
|
if vllm_config.model_config is not None:
|
||||||
# it can be None in tests
|
# it can be None in tests
|
||||||
trust_remote_code = vllm_config.model_config.trust_remote_code
|
trust_remote_code = vllm_config.model_config.trust_remote_code
|
||||||
@ -533,6 +537,9 @@ class WorkerWrapperBase:
|
|||||||
Arguments are passed to the worker class constructor.
|
Arguments are passed to the worker class constructor.
|
||||||
"""
|
"""
|
||||||
kwargs = all_kwargs[self.rpc_rank]
|
kwargs = all_kwargs[self.rpc_rank]
|
||||||
|
self.vllm_config = kwargs.get("vllm_config", None)
|
||||||
|
assert self.vllm_config is not None, (
|
||||||
|
"vllm_config is required to initialize the worker")
|
||||||
enable_trace_function_call_for_thread(self.vllm_config)
|
enable_trace_function_call_for_thread(self.vllm_config)
|
||||||
|
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
@ -546,6 +553,8 @@ class WorkerWrapperBase:
|
|||||||
bytes)
|
bytes)
|
||||||
worker_class = cloudpickle.loads(
|
worker_class = cloudpickle.loads(
|
||||||
self.vllm_config.parallel_config.worker_cls)
|
self.vllm_config.parallel_config.worker_cls)
|
||||||
|
with set_current_vllm_config(self.vllm_config):
|
||||||
|
# To make vLLM config available during worker initialization
|
||||||
self.worker = worker_class(**kwargs)
|
self.worker = worker_class(**kwargs)
|
||||||
assert self.worker is not None
|
assert self.worker is not None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user