[Bugfix] Fix try-catch conditions to import correct Flash Attention Backend in Draft Model (#9101)

This commit is contained in:
TJian 2024-10-05 22:00:04 -07:00 committed by GitHub
parent f4dd830e09
commit 23fea8714a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,11 +6,16 @@ from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput
try:
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
except (ModuleNotFoundError, ImportError):
# vllm_flash_attn is not installed, try the ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,