2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-11-18 18:52:12 +01:00
|
|
|
from enum import Enum
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
from vllm import LLM, SamplingParams
|
|
|
|
from vllm.sampling_params import GuidedDecodingParams
|
|
|
|
|
|
|
|
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
|
|
|
|
|
|
|
|
# Guided decoding by Choice (list of possible options)
|
|
|
|
guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"])
|
|
|
|
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts="Classify this sentiment: vLLM is wonderful!",
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
)
|
|
|
|
print(outputs[0].outputs[0].text)
|
|
|
|
|
|
|
|
# Guided decoding by Regex
|
2025-04-08 18:34:09 -04:00
|
|
|
guided_decoding_params = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
|
2024-11-18 18:52:12 +01:00
|
|
|
sampling_params = SamplingParams(guided_decoding=guided_decoding_params,
|
|
|
|
stop=["\n"])
|
|
|
|
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
|
|
|
|
"End in .com and new line. Example result:"
|
|
|
|
"alan.turing@enigma.com\n")
|
|
|
|
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
|
|
|
print(outputs[0].outputs[0].text)
|
|
|
|
|
|
|
|
|
|
|
|
# Guided decoding by JSON using Pydantic schema
|
|
|
|
class CarType(str, Enum):
|
|
|
|
sedan = "sedan"
|
|
|
|
suv = "SUV"
|
|
|
|
truck = "Truck"
|
|
|
|
coupe = "Coupe"
|
|
|
|
|
|
|
|
|
|
|
|
class CarDescription(BaseModel):
|
|
|
|
brand: str
|
|
|
|
model: str
|
|
|
|
car_type: CarType
|
|
|
|
|
|
|
|
|
|
|
|
json_schema = CarDescription.model_json_schema()
|
|
|
|
|
|
|
|
guided_decoding_params = GuidedDecodingParams(json=json_schema)
|
|
|
|
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
|
|
|
|
prompt = ("Generate a JSON with the brand, model and car_type of"
|
|
|
|
"the most iconic car from the 90's")
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts=prompt,
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
)
|
|
|
|
print(outputs[0].outputs[0].text)
|
|
|
|
|
|
|
|
# Guided decoding by Grammar
|
|
|
|
simplified_sql_grammar = """
|
2025-04-08 18:34:09 -04:00
|
|
|
root ::= select_statement
|
|
|
|
select_statement ::= "SELECT " column " from " table " where " condition
|
|
|
|
column ::= "col_1 " | "col_2 "
|
|
|
|
table ::= "table_1 " | "table_2 "
|
|
|
|
condition ::= column "= " number
|
|
|
|
number ::= "1 " | "2 "
|
2024-11-18 18:52:12 +01:00
|
|
|
"""
|
|
|
|
guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar)
|
|
|
|
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
|
|
|
|
prompt = ("Generate an SQL query to show the 'username' and 'email'"
|
|
|
|
"from the 'users' table.")
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts=prompt,
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
)
|
|
|
|
print(outputs[0].outputs[0].text)
|