2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2025-03-17 11:42:45 -04:00
|
|
|
import json
|
2024-12-18 03:34:08 -03:00
|
|
|
import pickle
|
|
|
|
|
2024-04-16 08:54:57 +03:00
|
|
|
import pytest
|
2024-02-29 14:13:08 -08:00
|
|
|
import torch
|
2024-03-25 23:59:47 +09:00
|
|
|
from transformers import AutoTokenizer
|
2024-02-29 14:13:08 -08:00
|
|
|
|
2024-12-18 03:34:08 -03:00
|
|
|
from vllm.config import ModelConfig
|
2024-04-16 08:54:57 +03:00
|
|
|
from vllm.model_executor.guided_decoding import (
|
2024-12-18 03:34:08 -03:00
|
|
|
get_guided_decoding_logits_processor,
|
|
|
|
get_local_guided_decoding_logits_processor)
|
2024-04-16 08:54:57 +03:00
|
|
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
|
|
|
JSONLogitsProcessor, RegexLogitsProcessor)
|
2024-09-30 19:34:25 -06:00
|
|
|
from vllm.sampling_params import GuidedDecodingParams
|
2024-02-29 14:13:08 -08:00
|
|
|
|
2024-12-18 03:34:08 -03:00
|
|
|
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
2024-12-18 23:00:38 -05:00
|
|
|
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
2025-03-03 03:49:42 +08:00
|
|
|
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
|
|
|
|
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
2024-12-18 03:34:08 -03:00
|
|
|
|
2024-02-29 14:13:08 -08:00
|
|
|
|
2025-03-03 03:49:42 +08:00
|
|
|
# Initialize the tokenizer for the model here to avoid repeated loading
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def zephyr_7B_tokenzer():
|
|
|
|
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def deepseek_r1_qwen_tokenizer():
|
|
|
|
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
|
|
|
sample_json_schema):
|
2024-02-29 14:13:08 -08:00
|
|
|
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
2025-03-03 03:49:42 +08:00
|
|
|
regex_LP = RegexLogitsProcessor(sample_regex,
|
|
|
|
zephyr_7B_tokenzer,
|
|
|
|
reasoner=None)
|
2024-07-12 12:55:39 -04:00
|
|
|
json_LP = JSONLogitsProcessor(sample_json_schema,
|
2025-03-03 03:49:42 +08:00
|
|
|
zephyr_7B_tokenzer,
|
|
|
|
whitespace_pattern=None,
|
|
|
|
reasoner=None)
|
2024-02-29 14:13:08 -08:00
|
|
|
|
2025-03-03 03:49:42 +08:00
|
|
|
token_ids = zephyr_7B_tokenzer.encode(
|
2024-07-12 12:55:39 -04:00
|
|
|
f"Give an example IPv4 address with this regex: {sample_regex}")
|
2024-02-29 14:13:08 -08:00
|
|
|
tensor = torch.rand(32000)
|
|
|
|
original_tensor = torch.clone(tensor)
|
|
|
|
regex_LP(token_ids, tensor)
|
|
|
|
assert tensor.shape == original_tensor.shape
|
|
|
|
assert not torch.allclose(tensor, original_tensor)
|
|
|
|
|
2025-03-03 03:49:42 +08:00
|
|
|
token_ids = zephyr_7B_tokenzer.encode(
|
2024-07-12 12:55:39 -04:00
|
|
|
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
|
|
|
)
|
2024-02-29 14:13:08 -08:00
|
|
|
tensor = torch.rand(32000)
|
|
|
|
original_tensor = torch.clone(tensor)
|
|
|
|
json_LP(token_ids, tensor)
|
|
|
|
assert tensor.shape == original_tensor.shape
|
|
|
|
assert not torch.allclose(tensor, original_tensor)
|
2024-04-16 08:54:57 +03:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-12-18 23:00:38 -05:00
|
|
|
@pytest.mark.parametrize("backend", GUIDED_DECODING_BACKENDS)
|
2024-12-18 03:34:08 -03:00
|
|
|
@pytest.mark.parametrize("is_local", [True, False])
|
|
|
|
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
|
|
|
sample_regex,
|
2025-03-03 03:49:42 +08:00
|
|
|
sample_json_schema,
|
|
|
|
zephyr_7B_tokenzer):
|
2024-12-18 03:34:08 -03:00
|
|
|
|
|
|
|
config = ModelConfig(
|
|
|
|
MODEL_NAME,
|
|
|
|
task="generate",
|
|
|
|
tokenizer=MODEL_NAME,
|
|
|
|
tokenizer_mode="auto",
|
|
|
|
trust_remote_code=False,
|
|
|
|
seed=0,
|
|
|
|
dtype="bfloat16",
|
|
|
|
)
|
2025-03-03 03:49:42 +08:00
|
|
|
token_ids = zephyr_7B_tokenzer.encode(
|
2024-07-12 12:55:39 -04:00
|
|
|
f"Give an example IPv4 address with this regex: {sample_regex}")
|
2024-09-30 19:34:25 -06:00
|
|
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
2024-12-18 03:34:08 -03:00
|
|
|
|
|
|
|
regex_lp = get_local_guided_decoding_logits_processor(
|
2025-03-03 03:49:42 +08:00
|
|
|
regex_request, zephyr_7B_tokenzer, config) if is_local else \
|
2024-12-18 03:34:08 -03:00
|
|
|
await get_guided_decoding_logits_processor(
|
2025-03-03 03:49:42 +08:00
|
|
|
regex_request, zephyr_7B_tokenzer, config)
|
2024-04-16 08:54:57 +03:00
|
|
|
assert regex_lp is not None
|
|
|
|
tensor = torch.rand(32000)
|
|
|
|
original_tensor = torch.clone(tensor)
|
|
|
|
tensor = regex_lp(token_ids, tensor)
|
|
|
|
assert tensor.shape == original_tensor.shape
|
|
|
|
assert not torch.allclose(tensor, original_tensor)
|
|
|
|
|
2025-03-03 03:49:42 +08:00
|
|
|
token_ids = zephyr_7B_tokenzer.encode(
|
2024-07-12 12:55:39 -04:00
|
|
|
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
|
|
|
)
|
2024-09-30 19:34:25 -06:00
|
|
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
|
|
|
backend=backend)
|
2024-04-16 08:54:57 +03:00
|
|
|
json_lp = await get_guided_decoding_logits_processor(
|
2025-03-03 03:49:42 +08:00
|
|
|
json_request, zephyr_7B_tokenzer, config)
|
|
|
|
assert json_lp is not None
|
|
|
|
tensor = torch.rand(32000)
|
|
|
|
original_tensor = torch.clone(tensor)
|
|
|
|
tensor = json_lp(token_ids, tensor)
|
|
|
|
assert tensor.shape == original_tensor.shape
|
|
|
|
assert not torch.allclose(tensor, original_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
|
|
|
|
@pytest.mark.parametrize("is_local", [True, False])
|
|
|
|
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
|
|
|
|
async def test_guided_logits_processor_with_reasoning(
|
|
|
|
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
|
|
|
|
sample_json_schema, deepseek_r1_qwen_tokenizer):
|
|
|
|
|
|
|
|
config = ModelConfig(
|
|
|
|
REASONING_MODEL_NAME,
|
|
|
|
task="generate",
|
|
|
|
tokenizer=REASONING_MODEL_NAME,
|
|
|
|
tokenizer_mode="auto",
|
|
|
|
trust_remote_code=False,
|
|
|
|
seed=0,
|
|
|
|
dtype="bfloat16",
|
|
|
|
)
|
|
|
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
|
|
|
f"Give an example IPv4 address with this regex: {sample_regex}."
|
|
|
|
"<think>here is the thinking process")
|
|
|
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
|
|
|
|
|
|
|
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
|
|
|
|
deepseek_r1_qwen_tokenizer, config,
|
|
|
|
reasoning_backend) if is_local else \
|
|
|
|
await get_guided_decoding_logits_processor(
|
|
|
|
regex_request, deepseek_r1_qwen_tokenizer, config,
|
|
|
|
reasoning_backend)
|
|
|
|
assert regex_lp is not None
|
|
|
|
tensor = torch.rand(32000)
|
|
|
|
original_tensor = torch.clone(tensor)
|
|
|
|
tensor = regex_lp(token_ids, tensor)
|
|
|
|
assert tensor.shape == original_tensor.shape
|
|
|
|
assert torch.allclose(tensor, original_tensor)
|
|
|
|
|
|
|
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
|
|
|
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
|
|
|
"<think>here is the thinking process")
|
|
|
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
|
|
|
backend=backend)
|
|
|
|
json_lp = get_local_guided_decoding_logits_processor(
|
|
|
|
json_request, deepseek_r1_qwen_tokenizer, config,
|
|
|
|
reasoning_backend) if is_local else \
|
|
|
|
await get_guided_decoding_logits_processor(
|
|
|
|
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
|
|
|
assert json_lp is not None
|
|
|
|
tensor = torch.rand(32000)
|
|
|
|
original_tensor = torch.clone(tensor)
|
|
|
|
tensor = json_lp(token_ids, tensor)
|
|
|
|
assert tensor.shape == original_tensor.shape
|
|
|
|
assert torch.allclose(tensor, original_tensor)
|
|
|
|
|
|
|
|
# Thinking is over, so the tensor should change.
|
|
|
|
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
|
|
|
f"Give an employee profile that fits this schema: {sample_json_schema}."
|
|
|
|
"<think>here is the thinking process</think> Then")
|
|
|
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
|
|
|
backend=backend)
|
|
|
|
json_lp = get_local_guided_decoding_logits_processor(
|
|
|
|
json_request, deepseek_r1_qwen_tokenizer, config,
|
|
|
|
reasoning_backend) if is_local else \
|
|
|
|
await get_guided_decoding_logits_processor(
|
|
|
|
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
2024-04-16 08:54:57 +03:00
|
|
|
assert json_lp is not None
|
|
|
|
tensor = torch.rand(32000)
|
|
|
|
original_tensor = torch.clone(tensor)
|
|
|
|
tensor = json_lp(token_ids, tensor)
|
|
|
|
assert tensor.shape == original_tensor.shape
|
|
|
|
assert not torch.allclose(tensor, original_tensor)
|
2024-09-30 19:34:25 -06:00
|
|
|
|
|
|
|
|
|
|
|
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
|
|
|
with pytest.raises(ValueError,
|
|
|
|
match="You can only use one kind of guided"):
|
|
|
|
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError,
|
|
|
|
match="You can only use one kind of guided"):
|
|
|
|
GuidedDecodingParams(json=sample_json_schema, json_object=True)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError,
|
|
|
|
match="You can only use one kind of guided"):
|
|
|
|
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
|
|
|
|
|
|
|
|
with pytest.raises(ValueError,
|
|
|
|
match="You can only use one kind of guided"):
|
|
|
|
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
2024-12-18 03:34:08 -03:00
|
|
|
|
|
|
|
|
2025-02-20 13:07:58 -07:00
|
|
|
def test_guided_decoding_backend_options():
|
|
|
|
"""Test backend-specific options"""
|
|
|
|
params = GuidedDecodingParams(
|
|
|
|
backend="xgrammar:option-1,option-2,option-3")
|
|
|
|
assert params.backend_options() == ["option-1", "option-2", "option-3"]
|
|
|
|
|
|
|
|
no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
|
|
|
|
assert no_fallback.no_fallback()
|
|
|
|
|
|
|
|
|
2024-12-18 03:34:08 -03:00
|
|
|
def test_pickle_xgrammar_tokenizer_data():
|
|
|
|
try:
|
|
|
|
import xgrammar as xgr
|
|
|
|
except ImportError:
|
|
|
|
pytest.skip("Could not import xgrammar to run test")
|
|
|
|
|
|
|
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
|
|
|
TokenizerData)
|
2025-03-17 11:42:45 -04:00
|
|
|
tokenizer_data = TokenizerData(
|
|
|
|
metadata=
|
|
|
|
'{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}',
|
|
|
|
encoded_vocab=['!', '"', '#', '$', '%'],
|
|
|
|
)
|
2024-12-18 03:34:08 -03:00
|
|
|
pickled = pickle.dumps(tokenizer_data)
|
|
|
|
|
|
|
|
assert pickled is not None
|
|
|
|
|
|
|
|
depickled: TokenizerData = pickle.loads(pickled)
|
|
|
|
|
|
|
|
assert depickled is not None
|
2025-03-17 11:42:45 -04:00
|
|
|
assert json.loads(
|
|
|
|
depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value
|