[ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if Triton is not used (#4658)
This commit is contained in:
parent
86b45ae065
commit
c0724fc915
@ -231,8 +231,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
self.attn_func = triton_attention
|
self.attn_func = triton_attention
|
||||||
logger.debug("Using Triton FA in ROCmBackend")
|
logger.debug("Using Triton FA in ROCmBackend")
|
||||||
else:
|
else:
|
||||||
# if not using triton, navi3x not use flash-attn either
|
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
|
||||||
if torch.cuda.get_device_capability()[0] == 11:
|
# either
|
||||||
|
if torch.cuda.get_device_capability()[0] != 9:
|
||||||
self.use_naive_attn = True
|
self.use_naive_attn = True
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user