[Hardware][Intel GPU] add XPU bf16 support (#12392)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
f8ece6e17f
commit
f256ebe4df
@ -36,7 +36,7 @@ VLLM_TARGET_DEVICE=xpu python setup.py install
|
|||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
- FP16 is the default data type in the current XPU backend. The BF16 data
|
- FP16 is the default data type in the current XPU backend. The BF16 data
|
||||||
type will be supported in the future.
|
type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Set up using Docker
|
## Set up using Docker
|
||||||
|
@ -66,8 +66,13 @@ class XPUPlatform(Platform):
|
|||||||
# check and update model config
|
# check and update model config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
if model_config.dtype == torch.bfloat16:
|
if model_config.dtype == torch.bfloat16:
|
||||||
|
bf16_supported = cls.device_support_bf16()
|
||||||
|
if not bf16_supported:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"bfloat16 is not fully supported on XPU, casting to float16.")
|
"bfloat16 is only supported on Intel Data Center GPU, "
|
||||||
|
"Intel Arc GPU is not supported yet. Your device is %s,"
|
||||||
|
"which is not supported. will fallback to float16",
|
||||||
|
cls.get_device_name())
|
||||||
model_config.dtype = torch.float16
|
model_config.dtype = torch.float16
|
||||||
if not model_config.enforce_eager:
|
if not model_config.enforce_eager:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -116,3 +121,15 @@ class XPUPlatform(Platform):
|
|||||||
) -> float:
|
) -> float:
|
||||||
torch.xpu.reset_peak_memory_stats(device)
|
torch.xpu.reset_peak_memory_stats(device)
|
||||||
return torch.xpu.max_memory_allocated(device)
|
return torch.xpu.max_memory_allocated(device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def device_support_bf16(cls) -> bool:
|
||||||
|
device_name = cls.get_device_name().lower()
|
||||||
|
if device_name.count("arc") > 0:
|
||||||
|
return False
|
||||||
|
elif device_name.count("data center gpu") > 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("Unknown device name %s, always use float16",
|
||||||
|
device_name)
|
||||||
|
return False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user