diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index 5ec4dbe6..363b500e 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -1,4 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of guided decoding +to generate structured outputs using vLLM. It shows how to apply +different guided decoding techniques such as Choice, Regex, JSON schema, +and Grammar to produce structured and formatted results +based on specific prompts. +""" from enum import Enum @@ -7,26 +14,21 @@ from pydantic import BaseModel from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams -llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) - # Guided decoding by Choice (list of possible options) -guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) -sampling_params = SamplingParams(guided_decoding=guided_decoding_params) -outputs = llm.generate( - prompts="Classify this sentiment: vLLM is wonderful!", - sampling_params=sampling_params, -) -print(outputs[0].outputs[0].text) +guided_decoding_params_choice = GuidedDecodingParams( + choice=["Positive", "Negative"]) +sampling_params_choice = SamplingParams( + guided_decoding=guided_decoding_params_choice) +prompt_choice = "Classify this sentiment: vLLM is wonderful!" # Guided decoding by Regex -guided_decoding_params = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") -sampling_params = SamplingParams(guided_decoding=guided_decoding_params, - stop=["\n"]) -prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") -outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) -print(outputs[0].outputs[0].text) +guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") +sampling_params_regex = SamplingParams( + guided_decoding=guided_decoding_params_regex, stop=["\n"]) +prompt_regex = ( + "Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") # Guided decoding by JSON using Pydantic schema @@ -44,16 +46,11 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() - -guided_decoding_params = GuidedDecodingParams(json=json_schema) -sampling_params = SamplingParams(guided_decoding=guided_decoding_params) -prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") -outputs = llm.generate( - prompts=prompt, - sampling_params=sampling_params, -) -print(outputs[0].outputs[0].text) +guided_decoding_params_json = GuidedDecodingParams(json=json_schema) +sampling_params_json = SamplingParams( + guided_decoding=guided_decoding_params_json) +prompt_json = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") # Guided decoding by Grammar simplified_sql_grammar = """ @@ -64,12 +61,39 @@ table ::= "table_1 " | "table_2 " condition ::= column "= " number number ::= "1 " | "2 " """ -guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar) -sampling_params = SamplingParams(guided_decoding=guided_decoding_params) -prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") -outputs = llm.generate( - prompts=prompt, - sampling_params=sampling_params, -) -print(outputs[0].outputs[0].text) +guided_decoding_params_grammar = GuidedDecodingParams( + grammar=simplified_sql_grammar) +sampling_params_grammar = SamplingParams( + guided_decoding=guided_decoding_params_grammar) +prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") + + +def format_output(title: str, output: str): + print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}") + + +def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): + outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + return outputs[0].outputs[0].text + + +def main(): + llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) + + choice_output = generate_output(prompt_choice, sampling_params_choice, llm) + format_output("Guided decoding by Choice", choice_output) + + regex_output = generate_output(prompt_regex, sampling_params_regex, llm) + format_output("Guided decoding by Regex", regex_output) + + json_output = generate_output(prompt_json, sampling_params_json, llm) + format_output("Guided decoding by JSON", json_output) + + grammar_output = generate_output(prompt_grammar, sampling_params_grammar, + llm) + format_output("Guided decoding by Grammar", grammar_output) + + +if __name__ == "__main__": + main()