[Bugfix] Only require XGrammar on x86 (#10865)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin 2024-12-03 13:32:21 -05:00 committed by GitHub
parent 2f2cdc745a
commit 7090c27bb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 3 deletions

View File

@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11 lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1 outlines >= 0.0.43, < 0.1
xgrammar xgrammar >= 0.1.5; platform_machine == "x86_64"
typing_extensions >= 4.10 typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs partial-json-parser # used for parsing partial JSON outputs

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -25,6 +26,12 @@ def maybe_backend_fallback(
guided_params.backend = "xgrammar" guided_params.backend = "xgrammar"
if guided_params.backend == "xgrammar": if guided_params.backend == "xgrammar":
# xgrammar only has x86 wheels for linux, fallback to outlines
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
# xgrammar doesn't support regex or choice, fallback to outlines # xgrammar doesn't support regex or choice, fallback to outlines
if guided_params.regex is not None or guided_params.choice is not None: if guided_params.regex is not None or guided_params.choice is not None:
logger.warning( logger.warning(

View File

@ -1,5 +1,5 @@
from .interface import _Backend # noqa: F401 from .interface import _Backend # noqa: F401
from .interface import Platform, PlatformEnum, UnspecifiedPlatform from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform current_platform: Platform
@ -120,4 +120,4 @@ elif is_openvino:
else: else:
current_platform = UnspecifiedPlatform() current_platform = UnspecifiedPlatform()
__all__ = ['Platform', 'PlatformEnum', 'current_platform'] __all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']

View File

@ -1,4 +1,5 @@
import enum import enum
import platform
import random import random
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
@ -37,6 +38,14 @@ class PlatformEnum(enum.Enum):
UNSPECIFIED = enum.auto() UNSPECIFIED = enum.auto()
class CpuArchEnum(enum.Enum):
X86 = enum.auto()
ARM = enum.auto()
POWERPC = enum.auto()
OTHER = enum.auto()
UNKNOWN = enum.auto()
class DeviceCapability(NamedTuple): class DeviceCapability(NamedTuple):
major: int major: int
minor: int minor: int
@ -184,6 +193,23 @@ class Platform:
f"{quant} quantization is currently not supported in " f"{quant} quantization is currently not supported in "
f"{cls.device_name}.") f"{cls.device_name}.")
@classmethod
def get_cpu_architecture(cls) -> CpuArchEnum:
"""
Determine the CPU architecture of the current system.
Returns CpuArchEnum indicating the architecture type.
"""
machine = platform.machine().lower()
if machine in ("x86_64", "amd64", "i386", "i686"):
return CpuArchEnum.X86
elif machine.startswith("arm") or machine.startswith("aarch"):
return CpuArchEnum.ARM
elif machine.startswith("ppc"):
return CpuArchEnum.POWERPC
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED