[Misc] Add multipstep chunked-prefill support for FlashInfer (#10467)
This commit is contained in:
parent
b7ee940a82
commit
0794e7446e
@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
|
|||||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||||
|
int const n_pad = num_seqs - num_queries;
|
||||||
|
if (n_pad && blockIdx.x == 0) {
|
||||||
|
// Handle cuda graph padding
|
||||||
|
int const offset = num_queries;
|
||||||
|
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||||
|
input_tokens_ptr[offset + i] = 0;
|
||||||
|
input_positions_ptr[offset + i] = 0;
|
||||||
|
slot_mapping_ptr[offset + i] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||||
|
|
||||||
if (blockIdx.x < num_query_blocks) {
|
if (blockIdx.x < num_query_blocks) {
|
||||||
|
@ -5,6 +5,8 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
|
|
||||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
@ -19,10 +21,11 @@ NUM_PROMPTS = [10]
|
|||||||
@pytest.mark.parametrize("tp_size", [1])
|
@pytest.mark.parametrize("tp_size", [1])
|
||||||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("enforce_eager", [True])
|
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||||
|
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||||
def test_multi_step_llm(
|
def test_multi_step_llm(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
@ -36,6 +39,8 @@ def test_multi_step_llm(
|
|||||||
num_scheduler_steps: int,
|
num_scheduler_steps: int,
|
||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
num_logprobs: Optional[int],
|
num_logprobs: Optional[int],
|
||||||
|
attention_backend: str,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
|
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
|
||||||
|
|
||||||
@ -63,6 +68,7 @@ def test_multi_step_llm(
|
|||||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||||
completions endpoint; `None` -> 1 logprob returned.
|
completions endpoint; `None` -> 1 logprob returned.
|
||||||
"""
|
"""
|
||||||
|
override_backend_env_variable(monkeypatch, attention_backend)
|
||||||
|
|
||||||
prompts = example_prompts
|
prompts = example_prompts
|
||||||
if len(prompts) < num_prompts:
|
if len(prompts) < num_prompts:
|
||||||
@ -114,6 +120,7 @@ def test_multi_step_llm(
|
|||||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||||
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
|
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
|
||||||
|
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
|
||||||
def test_multi_step_llm_w_prompt_logprobs(
|
def test_multi_step_llm_w_prompt_logprobs(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
|
|||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
num_logprobs: Optional[int],
|
num_logprobs: Optional[int],
|
||||||
num_prompt_logprobs: Optional[int],
|
num_prompt_logprobs: Optional[int],
|
||||||
|
attention_backend: str,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
|
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
|
||||||
|
|
||||||
@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
|
|||||||
note that this argument is not supported by the
|
note that this argument is not supported by the
|
||||||
OpenAI completions endpoint.
|
OpenAI completions endpoint.
|
||||||
"""
|
"""
|
||||||
|
override_backend_env_variable(monkeypatch, attention_backend)
|
||||||
|
|
||||||
prompts = example_prompts
|
prompts = example_prompts
|
||||||
if len(prompts) < num_prompts:
|
if len(prompts) < num_prompts:
|
||||||
@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
|
|||||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||||
|
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
|
||||||
def test_multi_step_llm_chunked_prefill_prefix_cache(
|
def test_multi_step_llm_chunked_prefill_prefix_cache(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
|
|||||||
num_scheduler_steps: int,
|
num_scheduler_steps: int,
|
||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
num_logprobs: Optional[int],
|
num_logprobs: Optional[int],
|
||||||
|
attention_backend: str,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
|
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
|
||||||
|
|
||||||
@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
|
|||||||
#
|
#
|
||||||
# The Incorrect scheduling behavior - if it occurs - will cause an exception
|
# The Incorrect scheduling behavior - if it occurs - will cause an exception
|
||||||
# in the model runner resulting from `do_sample=False`.
|
# in the model runner resulting from `do_sample=False`.
|
||||||
|
override_backend_env_variable(monkeypatch, attention_backend)
|
||||||
|
|
||||||
assert len(example_prompts) >= 2
|
assert len(example_prompts) >= 2
|
||||||
challenge_prompts = copy.deepcopy(example_prompts)
|
challenge_prompts = copy.deepcopy(example_prompts)
|
||||||
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
|
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
|
||||||
|
@ -256,7 +256,12 @@ class FlashInferState(AttentionState):
|
|||||||
def begin_forward(self, model_input):
|
def begin_forward(self, model_input):
|
||||||
assert not self._is_graph_capturing
|
assert not self._is_graph_capturing
|
||||||
state = self
|
state = self
|
||||||
if model_input.attn_metadata.use_cuda_graph:
|
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
|
||||||
|
is_decode = model_input.attn_metadata.num_prefills == 0
|
||||||
|
# In case of multistep chunked-prefill, there might be prefill requests
|
||||||
|
# scheduled while CUDA graph mode is enabled. We don't run graph in that
|
||||||
|
# case.
|
||||||
|
if use_cuda_graph and is_decode:
|
||||||
batch_size = model_input.input_tokens.shape[0]
|
batch_size = model_input.input_tokens.shape[0]
|
||||||
state = (self.runner.graph_runners[model_input.virtual_engine]
|
state = (self.runner.graph_runners[model_input.virtual_engine]
|
||||||
[batch_size].attn_state)
|
[batch_size].attn_state)
|
||||||
@ -429,10 +434,24 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
Update metadata in-place to advance one decode step.
|
Update metadata in-place to advance one decode step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert not turn_prefills_into_decodes, \
|
if turn_prefills_into_decodes:
|
||||||
("Chunked prefill is not supported with flashinfer yet."
|
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
||||||
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
|
# decodes are scheduled together. In the first step, all the
|
||||||
"specific parameter.")
|
# prefills turn into decodes. This update reflects that
|
||||||
|
# conversion.
|
||||||
|
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||||
|
# Flashinfer doesn't support speculative decoding + chunked-prefill
|
||||||
|
# + multi-step scheduling yet.
|
||||||
|
assert self.decode_query_len == 1
|
||||||
|
self.num_decode_tokens += self.num_prefills
|
||||||
|
self.num_prefills = 0
|
||||||
|
self.num_prefill_tokens = 0
|
||||||
|
self.max_prefill_seq_len = 0
|
||||||
|
self.max_query_len = 1
|
||||||
|
|
||||||
|
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||||
|
else:
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
|
||||||
assert num_seqs > 0
|
assert num_seqs > 0
|
||||||
assert num_queries > 0
|
assert num_queries > 0
|
||||||
|
@ -5,6 +5,7 @@ import itertools
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
||||||
Tuple, Type, TypeVar, Union)
|
Tuple, Type, TypeVar, Union)
|
||||||
@ -1028,6 +1029,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
self.has_inner_state = model_config.has_inner_state
|
self.has_inner_state = model_config.has_inner_state
|
||||||
|
|
||||||
|
self.in_profile_run = False
|
||||||
|
|
||||||
# When using CUDA graph, the input block tables must be padded to
|
# When using CUDA graph, the input block tables must be padded to
|
||||||
# max_seq_len_to_capture. However, creating the block table in
|
# max_seq_len_to_capture. However, creating the block table in
|
||||||
# Python can be expensive. To optimize this, we cache the block table
|
# Python can be expensive. To optimize this, we cache the block table
|
||||||
@ -1228,110 +1231,123 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
return builder.build() # type: ignore
|
return builder.build() # type: ignore
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_in_profile_run(self):
|
||||||
|
self.in_profile_run = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.in_profile_run = False
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Enable top-k sampling to reflect the accurate memory usage.
|
with self.set_in_profile_run():
|
||||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
sampling_params = \
|
||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
# This represents the maximum number of different requests
|
max_num_batched_tokens = \
|
||||||
# that will have unique loras, an therefore the max amount of memory
|
self.scheduler_config.max_num_batched_tokens
|
||||||
# consumption create dummy lora request copies from the lora request
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
# passed in, which contains a lora from the lora warmup path.
|
# This represents the maximum number of different requests
|
||||||
dummy_lora_requests: List[LoRARequest] = []
|
# that will have unique loras, an therefore the max amount of memory
|
||||||
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
# consumption create dummy lora request copies from the lora request
|
||||||
if self.lora_config:
|
# passed in, which contains a lora from the lora warmup path.
|
||||||
assert self.lora_manager is not None
|
dummy_lora_requests: List[LoRARequest] = []
|
||||||
with self.lora_manager.dummy_lora_cache():
|
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||||||
for idx in range(self.lora_config.max_loras):
|
if self.lora_config:
|
||||||
lora_id = idx + 1
|
assert self.lora_manager is not None
|
||||||
dummy_lora_request = LoRARequest(
|
with self.lora_manager.dummy_lora_cache():
|
||||||
lora_name=f"warmup_{lora_id}",
|
for idx in range(self.lora_config.max_loras):
|
||||||
lora_int_id=lora_id,
|
lora_id = idx + 1
|
||||||
lora_path="/not/a/real/path",
|
dummy_lora_request = LoRARequest(
|
||||||
)
|
lora_name=f"warmup_{lora_id}",
|
||||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
lora_int_id=lora_id,
|
||||||
rank=LORA_WARMUP_RANK)
|
lora_path="/not/a/real/path",
|
||||||
dummy_lora_requests.append(dummy_lora_request)
|
)
|
||||||
dummy_lora_requests_per_seq = [
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
rank=LORA_WARMUP_RANK)
|
||||||
for idx in range(max_num_seqs)
|
dummy_lora_requests.append(dummy_lora_request)
|
||||||
]
|
dummy_lora_requests_per_seq = [
|
||||||
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||||
|
for idx in range(max_num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
# Profile memory usage with max_num_sequences sequences and the total
|
# Profile memory usage with max_num_sequences sequences and the
|
||||||
# number of tokens equal to max_num_batched_tokens.
|
# total number of tokens equal to max_num_batched_tokens.
|
||||||
seqs: List[SequenceGroupMetadata] = []
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
# Additional GPU memory may be needed for multi-modal encoding,
|
||||||
# needs to be accounted for when calculating the GPU blocks for
|
# which needs to be accounted for when calculating the GPU blocks
|
||||||
# vLLM blocker manager.
|
# for vLLM blocker manager.
|
||||||
# To exercise the worst scenario for GPU memory consumption,
|
# To exercise the worst scenario for GPU memory consumption,
|
||||||
# the number of seqs (batch_size) is chosen to maximize the number
|
# the number of seqs (batch_size) is chosen to maximize the number
|
||||||
# of images processed.
|
# of images processed.
|
||||||
|
|
||||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||||
self.model_config)
|
self.model_config)
|
||||||
if max_mm_tokens > 0:
|
if max_mm_tokens > 0:
|
||||||
max_num_seqs_orig = max_num_seqs
|
max_num_seqs_orig = max_num_seqs
|
||||||
max_num_seqs = min(max_num_seqs,
|
max_num_seqs = min(max_num_seqs,
|
||||||
max_num_batched_tokens // max_mm_tokens)
|
max_num_batched_tokens // max_mm_tokens)
|
||||||
if max_num_seqs < 1:
|
if max_num_seqs < 1:
|
||||||
expr = (f"min({max_num_seqs_orig}, "
|
expr = (f"min({max_num_seqs_orig}, "
|
||||||
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Computed max_num_seqs (%s) to be less than 1. "
|
"Computed max_num_seqs (%s) to be less than 1. "
|
||||||
"Setting it to the minimum value of 1.", expr)
|
"Setting it to the minimum value of 1.", expr)
|
||||||
max_num_seqs = 1
|
max_num_seqs = 1
|
||||||
|
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
for group_id in range(max_num_seqs):
|
for group_id in range(max_num_seqs):
|
||||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
batch_size += seq_len
|
batch_size += seq_len
|
||||||
|
|
||||||
dummy_data = self.input_registry \
|
dummy_data = self.input_registry \
|
||||||
.dummy_data_for_profiling(self.model_config,
|
.dummy_data_for_profiling(self.model_config,
|
||||||
seq_len,
|
seq_len,
|
||||||
self.mm_registry)
|
self.mm_registry)
|
||||||
|
|
||||||
seq = SequenceGroupMetadata(
|
seq = SequenceGroupMetadata(
|
||||||
request_id=str(group_id),
|
request_id=str(group_id),
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_data={group_id: dummy_data.seq_data},
|
seq_data={group_id: dummy_data.seq_data},
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||||
if dummy_lora_requests_per_seq else None,
|
if dummy_lora_requests_per_seq else None,
|
||||||
multi_modal_data=dummy_data.multi_modal_data,
|
multi_modal_data=dummy_data.multi_modal_data,
|
||||||
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
|
multi_modal_placeholders=dummy_data.
|
||||||
)
|
multi_modal_placeholders,
|
||||||
seqs.append(seq)
|
)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
# Run the model with the dummy inputs.
|
# Run the model with the dummy inputs.
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
# it by reference, rather by specializing on the value ``None``.
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
# the `dtype` argument does not matter, and we use `float32` as
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
# a placeholder (it has wide hardware support).
|
# a placeholder (it has wide hardware support).
|
||||||
# it is important to create tensors inside the loop, rather than
|
# it is important to create tensors inside the loop, rather than
|
||||||
# multiplying the list, to avoid Dynamo from treating them as
|
# multiplying the list, to avoid Dynamo from treating them as
|
||||||
# tensor aliasing.
|
# tensor aliasing.
|
||||||
kv_caches = [
|
kv_caches = [
|
||||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
model_input = self.prepare_model_input(
|
model_input = self.prepare_model_input(
|
||||||
seqs, finished_requests_ids=finished_requests_ids)
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
if not get_pp_group().is_first_rank:
|
if not get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
intermediate_tensors = \
|
||||||
batch_size=batch_size,
|
self.model.make_empty_intermediate_tensors(
|
||||||
dtype=self.model_config.dtype,
|
batch_size=batch_size,
|
||||||
device=self.device)
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
return
|
||||||
|
|
||||||
def remove_all_loras(self):
|
def remove_all_loras(self):
|
||||||
if not self.lora_manager:
|
if not self.lora_manager:
|
||||||
|
@ -32,7 +32,7 @@ logger = init_logger(__name__)
|
|||||||
MULTI_STEP_ATTENTION_BACKENDS = [
|
MULTI_STEP_ATTENTION_BACKENDS = [
|
||||||
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
|
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
|
||||||
]
|
]
|
||||||
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
|
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
|
||||||
|
|
||||||
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
|
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
|
||||||
-> List[str]:
|
-> List[str]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user