[ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if Triton is not used (#4658)

This commit is contained in:
alexeykondrat 2024-05-18 01:09:11 -04:00 committed by GitHub
parent 86b45ae065
commit c0724fc915
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -231,8 +231,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
# if not using triton, navi3x not use flash-attn either
if torch.cuda.get_device_capability()[0] == 11:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
if torch.cuda.get_device_capability()[0] != 9:
self.use_naive_attn = True
else:
try: