[Misc] refactor Structured Outputs example (#16322)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid 2025-04-10 07:32:42 +08:00 committed by GitHub
parent cb391d85dc
commit 1bff42c4b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of guided decoding
to generate structured outputs using vLLM. It shows how to apply
different guided decoding techniques such as Choice, Regex, JSON schema,
and Grammar to produce structured and formatted results
based on specific prompts.
"""
from enum import Enum
@ -7,26 +14,21 @@ from pydantic import BaseModel
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
# Guided decoding by Choice (list of possible options)
guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"])
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
outputs = llm.generate(
prompts="Classify this sentiment: vLLM is wonderful!",
sampling_params=sampling_params,
)
print(outputs[0].outputs[0].text)
guided_decoding_params_choice = GuidedDecodingParams(
choice=["Positive", "Negative"])
sampling_params_choice = SamplingParams(
guided_decoding=guided_decoding_params_choice)
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
# Guided decoding by Regex
guided_decoding_params = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params = SamplingParams(guided_decoding=guided_decoding_params,
stop=["\n"])
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n")
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams(
guided_decoding=guided_decoding_params_regex, stop=["\n"])
prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n")
# Guided decoding by JSON using Pydantic schema
@ -44,16 +46,11 @@ class CarDescription(BaseModel):
json_schema = CarDescription.model_json_schema()
guided_decoding_params = GuidedDecodingParams(json=json_schema)
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
prompt = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's")
outputs = llm.generate(
prompts=prompt,
sampling_params=sampling_params,
)
print(outputs[0].outputs[0].text)
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(
guided_decoding=guided_decoding_params_json)
prompt_json = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's")
# Guided decoding by Grammar
simplified_sql_grammar = """
@ -64,12 +61,39 @@ table ::= "table_1 " | "table_2 "
condition ::= column "= " number
number ::= "1 " | "2 "
"""
guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar)
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
prompt = ("Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table.")
outputs = llm.generate(
prompts=prompt,
sampling_params=sampling_params,
)
print(outputs[0].outputs[0].text)
guided_decoding_params_grammar = GuidedDecodingParams(
grammar=simplified_sql_grammar)
sampling_params_grammar = SamplingParams(
guided_decoding=guided_decoding_params_grammar)
prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table.")
def format_output(title: str, output: str):
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
return outputs[0].outputs[0].text
def main():
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
format_output("Guided decoding by Choice", choice_output)
regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
format_output("Guided decoding by Regex", regex_output)
json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Guided decoding by JSON", json_output)
grammar_output = generate_output(prompt_grammar, sampling_params_grammar,
llm)
format_output("Guided decoding by Grammar", grammar_output)
if __name__ == "__main__":
main()