[Misc] refactor Structured Outputs example (#16322)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
parent
cb391d85dc
commit
1bff42c4b7
@ -1,4 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
This file demonstrates the example usage of guided decoding
|
||||||
|
to generate structured outputs using vLLM. It shows how to apply
|
||||||
|
different guided decoding techniques such as Choice, Regex, JSON schema,
|
||||||
|
and Grammar to produce structured and formatted results
|
||||||
|
based on specific prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@ -7,26 +14,21 @@ from pydantic import BaseModel
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
|
|
||||||
|
|
||||||
# Guided decoding by Choice (list of possible options)
|
# Guided decoding by Choice (list of possible options)
|
||||||
guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"])
|
guided_decoding_params_choice = GuidedDecodingParams(
|
||||||
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
|
choice=["Positive", "Negative"])
|
||||||
outputs = llm.generate(
|
sampling_params_choice = SamplingParams(
|
||||||
prompts="Classify this sentiment: vLLM is wonderful!",
|
guided_decoding=guided_decoding_params_choice)
|
||||||
sampling_params=sampling_params,
|
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
|
||||||
)
|
|
||||||
print(outputs[0].outputs[0].text)
|
|
||||||
|
|
||||||
# Guided decoding by Regex
|
# Guided decoding by Regex
|
||||||
guided_decoding_params = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
|
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
|
||||||
sampling_params = SamplingParams(guided_decoding=guided_decoding_params,
|
sampling_params_regex = SamplingParams(
|
||||||
stop=["\n"])
|
guided_decoding=guided_decoding_params_regex, stop=["\n"])
|
||||||
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
|
prompt_regex = (
|
||||||
"End in .com and new line. Example result:"
|
"Generate an email address for Alan Turing, who works in Enigma."
|
||||||
"alan.turing@enigma.com\n")
|
"End in .com and new line. Example result:"
|
||||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
"alan.turing@enigma.com\n")
|
||||||
print(outputs[0].outputs[0].text)
|
|
||||||
|
|
||||||
|
|
||||||
# Guided decoding by JSON using Pydantic schema
|
# Guided decoding by JSON using Pydantic schema
|
||||||
@ -44,16 +46,11 @@ class CarDescription(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
json_schema = CarDescription.model_json_schema()
|
json_schema = CarDescription.model_json_schema()
|
||||||
|
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
|
||||||
guided_decoding_params = GuidedDecodingParams(json=json_schema)
|
sampling_params_json = SamplingParams(
|
||||||
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
|
guided_decoding=guided_decoding_params_json)
|
||||||
prompt = ("Generate a JSON with the brand, model and car_type of"
|
prompt_json = ("Generate a JSON with the brand, model and car_type of"
|
||||||
"the most iconic car from the 90's")
|
"the most iconic car from the 90's")
|
||||||
outputs = llm.generate(
|
|
||||||
prompts=prompt,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
print(outputs[0].outputs[0].text)
|
|
||||||
|
|
||||||
# Guided decoding by Grammar
|
# Guided decoding by Grammar
|
||||||
simplified_sql_grammar = """
|
simplified_sql_grammar = """
|
||||||
@ -64,12 +61,39 @@ table ::= "table_1 " | "table_2 "
|
|||||||
condition ::= column "= " number
|
condition ::= column "= " number
|
||||||
number ::= "1 " | "2 "
|
number ::= "1 " | "2 "
|
||||||
"""
|
"""
|
||||||
guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar)
|
guided_decoding_params_grammar = GuidedDecodingParams(
|
||||||
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
|
grammar=simplified_sql_grammar)
|
||||||
prompt = ("Generate an SQL query to show the 'username' and 'email'"
|
sampling_params_grammar = SamplingParams(
|
||||||
"from the 'users' table.")
|
guided_decoding=guided_decoding_params_grammar)
|
||||||
outputs = llm.generate(
|
prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'"
|
||||||
prompts=prompt,
|
"from the 'users' table.")
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
print(outputs[0].outputs[0].text)
|
def format_output(title: str, output: str):
|
||||||
|
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
|
||||||
|
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||||
|
return outputs[0].outputs[0].text
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
|
||||||
|
|
||||||
|
choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
|
||||||
|
format_output("Guided decoding by Choice", choice_output)
|
||||||
|
|
||||||
|
regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
|
||||||
|
format_output("Guided decoding by Regex", regex_output)
|
||||||
|
|
||||||
|
json_output = generate_output(prompt_json, sampling_params_json, llm)
|
||||||
|
format_output("Guided decoding by JSON", json_output)
|
||||||
|
|
||||||
|
grammar_output = generate_output(prompt_grammar, sampling_params_grammar,
|
||||||
|
llm)
|
||||||
|
format_output("Guided decoding by Grammar", grammar_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user