[Bugfix] Fix try-catch conditions to import correct Flash Attention Backend in Draft Model (#9101)

This commit is contained in:
TJian 2024-10-05 22:00:04 -07:00 committed by GitHub
parent f4dd830e09
commit 23fea8714a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,11 +6,16 @@ from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
try: try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata try:
except ModuleNotFoundError: from vllm.attention.backends.flash_attn import FlashAttentionMetadata
# vllm_flash_attn is not installed, use the identical ROCm FA metadata except (ModuleNotFoundError, ImportError):
from vllm.attention.backends.rocm_flash_attn import ( # vllm_flash_attn is not installed, try the ROCm FA metadata
ROCmFlashAttentionMetadata as FlashAttentionMetadata) from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,