# SPDX-License-Identifier: Apache-2.0 """ This file demonstrates the example usage of guided decoding to generate structured outputs using vLLM. It shows how to apply different guided decoding techniques such as Choice, Regex, JSON schema, and Grammar to produce structured and formatted results based on specific prompts. """ from enum import Enum from pydantic import BaseModel from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams # Guided decoding by Choice (list of possible options) guided_decoding_params_choice = GuidedDecodingParams( choice=["Positive", "Negative"]) sampling_params_choice = SamplingParams( guided_decoding=guided_decoding_params_choice) prompt_choice = "Classify this sentiment: vLLM is wonderful!" # Guided decoding by Regex guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( guided_decoding=guided_decoding_params_regex, stop=["\n"]) prompt_regex = ( "Generate an email address for Alan Turing, who works in Enigma." "End in .com and new line. Example result:" "alan.turing@enigma.com\n") # Guided decoding by JSON using Pydantic schema class CarType(str, Enum): sedan = "sedan" suv = "SUV" truck = "Truck" coupe = "Coupe" class CarDescription(BaseModel): brand: str model: str car_type: CarType json_schema = CarDescription.model_json_schema() guided_decoding_params_json = GuidedDecodingParams(json=json_schema) sampling_params_json = SamplingParams( guided_decoding=guided_decoding_params_json) prompt_json = ("Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's") # Guided decoding by Grammar simplified_sql_grammar = """ root ::= select_statement select_statement ::= "SELECT " column " from " table " where " condition column ::= "col_1 " | "col_2 " table ::= "table_1 " | "table_2 " condition ::= column "= " number number ::= "1 " | "2 " """ guided_decoding_params_grammar = GuidedDecodingParams( grammar=simplified_sql_grammar) sampling_params_grammar = SamplingParams( guided_decoding=guided_decoding_params_grammar) prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'" "from the 'users' table.") def format_output(title: str, output: str): print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}") def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) return outputs[0].outputs[0].text def main(): llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) choice_output = generate_output(prompt_choice, sampling_params_choice, llm) format_output("Guided decoding by Choice", choice_output) regex_output = generate_output(prompt_regex, sampling_params_regex, llm) format_output("Guided decoding by Regex", regex_output) json_output = generate_output(prompt_json, sampling_params_json, llm) format_output("Guided decoding by JSON", json_output) grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm) format_output("Guided decoding by Grammar", grammar_output) if __name__ == "__main__": main()