[FlashInfer] Upgrade to 0.2.0 (#11194)

Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Bowen Wang 2025-01-28 02:19:24 +08:00 committed by GitHub
parent 3f1fc7425a
commit 2bc3fbba0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 256 additions and 74 deletions

View File

@ -183,7 +183,16 @@ steps:
- vllm/
- tests/v1
commands:
- VLLM_USE_V1=1 pytest -v -s v1
# split the test to avoid interference
- VLLM_USE_V1=1 pytest -v -s v1/core
- VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample
- VLLM_USE_V1=1 pytest -v -s v1/worker
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
# TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
- VLLM_USE_V1=1 pytest -v -s v1/e2e
- label: Examples Test # 25min
working_dir: "/vllm-workspace/examples"

View File

@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
# TODO: Restore to base image after FlashInfer AOT wheel fixed
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
ARG CUDA_VERSION=12.4.1
ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace
@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose
# How to build this FlashInfer wheel:
# $ export FLASHINFER_ENABLE_AOT=1
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
# $ cd flashinfer
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose
RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
fi
COPY examples examples
# Although we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# install build dependencies for JIT compilation.
# TODO: Remove this once FlashInfer AOT wheel is fixed
COPY requirements-build.txt requirements-build.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt
#################### vLLM installation IMAGE ####################
#################### TEST IMAGE ####################

View File

@ -61,9 +61,10 @@ def test_models(
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
if backend in ("XFORMERS",
"FLASHINFER") and model == "google/gemma-2-2b-it":
pytest.skip(
"XFORMERS does not support gemma2 with full context length.")
f"{backend} does not support gemma2 with full context length.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend

View File

@ -58,7 +58,7 @@ test_settings = [
model_args=["--task", "embed"],
pp_size=1,
tp_size=1,
attn_backend="FLASHINFER",
attn_backend="FLASH_ATTN",
method="encode",
fullgraph=True,
),

View File

@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
output = wrapper.run(query, key_value_cache)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_kv_heads,
head_size,
block_size,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.forward(
output = wrapper.run(
query,
key_value_cache,
logits_soft_cap=soft_cap,
)
ref_output = ref_paged_attn(query=query,
@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
num_kv_heads,
head_size,
block_size,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache.squeeze(1),
@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
del query
del block_tables
# verify prefill fp8
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=use_tensor_cores)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype,
q_data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)

View File

@ -1,3 +1,4 @@
import dataclasses
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
@ -13,9 +14,11 @@ try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
# Avoid turning these types into variables during type checking
if not TYPE_CHECKING:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
@ -30,7 +33,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
@ -99,6 +104,72 @@ class FlashInferBackend(AttentionBackend):
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = vllm_config.compilation_config.static_forward_context
per_layer_params: Dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
class FlashInferState(AttentionState):
def __init__(self, runner):
@ -108,6 +179,11 @@ class FlashInferState(AttentionState):
self._decode_wrapper = None
self._prefill_wrapper = None
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config()
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
@ -215,6 +291,9 @@ class FlashInferState(AttentionState):
batch_size + 1,
dtype=torch.int32)
global_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
@ -238,7 +317,9 @@ class FlashInferState(AttentionState):
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
prefill_wrapper=None,
**dataclasses.asdict(global_params),
)
attn_metadata.begin_forward()
return attn_metadata
@ -325,9 +406,28 @@ class FlashInferMetadata(AttentionMetadata):
data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# FlashInfer 0.2 encourages passing host tensors
device: torch.device = torch.device("cpu")
is_profile_run: bool = False
# The FlashInfer backend currently supports only models in which all layers
# share the same following hyperparameters:
# The left (inclusive) window size for the attention window, when
# set to `-1`, the window size will be set to the full length of
# the sequence. Defaults to `-1`.
window_left: int = -1
# The attention logits soft capping value (used in Gemini, Grok and
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
# than 0, the logits will be capped according to formula:
# $$\texttt{logits\_soft\_cap} \times
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
# where $x$ is the input logits.
logits_soft_cap: Optional[float] = None
# The scale used in softmax, if not provided, will be set to
# `1.0 / sqrt(head_dim)`.
sm_scale: Optional[float] = None
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
@ -363,14 +463,21 @@ class FlashInferMetadata(AttentionMetadata):
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.prefill_wrapper.plan(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.data_type)
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
@ -386,8 +493,7 @@ class FlashInferMetadata(AttentionMetadata):
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.decode_wrapper.plan(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
@ -397,8 +503,11 @@ class FlashInferMetadata(AttentionMetadata):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
# kv-cache data type.
data_type=self.data_type,
kv_data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)
@ -496,6 +605,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config()
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
@ -528,6 +642,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.total_blocks = 0
self.is_profile_run: bool = False
if self.global_hyperparameters is None:
# Infer global hyperparameters, since currently we only support
# models in which all layers share the same values for the
# following hyperparameters:
# - `window_left`
# - `logits_soft_cap`
# - `sm_scale`
inferred_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
self.global_hyperparameters = inferred_params
self.window_left = inferred_params.window_left
self.logits_soft_cap = inferred_params.logits_soft_cap
self.sm_scale = inferred_params.sm_scale
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
@ -756,7 +884,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)
is_profile_run=self.is_profile_run,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
)
class FlashInferImpl(AttentionImpl):
@ -885,25 +1017,34 @@ class FlashInferImpl(AttentionImpl):
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
assert prefill_meta.prefill_wrapper._causal
assert prefill_meta.prefill_wrapper._window_left == window_left
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale
prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
window_left=window_left)
)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(
assert decode_meta.decode_wrapper._window_left == window_left
assert decode_meta.decode_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
window_left=window_left)
)
if prefill_output is None and decode_output is not None:
# Decode only batch.

View File

@ -310,14 +310,15 @@ class ModelConfig:
(self.hf_text_config.model_type in ["gemma2", "cohere2"]))
if (not self.disable_sliding_window and has_interleaved_attention):
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
logger.warning_once(
f"{self.hf_text_config.model_type} has interleaved "
"attention, which is currently not supported by the "
"XFORMERS backend. Disabling sliding window and capping "
f"{backend} backend. Disabling sliding window and capping "
"the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
@ -3310,7 +3311,7 @@ _current_vllm_config: Optional[VllmConfig] = None
@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
"""
Temporarily set the current VLLM config.
Used during model initialization.
@ -3330,7 +3331,8 @@ def set_current_vllm_config(vllm_config: VllmConfig):
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
if check_compile and \
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased

View File

@ -114,7 +114,7 @@ def _initialize_model(
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config):
with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
@ -142,7 +142,7 @@ def _initialize_model(
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config):
with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(**kwargs)

View File

@ -288,7 +288,8 @@ class TensorizerAgent:
model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config):
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config,
check_compile=True):
return self.tensorizer_config.model_class(
vllm_config=self.vllm_config, )

View File

@ -8,7 +8,8 @@ import cloudpickle
import torch
import torch.nn as nn
from vllm.config import ObservabilityConfig, VllmConfig
from vllm.config import (ObservabilityConfig, VllmConfig,
set_current_vllm_config)
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -498,8 +499,11 @@ class WorkerWrapperBase:
group.
"""
self.rpc_rank = rpc_rank
self.vllm_config = vllm_config
self.worker: Optional[WorkerBase] = None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if vllm_config.model_config is not None:
# it can be None in tests
trust_remote_code = vllm_config.model_config.trust_remote_code
@ -533,6 +537,9 @@ class WorkerWrapperBase:
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config", None)
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker")
enable_trace_function_call_for_thread(self.vllm_config)
from vllm.plugins import load_general_plugins
@ -546,8 +553,10 @@ class WorkerWrapperBase:
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(**kwargs)
assert self.worker is not None
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try: