2024-08-24 00:51:38 -07:00
|
|
|
import sys
|
2024-11-27 09:26:14 -08:00
|
|
|
from contextlib import nullcontext
|
2024-08-24 00:51:38 -07:00
|
|
|
|
2024-11-27 09:26:14 -08:00
|
|
|
from vllm_test_utils import BlameResult, blame
|
2024-11-26 00:20:04 -08:00
|
|
|
|
2024-08-24 00:51:38 -07:00
|
|
|
from vllm import LLM, SamplingParams
|
2024-10-18 14:30:55 -07:00
|
|
|
from vllm.distributed import cleanup_dist_env_and_memory
|
2024-08-24 00:51:38 -07:00
|
|
|
|
|
|
|
|
2024-11-26 00:20:04 -08:00
|
|
|
def run_normal():
|
2024-08-24 00:51:38 -07:00
|
|
|
prompts = [
|
|
|
|
"Hello, my name is",
|
|
|
|
"The president of the United States is",
|
|
|
|
"The capital of France is",
|
|
|
|
"The future of AI is",
|
|
|
|
]
|
|
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
|
|
|
2024-10-18 14:30:55 -07:00
|
|
|
# Create an LLM without guided decoding as a baseline.
|
2024-08-24 00:51:38 -07:00
|
|
|
llm = LLM(model="facebook/opt-125m",
|
|
|
|
enforce_eager=True,
|
|
|
|
gpu_memory_utilization=0.3)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
for output in outputs:
|
|
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
2024-10-18 14:30:55 -07:00
|
|
|
# Destroy the LLM object and free up the GPU memory.
|
|
|
|
del llm
|
|
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
2024-11-26 00:20:04 -08:00
|
|
|
|
|
|
|
def run_lmfe(sample_regex):
|
2024-10-18 14:30:55 -07:00
|
|
|
# Create an LLM with guided decoding enabled.
|
2024-08-24 00:51:38 -07:00
|
|
|
llm = LLM(model="facebook/opt-125m",
|
|
|
|
enforce_eager=True,
|
|
|
|
guided_decoding_backend="lm-format-enforcer",
|
2024-10-17 21:47:27 -05:00
|
|
|
gpu_memory_utilization=0.6)
|
2024-08-24 00:51:38 -07:00
|
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts=[
|
|
|
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
|
|
|
] * 2,
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True,
|
|
|
|
guided_options_request=dict(guided_regex=sample_regex))
|
|
|
|
|
|
|
|
for output in outputs:
|
|
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
2024-11-26 00:20:04 -08:00
|
|
|
|
|
|
|
def test_lazy_outlines(sample_regex):
|
|
|
|
"""If users don't use guided decoding, outlines should not be imported.
|
|
|
|
"""
|
2024-08-24 00:51:38 -07:00
|
|
|
# make sure outlines is not imported
|
2024-11-26 00:20:04 -08:00
|
|
|
module_name = "outlines"
|
2024-11-27 09:26:14 -08:00
|
|
|
# In CI, we only check finally if the module is imported.
|
|
|
|
# If it is indeed imported, we can rerun the test with `use_blame=True`,
|
|
|
|
# which will trace every function call to find the first import location,
|
|
|
|
# and help find the root cause.
|
|
|
|
# We don't run it in CI by default because it is slow.
|
|
|
|
use_blame = False
|
|
|
|
context = blame(
|
|
|
|
lambda: module_name in sys.modules) if use_blame else nullcontext()
|
|
|
|
with context as result:
|
2024-11-26 00:20:04 -08:00
|
|
|
run_normal()
|
|
|
|
run_lmfe(sample_regex)
|
2024-11-27 09:26:14 -08:00
|
|
|
if use_blame:
|
|
|
|
assert isinstance(result, BlameResult)
|
|
|
|
print(f"the first import location is:\n{result.trace_stack}")
|
|
|
|
assert module_name not in sys.modules, (
|
|
|
|
f"Module {module_name} is imported. To see the first"
|
|
|
|
f" import location, run the test with `use_blame=True`.")
|