[Bugfix] Fix chunked prefill with model dtype float32 on Turing Devices (#9850)
Signed-off-by: Wallas Santos <wallashss@ibm.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
d04b13a380
commit
c27df94e1f
@ -98,4 +98,5 @@ markers = [
|
|||||||
"quant_model: run this model test under Quantized category",
|
"quant_model: run this model test under Quantized category",
|
||||||
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
|
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
|
||||||
"skip_v1: do not run this test with v1",
|
"skip_v1: do not run this test with v1",
|
||||||
|
"optional: optional tests that are automatically skipped, include --optional to run them",
|
||||||
]
|
]
|
||||||
|
@ -1030,3 +1030,22 @@ def dummy_gemma2_embedding_path():
|
|||||||
with open(json_path, "w") as f:
|
with open(json_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
return _dummy_gemma2_embedding_path
|
return _dummy_gemma2_embedding_path
|
||||||
|
|
||||||
|
|
||||||
|
# Add the flag `--optional` to allow run tests
|
||||||
|
# that are marked with @pytest.mark.optional
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption("--optional",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="run optional test")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
if config.getoption("--optional"):
|
||||||
|
# --optional given in cli: do not skip optional tests
|
||||||
|
return
|
||||||
|
skip_optional = pytest.mark.skip(reason="need --optional option to run")
|
||||||
|
for item in items:
|
||||||
|
if "optional" in item.keywords:
|
||||||
|
item.add_marker(skip_optional)
|
||||||
|
@ -40,6 +40,13 @@ def test_contexted_kv_attention(
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||||
|
89):
|
||||||
|
pytest.skip(
|
||||||
|
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||||
|
' arch < 89')
|
||||||
|
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi(
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||||
|
89):
|
||||||
|
pytest.skip(
|
||||||
|
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||||
|
' arch < 89')
|
||||||
|
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
@ -462,3 +476,52 @@ def test_contexted_kv_attention_alibi(
|
|||||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
# These tests are optional to only run when explicitly invoked
|
||||||
|
#
|
||||||
|
# pytest -v -s --optional \
|
||||||
|
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
|
||||||
|
#
|
||||||
|
# These tests are useful to test model dtype float32 on Turing devices.
|
||||||
|
# We skip them to not increase the time when running tests on CI
|
||||||
|
@pytest.mark.optional
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_contexted_kv_attention_f32(
|
||||||
|
num_heads: int,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
head_size: int,
|
||||||
|
sliding_window: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
||||||
|
sliding_window, dtype, kv_cache_dtype, device)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.optional
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_contexted_kv_attention_alibi_f32(
|
||||||
|
num_heads: int,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
||||||
|
dtype, kv_cache_dtype, device)
|
||||||
|
@ -7,6 +7,13 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# Static kernels parameters
|
||||||
|
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||||
|
NUM_WARPS = 8
|
||||||
|
|
||||||
|
# To check compatibility
|
||||||
|
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||||
|
|
||||||
if triton.__version__ >= "2.1.0":
|
if triton.__version__ >= "2.1.0":
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -50,6 +57,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
stride_v_cache_d,
|
stride_v_cache_d,
|
||||||
stride_v_cache_bl,
|
stride_v_cache_bl,
|
||||||
num_queries_per_kv: int,
|
num_queries_per_kv: int,
|
||||||
|
IN_PRECISION: tl.constexpr,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr, # head size
|
BLOCK_DMODEL: tl.constexpr, # head size
|
||||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
@ -130,7 +138,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k = k_load
|
k = k_load
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
|
||||||
qk += tl.dot(q, k)
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
float("-inf"))
|
float("-inf"))
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
@ -178,7 +186,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v = v_load
|
v = v_load
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v)
|
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||||
# # update m_i and l_i
|
# # update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
@ -204,7 +212,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
# apply causal mask
|
# apply causal mask
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||||
@ -238,7 +246,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
other=0.0)
|
other=0.0)
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v)
|
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
|
||||||
# update m_i and l_i
|
# update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
@ -485,6 +493,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
stride_v_cache_d,
|
stride_v_cache_d,
|
||||||
stride_v_cache_bl,
|
stride_v_cache_bl,
|
||||||
num_queries_per_kv: int,
|
num_queries_per_kv: int,
|
||||||
|
IN_PRECISION: tl.constexpr,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr, # head size
|
BLOCK_DMODEL: tl.constexpr, # head size
|
||||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||||
@ -560,7 +569,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
k = k_load
|
k = k_load
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||||
float("-inf"))
|
float("-inf"))
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
@ -600,7 +609,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v = v_load
|
v = v_load
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v, allow_tf32=False)
|
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||||
# update m_i and l_i
|
# update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
@ -635,7 +644,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k, allow_tf32=False)
|
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||||
float("-inf"))
|
float("-inf"))
|
||||||
@ -673,7 +682,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
other=0.0)
|
other=0.0)
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
|
|
||||||
acc += tl.dot(p, v, allow_tf32=False)
|
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
|
||||||
# update m_i and l_i
|
# update m_i and l_i
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
@ -709,13 +718,17 @@ if triton.__version__ >= "2.1.0":
|
|||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None):
|
sliding_window=None):
|
||||||
|
|
||||||
BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
q_dtype_is_f32 = q.dtype is torch.float32
|
||||||
NUM_WARPS = 8
|
|
||||||
|
|
||||||
# need to reduce num. blocks when using fp32
|
# need to reduce num. blocks when using fp32
|
||||||
# due to increased use of GPU shared memory
|
# due to increased use of GPU shared memory
|
||||||
if q.dtype is torch.float32:
|
# if q.dtype is torch.float32:
|
||||||
BLOCK = BLOCK // 2
|
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
|
||||||
|
|
||||||
|
# Turing does have tensor core for float32 multiplication
|
||||||
|
# use ieee as fallback for triton kernels work. There is also
|
||||||
|
# warning on vllm/config.py to inform users this fallback
|
||||||
|
# implementation
|
||||||
|
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None
|
||||||
|
|
||||||
# Conversion of FP8 Tensor from uint8 storage to
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
# appropriate torch.dtype for interpretation by Triton
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
@ -799,6 +812,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_cache.stride(
|
v_cache.stride(
|
||||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
num_queries_per_kv=num_queries_per_kv,
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
IN_PRECISION=IN_PRECISION,
|
||||||
BLOCK_M=BLOCK,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
@ -850,6 +864,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_cache.stride(
|
v_cache.stride(
|
||||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
num_queries_per_kv=num_queries_per_kv,
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
|
IN_PRECISION=IN_PRECISION,
|
||||||
BLOCK_M=BLOCK,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||||
|
@ -2388,6 +2388,16 @@ class VllmConfig:
|
|||||||
self.quant_config = VllmConfig._get_quantization_config(
|
self.quant_config = VllmConfig._get_quantization_config(
|
||||||
self.model_config, self.load_config)
|
self.model_config, self.load_config)
|
||||||
|
|
||||||
|
if self.scheduler_config is not None and \
|
||||||
|
self.model_config is not None and \
|
||||||
|
self.scheduler_config.chunked_prefill_enabled and \
|
||||||
|
self.model_config.dtype == torch.float32 and \
|
||||||
|
current_platform.get_device_capability() == (7, 5):
|
||||||
|
print_warning_once(
|
||||||
|
"Turing devices tensor cores do not support float32 matmul. "
|
||||||
|
"To workaround this limitation, vLLM will set 'ieee' input "
|
||||||
|
"precision for chunked prefill triton kernels.")
|
||||||
|
|
||||||
if self.compilation_config is None:
|
if self.compilation_config is None:
|
||||||
self.compilation_config = CompilationConfig()
|
self.compilation_config = CompilationConfig()
|
||||||
if envs.VLLM_USE_V1 and not self.model_config.enforce_eager:
|
if envs.VLLM_USE_V1 and not self.model_config.enforce_eager:
|
||||||
|
@ -1055,6 +1055,7 @@ class EngineArgs:
|
|||||||
msg = "Chunked prefill is not supported for embedding models"
|
msg = "Chunked prefill is not supported for embedding models"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||||
target_model_config=model_config,
|
target_model_config=model_config,
|
||||||
target_parallel_config=parallel_config,
|
target_parallel_config=parallel_config,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user