[Bugfix] Fix try-catch conditions to import correct Flash Attention Backend in Draft Model (#9101)
This commit is contained in:
parent
f4dd830e09
commit
23fea8714a
@ -5,12 +5,17 @@ import torch
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
|
||||
try:
|
||||
try:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
except ModuleNotFoundError:
|
||||
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# vllm_flash_attn is not installed, try the ROCm FA metadata
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
except (ModuleNotFoundError, ImportError) as err:
|
||||
raise RuntimeError(
|
||||
"Draft model speculative decoding currently only supports"
|
||||
"CUDA and ROCm flash attention backend.") from err
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
|
Loading…
x
Reference in New Issue
Block a user