diff --git a/requirements/common.txt b/requirements/common.txt index 2d52858a..14084b79 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -18,7 +18,7 @@ pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 -llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" +llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64" diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index d99ae59d..6bdfa0fa 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] +GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"] MODELS_TO_TEST = [ "Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410" ] @@ -30,12 +30,13 @@ def test_guided_json_completion( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_json_schema, - backend=guided_decoding_backend)) + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) outputs = llm.generate(prompts=[ f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}" @@ -111,13 +112,14 @@ def test_guided_json_object( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=1.0, - max_tokens=100, - n=2, - guided_decoding=GuidedDecodingParams( - json_object=True, - backend=guided_decoding_backend)) + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=100, + n=2, + guided_decoding=GuidedDecodingParams(json_object=True)) outputs = llm.generate( prompts=("Generate a JSON object with curly braces for a person with " @@ -137,12 +139,20 @@ def test_guided_json_object( # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) - assert isinstance(parsed_json, dict) + allowed_types: tuple[type, ...] = (dict, ) + if guided_decoding_backend == "xgrammar": + # TODO - we are currently too permissive with xgrammar and + # allow # any valid json (typically comes back as a list or + # object). We can fix this by specifying a jsonschema of + # {"type": "object"}, # but we need this fix in a release + # first: https://github.com/mlc-ai/xgrammar/pull/264 + allowed_types = (dict, list) + assert isinstance(parsed_json, allowed_types) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", - GUIDED_DECODING_BACKENDS_V1) + GUIDED_DECODING_BACKENDS_V1 + ["auto"]) @pytest.mark.parametrize("model_name", MODELS_TO_TEST) def test_guided_json_unsupported_schema( monkeypatch: pytest.MonkeyPatch, @@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=unsupported_json_schema, - backend=guided_decoding_backend)) - with pytest.raises(ValueError, - match="The provided JSON schema contains features " - "not supported by xgrammar."): - llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {unsupported_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) + if guided_decoding_backend == "xgrammar": + with pytest.raises(ValueError, + match="The provided JSON schema contains features " + "not supported by xgrammar."): + llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {unsupported_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + else: + # This should work for both "guidance" and "auto". + + outputs = llm.generate( + prompts=("Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}"), + sampling_params=sampling_params, + use_tqdm=True) + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + generated_text = output.outputs[0].text + assert generated_text is not None + print(generated_text) + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) @pytest.mark.skip_global_cleanup @@ -179,13 +211,14 @@ def test_guided_grammar_ebnf( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - grammar=sample_sql_ebnf, - backend=guided_decoding_backend)) + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( prompts=("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1"), @@ -222,13 +255,14 @@ def test_guided_grammar_lark( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - grammar=sample_sql_lark, - backend=guided_decoding_backend)) + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( prompts=("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1"), @@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - grammar="not a grammar", - backend=guided_decoding_backend)) - with pytest.raises(ValueError, - match="Failed to convert the grammar " - "from Lark to EBNF."): + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(grammar="not a grammar")) + with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( prompts=("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1"), @@ -298,12 +331,13 @@ def test_guided_regex( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - regex=sample_regex, - backend=guided_decoding_backend)) + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(regex=sample_regex)) outputs = llm.generate( prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" @@ -335,12 +369,13 @@ def test_guided_choice_completion( model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, max_model_len=1024) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - choice=sample_guided_choice, - backend=guided_decoding_backend)) + llm = LLM(model=model_name, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, diff --git a/vllm/config.py b/vllm/config.py index a2e83af3..7390ec59 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2800,12 +2800,17 @@ class DecodingConfig: return hash_str def __post_init__(self): - valid_guided_backends = [ - 'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance' + v0_valid_guided_backends = [ + 'outlines', 'lm-format-enforcer', 'xgrammar' ] + v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto'] backend = GuidedDecodingParams( backend=self.guided_decoding_backend).backend_name + if envs.VLLM_USE_V1: + valid_guided_backends = v1_valid_guided_backends + else: + valid_guided_backends = v0_valid_guided_backends if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}'," f" must be one of {valid_guided_backends}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 38a47a84..80fcbec6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -391,16 +391,13 @@ class EngineArgs: default='xgrammar', help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' - 'https://github.com/outlines-dev/outlines, ' - 'https://github.com/mlc-ai/xgrammar, and ' - 'https://github.com/noamgat/lm-format-enforcer.' - ' Can be overridden per request via guided_decoding_backend' - ' parameter.\n' - 'Backend-specific options can be supplied in a comma-separated ' - 'list following a colon after the backend name. Valid backends and ' - 'all available options are: [xgrammar:no-fallback, ' - 'xgrammar:disable-any-whitespace, ' - 'outlines:no-fallback, lm-format-enforcer:no-fallback]') + 'https://github.com/mlc-ai/xgrammar and ' + 'https://github.com/guidance-ai/llguidance.' + 'Valid backend values are "xgrammar", "guidance", and "auto". ' + 'With "auto", we will make opinionated choices based on request' + 'contents and what the backend libraries currently support, so ' + 'the behavior is subject to change in each release. ' + 'The default is xgrammar.') parser.add_argument( '--logits-processor-pattern', type=nullable_str, @@ -1539,9 +1536,9 @@ class EngineArgs: recommend_to_remove=False) return False - # Only support Xgrammar for guided decoding so far. + # Xgrammar and Guidance are supported. SUPPORTED_GUIDED_DECODING = [ - "xgrammar", "xgrammar:disable-any-whitespace" + "xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto" ] if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: _raise_or_fallback(feature_name="--guided-decoding-backend", diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 8ba06336..ffd12d5f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -4,7 +4,6 @@ import time from collections.abc import Mapping from typing import Optional, Union -import vllm.platforms from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) @@ -20,7 +19,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.structured_output.utils import validate_structured_output_request +from vllm.v1.structured_output.backend_guidance import ( + validate_guidance_grammar) +from vllm.v1.structured_output.utils import ( + validate_structured_output_request_xgrammar) class Processor: @@ -120,7 +122,9 @@ class Processor: if not params.guided_decoding or not self.decoding_config: return - supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"] + supported_backends = [ + "xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto" + ] engine_level_backend = self.decoding_config.guided_decoding_backend if engine_level_backend not in supported_backends: raise ValueError(f"Only {supported_backends} structured output is " @@ -134,10 +138,31 @@ class Processor: else: params.guided_decoding.backend = engine_level_backend - if vllm.platforms.current_platform.is_tpu(): - raise ValueError("Structured output is not supported on TPU.") + # Request content validation - validate_structured_output_request(params) + if engine_level_backend == "xgrammar": + # xgrammar with no fallback + validate_structured_output_request_xgrammar(params) + params.guided_decoding.backend = "xgrammar" + elif engine_level_backend == "auto": + # "auto" is an opt-in to opinionated behavior where we try to + # choose a backend based on request contents. This is not the + # default as it is less predictable and subject to change + # between releases as feature support changes. + try: + validate_structured_output_request_xgrammar(params) + params.guided_decoding.backend = "xgrammar" + except ValueError: + # The request includes some jsonschema feature(s) that + # are not supported in xgrammar. Fall back to guidance. + params.guided_decoding.backend = "guidance" + + if params.guided_decoding.backend == "guidance": + # TODO ideally we would have the LLTokenizer here as Lark syntax + # allows <|special_token|> and similar, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens + # Without tokenizer these are disallowed in grammars. + validate_guidance_grammar(params, tokenizer=None) def process_inputs( self, diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 0fdc45c2..6c6a8a7b 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) @@ -50,6 +51,8 @@ class StructuredOutputManager: XgrammarBackend) self.backend = XgrammarBackend(self.vllm_config) + elif backend_name == "guidance": + self.backend = GuidanceBackend(self.vllm_config) else: raise ValueError( f"Unsupported structured output backend: {backend_name}") diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py new file mode 100644 index 00000000..1e274ad0 --- /dev/null +++ b/vllm/v1/structured_output/backend_guidance.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) +from vllm.v1.structured_output.request import get_structured_output_key + +if TYPE_CHECKING: + import llguidance + import llguidance.hf as llguidance_hf + import llguidance.torch as llguidance_torch +else: + llguidance = LazyLoader("llguidance", globals(), "llguidance") + llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") + llguidance_torch = LazyLoader("llguidance.torch", globals(), + "llguidance.torch") + +logger = init_logger(__name__) + + +class GuidanceBackend(StructuredOutputBackend): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) # type: ignore[arg-type] + tokenizer_group.ping() + self.vllm_config = vllm_config + self.vocab_size = vllm_config.model_config.get_vocab_size() + + tokenizer = tokenizer_group.get_lora_tokenizer(None) + self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + self.serialized_grammar = serialize_guidance_grammar( + request_type, grammar_spec) + + ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, + self.serialized_grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + + r = GuidanceGrammar( + ll_matcher=ll_matcher, + ll_tokenizer=self.ll_tokenizer, + vocab_size=self.vocab_size, + ) + + r.check_error() + return r + + def allocate_token_bitmask(self, max_num_seqs: int): + return llguidance_torch.allocate_token_bitmask( + max_num_seqs, self.ll_tokenizer.vocab_size) + + +@dataclass +class GuidanceGrammar(StructuredOutputGrammar): + ll_matcher: llguidance.LLMatcher + ll_tokenizer: llguidance.LLTokenizer + vocab_size: int + printed_error: bool = False + terminated: bool = False + + def check_error(self): + if not self.printed_error: + err = self.ll_matcher.get_error() + if err: + self.printed_error = True + logger.warning("LLMatcher error: %s", err) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """Accepts a list of tokens and advances the parser. + + Returns True if the parser was advanced successfully. + Returns False if the parser failed to advance. + """ + + if self.ll_tokenizer.eos_token in tokens: + self.terminated = True + + if self.ll_matcher.is_stopped(): + return True + + # TODO - Add jump decoding support in the future: + # self.ll_matcher.compute_ff_bytes() - this should always work + # self.ll_matcher.compute_ff_tokens() - this only works for + # "canonical" tokenizers + # For conversion between the two, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md + + r = self.ll_matcher.consume_tokens(tokens) + + self.check_error() + + return r + + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + # this will automatically return [EOS] mask if the matcher is stopped + # or otherwise in an error state + llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx) + self.check_error() + + def is_terminated(self) -> bool: + return self.terminated + + def reset(self): + # This method may be not needed anymore? TODO + self.ll_matcher.reset() + + +def serialize_guidance_grammar(request_type: StructuredOutputOptions, + grammar_spec: str) -> str: + if request_type == StructuredOutputOptions.JSON: + # TODO: make whitespace_flexible configurable + return llguidance.LLMatcher.grammar_from_json_schema( + grammar_spec, defaults={ + "whitespace_flexible": True, + }) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + return llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}', defaults={ + "whitespace_flexible": True, + }) + else: + if request_type == StructuredOutputOptions.REGEX: + tp = "regex" + elif request_type == StructuredOutputOptions.GRAMMAR: + tp = "grammar" + elif request_type == StructuredOutputOptions.CHOICE: + tp = "choice" + else: + logger.error("Validation should have already occurred. " + "Please file an issue.") + raise ValueError("grammar is not of valid supported types. " + f"({request_type!s})") + return llguidance.grammar_from(tp, grammar_spec) + + +def validate_guidance_grammar( + sampling_params: SamplingParams, + tokenizer: Optional[llguidance.LLTokenizer] = None) -> None: + tp, grm = get_structured_output_key(sampling_params) + guidance_grm = serialize_guidance_grammar(tp, grm) + err = llguidance.LLMatcher.validate_grammar(guidance_grm, + tokenizer=tokenizer) + if err: + raise ValueError(f"Grammar error: {err}") diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 718fa583..9e54b8bf 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -53,25 +53,30 @@ class StructuredOutputRequest: @functools.cached_property def structured_output_key(self) -> StructuredOutputKey: - params = self.sampling_params.guided_decoding - assert params is not None, "params can't be None." - if params.json is not None: - if not isinstance(params.json, str): - json_str = json.dumps(params.json) - else: - json_str = params.json - return (StructuredOutputOptions.JSON, json_str) - elif params.json_object: - return (StructuredOutputOptions.JSON_OBJECT, "") - elif params.regex is not None: - return (StructuredOutputOptions.REGEX, params.regex) - elif params.choice is not None: - if not isinstance(params.choice, str): - json_str = json.dumps(params.choice) - else: - json_str = params.choice - return (StructuredOutputOptions.CHOICE, json_str) - elif params.grammar is not None: - return (StructuredOutputOptions.GRAMMAR, params.grammar) + return get_structured_output_key(self.sampling_params) + + +def get_structured_output_key( + sampling_params: SamplingParams) -> StructuredOutputKey: + params = sampling_params.guided_decoding + assert params is not None, "params can't be None." + if params.json is not None: + if not isinstance(params.json, str): + json_str = json.dumps(params.json) else: - raise ValueError("No valid structured output parameter found") + json_str = params.json + return (StructuredOutputOptions.JSON, json_str) + elif params.json_object: + return (StructuredOutputOptions.JSON_OBJECT, "") + elif params.regex is not None: + return (StructuredOutputOptions.REGEX, params.regex) + elif params.choice is not None: + if not isinstance(params.choice, str): + json_str = json.dumps(params.choice) + else: + json_str = params.choice + return (StructuredOutputOptions.CHOICE, json_str) + elif params.grammar is not None: + return (StructuredOutputOptions.GRAMMAR, params.grammar) + else: + raise ValueError("No valid structured output parameter found") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index b373d31e..694e46f7 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -239,7 +239,7 @@ def choice_as_grammar(choice: list[str]) -> str: return grammar -def validate_structured_output_request( +def validate_structured_output_request_xgrammar( sampling_params: SamplingParams) -> None: """Validate that the request is supported by structured output.