Support for guided decoding for offline LLM (#6878)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
825b044863
commit
654bc5ca49
@ -111,6 +111,7 @@ autodoc_mock_imports = [
|
||||
"tqdm",
|
||||
"tensorizer",
|
||||
"pynvml",
|
||||
"outlines",
|
||||
]
|
||||
|
||||
for mock_target in autodoc_mock_imports:
|
||||
|
@ -1,6 +1,26 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prompts():
|
||||
return [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_token_ids():
|
||||
return [
|
||||
[0],
|
||||
[0, 1],
|
||||
[0, 2, 1],
|
||||
[0, 3, 1, 2],
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_regex():
|
||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
@ -66,4 +86,4 @@ column: "col_1" | "col_2"
|
||||
table: "table_1" | "table_2"
|
||||
condition: column "=" number
|
||||
number: "1" | "2"
|
||||
""")
|
||||
""")
|
142
tests/entrypoints/llm/test_guided_generate.py
Normal file
142
tests/entrypoints/llm/test_guided_generate.py
Normal file
@ -0,0 +1,142 @@
|
||||
import json
|
||||
import re
|
||||
import weakref
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...conftest import cleanup
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_regex(sample_regex, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert re.fullmatch(sample_regex, generated_text) is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_json_completion(sample_json_schema, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_json=sample_json_schema))
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_choice_completion(sample_guided_choice, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts="The best language for type-safe systems programming is ",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_choice=sample_guided_choice))
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert generated_text in sample_guided_choice
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_grammar(sample_sql_statements, llm):
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_grammar=sample_sql_statements))
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_statements)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
|
||||
parse_and_batch_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -262,6 +265,8 @@ class LLM:
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -303,6 +308,14 @@ class LLM:
|
||||
else:
|
||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||
|
||||
if isinstance(guided_options_request, dict):
|
||||
if len(guided_options_request) > 1:
|
||||
raise ValueError(
|
||||
"You can only use one guided decoding but multiple is "
|
||||
f"specified: {guided_options_request}")
|
||||
guided_options_request = GuidedDecodingRequest(
|
||||
**guided_options_request)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
@ -311,7 +324,8 @@ class LLM:
|
||||
inputs=inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
guided_options=guided_options_request)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||
@ -508,6 +522,7 @@ class LLM:
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
) -> None:
|
||||
if isinstance(inputs, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
@ -523,6 +538,15 @@ class LLM:
|
||||
raise ValueError("The lengths of prompts and lora_request "
|
||||
"must be the same.")
|
||||
|
||||
if isinstance(params, list):
|
||||
params = [
|
||||
self._add_guided_processor(param, guided_options)
|
||||
if isinstance(param, SamplingParams) else param
|
||||
for param in params
|
||||
]
|
||||
elif isinstance(params, SamplingParams):
|
||||
params = self._add_guided_processor(params, guided_options)
|
||||
|
||||
# Add requests to the engine.
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
self._add_request(
|
||||
@ -548,6 +572,24 @@ class LLM:
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
def _add_guided_processor(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None):
|
||||
if guided_options:
|
||||
if guided_options.guided_decoding_backend is None:
|
||||
decoding_config = self.llm_engine.get_decoding_config()
|
||||
guided_options.guided_decoding_backend = (
|
||||
decoding_config.guided_decoding_backend)
|
||||
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
|
||||
guided_options.guided_decoding_backend, guided_options,
|
||||
self.get_tokenizer())
|
||||
if guided_logits_processor:
|
||||
if params.logits_processors is None:
|
||||
params.logits_processors = []
|
||||
params.logits_processors.append(guided_logits_processor)
|
||||
return params
|
||||
|
||||
def _run_engine(
|
||||
self, *, use_tqdm: bool
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -14,6 +15,23 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
# so we have to provide the values as literals
|
||||
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
|
||||
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
|
||||
if isinstance(torch, _MockModule):
|
||||
_LONG_INFO = _MOCK_LONG_INFO
|
||||
else:
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
except ModuleNotFoundError:
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
|
||||
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
|
||||
|
||||
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does not allow extra fields
|
||||
@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
seed: Optional[int] = Field(None,
|
||||
ge=torch.iinfo(torch.long).min,
|
||||
le=torch.iinfo(torch.long).max)
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
max_tokens: Optional[int] = 16
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
seed: Optional[int] = Field(None,
|
||||
ge=torch.iinfo(torch.long).min,
|
||||
le=torch.iinfo(torch.long).max)
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
|
@ -3,9 +3,10 @@ from typing import Optional, Union
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_local_outlines_guided_decoding_logits_processor,
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
|
||||
@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||
|
||||
|
||||
def get_local_guided_decoding_logits_processor(
|
||||
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
# request = _adapt_request_for_tool_use(request)
|
||||
|
||||
if guided_decoding_backend == 'outlines':
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||
|
||||
|
||||
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
|
||||
ChatCompletionRequest]):
|
||||
# the legacy completion API does not support tool use
|
||||
|
38
vllm/model_executor/guided_decoding/guided_fields.py
Normal file
38
vllm/model_executor/guided_decoding/guided_fields.py
Normal file
@ -0,0 +1,38 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMGuidedOptions(TypedDict, total=False):
|
||||
guided_json: Union[Dict, BaseModel, str]
|
||||
guided_regex: str
|
||||
guided_choice: List[str]
|
||||
guided_grammar: str
|
||||
guided_decoding_backend: str
|
||||
guided_whitespace_pattern: str
|
||||
guided_json_object: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidedDecodingRequest:
|
||||
"""One of the fields will be used to retrieve the logit processor."""
|
||||
guided_json: Optional[Union[Dict, BaseModel, str]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[List[str]] = None
|
||||
guided_grammar: Optional[str] = None
|
||||
guided_decoding_backend: Optional[str] = None
|
||||
guided_whitespace_pattern: Optional[str] = None
|
||||
guided_json_object: Optional[bool] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum([
|
||||
self.guided_json is not None, self.guided_regex is not None,
|
||||
self.guided_choice is not None, self.guided_grammar is not None,
|
||||
self.guided_json_object is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding but multiple are "
|
||||
f"specified: {self.__dict__}")
|
@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_local_outlines_guided_decoding_logits_processor,
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
return logits_processor
|
||||
|
||||
|
||||
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_options: GuidedDecodingRequest,
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
|
||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer)
|
||||
character_level_parser: CharacterLevelParser
|
||||
if guided_options.guided_json:
|
||||
schema = _normalize_json_schema_object(guided_options.guided_json)
|
||||
character_level_parser = JsonSchemaParser(schema)
|
||||
elif guided_options.guided_choice:
|
||||
character_level_parser = UnionParser(
|
||||
[StringParser(choice) for choice in guided_options.guided_choice])
|
||||
elif guided_options.guided_regex:
|
||||
character_level_parser = RegexParser(guided_options.guided_regex)
|
||||
elif guided_options.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
elif guided_options.guided_json_object:
|
||||
# None means any json object
|
||||
character_level_parser = JsonSchemaParser(None)
|
||||
else:
|
||||
return None
|
||||
|
||||
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
||||
character_level_parser)
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
||||
if isinstance(schema, str):
|
||||
return json_loads(schema)
|
||||
|
@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
|
||||
@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
mode, request.guided_whitespace_pattern)
|
||||
|
||||
|
||||
def get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
guide, mode = _get_guide_and_mode(guided_options)
|
||||
if not guide or not mode:
|
||||
return None
|
||||
|
||||
return _get_logits_processor(guide, tokenizer, mode,
|
||||
guided_options.guided_whitespace_pattern)
|
||||
|
||||
|
||||
def _get_guide_and_mode(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
GuidedDecodingRequest]
|
||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
||||
|
||||
if request.guided_json:
|
||||
@ -102,7 +123,8 @@ def _get_guide_and_mode(
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
elif request.guided_grammar:
|
||||
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
|
||||
elif (request.response_format is not None
|
||||
elif (not isinstance(request, GuidedDecodingRequest)
|
||||
and request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user