youkaichao 4fd9375028
[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2024-11-16 18:02:14 -08:00

41 lines
1.1 KiB
Python

import os
from typing import TYPE_CHECKING
import torch
from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
set_torch_compile_backend("openxla")
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel
compilation_config = vllm_config.compilation_config
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."