[Bugfix] Fix chunked prefill with model dtype float32 on Turing Devices (#9850)
Signed-off-by: Wallas Santos <wallashss@ibm.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
d04b13a380
commit
c27df94e1f
@ -98,4 +98,5 @@ markers = [
|
||||
"quant_model: run this model test under Quantized category",
|
||||
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
|
||||
"skip_v1: do not run this test with v1",
|
||||
"optional: optional tests that are automatically skipped, include --optional to run them",
|
||||
]
|
||||
|
@ -1030,3 +1030,22 @@ def dummy_gemma2_embedding_path():
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
return _dummy_gemma2_embedding_path
|
||||
|
||||
|
||||
# Add the flag `--optional` to allow run tests
|
||||
# that are marked with @pytest.mark.optional
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--optional",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run optional test")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--optional"):
|
||||
# --optional given in cli: do not skip optional tests
|
||||
return
|
||||
skip_optional = pytest.mark.skip(reason="need --optional option to run")
|
||||
for item in items:
|
||||
if "optional" in item.keywords:
|
||||
item.add_marker(skip_optional)
|
||||
|
@ -40,6 +40,13 @@ def test_contexted_kv_attention(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
|
||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||
89):
|
||||
pytest.skip(
|
||||
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||
' arch < 89')
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
|
||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||
89):
|
||||
pytest.skip(
|
||||
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||
' arch < 89')
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -462,3 +476,52 @@ def test_contexted_kv_attention_alibi(
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
# These tests are optional to only run when explicitly invoked
|
||||
#
|
||||
# pytest -v -s --optional \
|
||||
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
|
||||
#
|
||||
# These tests are useful to test model dtype float32 on Turing devices.
|
||||
# We skip them to not increase the time when running tests on CI
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_f32(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
||||
sliding_window, dtype, kv_cache_dtype, device)
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_alibi_f32(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
||||
dtype, kv_cache_dtype, device)
|
||||
|
@ -7,6 +7,13 @@ import triton.language as tl
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Static kernels parameters
|
||||
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
NUM_WARPS = 8
|
||||
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
|
||||
if triton.__version__ >= "2.1.0":
|
||||
|
||||
@triton.jit
|
||||
@ -50,6 +57,7 @@ if triton.__version__ >= "2.1.0":
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
@ -130,7 +138,7 @@ if triton.__version__ >= "2.1.0":
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
@ -178,7 +186,7 @@ if triton.__version__ >= "2.1.0":
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
@ -204,7 +212,7 @@ if triton.__version__ >= "2.1.0":
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk *= sm_scale
|
||||
# apply causal mask
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
@ -238,7 +246,7 @@ if triton.__version__ >= "2.1.0":
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
@ -485,6 +493,7 @@ if triton.__version__ >= "2.1.0":
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
IN_PRECISION: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
@ -560,7 +569,7 @@ if triton.__version__ >= "2.1.0":
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
@ -600,7 +609,7 @@ if triton.__version__ >= "2.1.0":
|
||||
v = v_load
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc += tl.dot(p, v, allow_tf32=False)
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
@ -635,7 +644,7 @@ if triton.__version__ >= "2.1.0":
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, allow_tf32=False)
|
||||
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
@ -673,7 +682,7 @@ if triton.__version__ >= "2.1.0":
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
|
||||
acc += tl.dot(p, v, allow_tf32=False)
|
||||
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
@ -709,13 +718,17 @@ if triton.__version__ >= "2.1.0":
|
||||
alibi_slopes=None,
|
||||
sliding_window=None):
|
||||
|
||||
BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
NUM_WARPS = 8
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
if q.dtype is torch.float32:
|
||||
BLOCK = BLOCK // 2
|
||||
# if q.dtype is torch.float32:
|
||||
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||
|
||||
# Turing does have tensor core for float32 multiplication
|
||||
# use ieee as fallback for triton kernels work. There is also
|
||||
# warning on vllm/config.py to inform users this fallback
|
||||
# implementation
|
||||
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
|
||||
|
||||
# Conversion of FP8 Tensor from uint8 storage to
|
||||
# appropriate torch.dtype for interpretation by Triton
|
||||
@ -799,6 +812,7 @@ if triton.__version__ >= "2.1.0":
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
@ -850,6 +864,7 @@ if triton.__version__ >= "2.1.0":
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
|
@ -2388,6 +2388,16 @@ class VllmConfig:
|
||||
self.quant_config = VllmConfig._get_quantization_config(
|
||||
self.model_config, self.load_config)
|
||||
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
self.scheduler_config.chunked_prefill_enabled and \
|
||||
self.model_config.dtype == torch.float32 and \
|
||||
current_platform.get_device_capability() == (7, 5):
|
||||
print_warning_once(
|
||||
"Turing devices tensor cores do not support float32 matmul. "
|
||||
"To workaround this limitation, vLLM will set 'ieee' input "
|
||||
"precision for chunked prefill triton kernels.")
|
||||
|
||||
if self.compilation_config is None:
|
||||
self.compilation_config = CompilationConfig()
|
||||
if envs.VLLM_USE_V1 and not self.model_config.enforce_eager:
|
||||
|
@ -1055,6 +1055,7 @@ class EngineArgs:
|
||||
msg = "Chunked prefill is not supported for embedding models"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
|
Loading…
x
Reference in New Issue
Block a user