[Bugfix] Fix try-catch conditions to import correct Flash Attention Backend in Draft Model (#9101)

This commit is contained in:
TJian 2024-10-05 22:00:04 -07:00 committed by GitHub
parent f4dd830e09
commit 23fea8714a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,12 +5,17 @@ import torch
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
try:
try: try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError: except (ModuleNotFoundError, ImportError):
# vllm_flash_attn is not installed, use the identical ROCm FA metadata # vllm_flash_attn is not installed, try the ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import ( from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata) ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,