vllm/rocm_patch/flashpy_xformers-0.0.23.rocm.patch
TJian f375ec8440
[ROCm] Upgrade xformers version for ROCm & update doc (#2079)
Co-authored-by: miloice <jeffaw99@hotmail.com>
2023-12-13 00:56:05 -08:00

153 lines
5.6 KiB
Diff

--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
@@ -36,44 +36,44 @@
FLASH_VERSION = "0.0.0"
try:
- try:
- from ... import _C_flashattention # type: ignore[attr-defined]
- from ..._cpp_lib import _build_metadata
-
- if _build_metadata is not None:
- FLASH_VERSION = _build_metadata.flash_version
- except ImportError:
- import flash_attn
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
- FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
- if (
- flash_ver_parsed != (2, 3, 6)
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
- ):
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
+ #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata
+
+ # if _build_metadata is not None:
+ # FLASH_VERSION = _build_metadata.flash_version
+ #except ImportError:
+ import flash_attn
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+ FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
+ # if (
+ # flash_ver_parsed != (2, 3, 6)
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
+ # ):
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
# create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
-
- _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, "
- "bool is_causal, int window_left, "
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, "
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, "
+ # "bool is_causal, int window_left, "
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #)
+
+ #_flash_lib.define(
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, "
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #)
def _flash_fwd(
query,
@@ -111,8 +111,8 @@
p,
softmax_scale,
is_causal,
- window_left, # window_size_left
- window_right, # window_size_right
+ # window_left, # window_size_left
+ # window_right, # window_size_right
return_softmax,
None, # rng
)
@@ -134,15 +134,15 @@
out,
cu_seq_lens_q,
cu_seq_lens_k,
- seqused_k,
+ # seqused_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False,
is_causal,
- window_left,
- window_right,
+ # window_left,
+ # window_right,
return_softmax,
None,
)
@@ -184,8 +184,8 @@
p,
softmax_scale,
is_causal,
- window_left,
- window_right,
+ # window_left,
+ # window_right,
None,
rng_state,
)
@@ -208,15 +208,15 @@
softmax_scale,
False, # zero_tensors
is_causal,
- window_left,
- window_right,
+ # window_left,
+ # window_right,
None,
rng_state,
)
return dq, dk, dv
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
@@ -400,7 +400,7 @@
implementation.
"""
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}