[Feature] Add load generation config from model (#11164)

Signed-off-by: liuyanyi <wolfsonliu@163.com>
Signed-off-by: Yanyi Liu <wolfsonliu@163.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Yanyi Liu 2024-12-19 18:50:38 +08:00 committed by GitHub
parent 98356735ac
commit 5aef49806d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 307 additions and 74 deletions

View File

@ -0,0 +1,30 @@
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM with built-in default generation config.
# The generation config is set to None by default to keep
# the behavior consistent with the previous version.
# If you want to use the default generation config from the model,
# you should set the generation_config to "auto".
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto")
# Load the default sampling parameters from the model.
sampling_params = llm.get_default_sampling_params()
# Modify the sampling parameters if needed.
sampling_params.temperature = 0.5
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -1,6 +1,7 @@
import asyncio
from contextlib import suppress
from dataclasses import dataclass
from typing import Optional
from unittest.mock import MagicMock
from vllm.config import MultiModalConfig
@ -31,6 +32,10 @@ class MockModelConfig:
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@dataclass
@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].max_tokens == 10
def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"temperature": 0.5,
"repetition_penalty": 1.05
}
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
# Initialize the serving chat
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
# Test the param when user set it
req.temperature = 0.1
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
# Test When temperature==0.0
req.temperature = 0.0
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

View File

@ -27,7 +27,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
@ -160,6 +161,7 @@ class ModelConfig:
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
"""
def compute_hash(self) -> str:
@ -218,7 +220,8 @@ class ModelConfig:
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
@ -348,6 +351,8 @@ class ModelConfig:
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
@ -813,6 +818,56 @@ class ModelConfig:
return self.multimodal_config
def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)
else:
config = try_get_generation_config(
self.generation_config,
trust_remote_code=self.trust_remote_code,
)
if config is None:
return {}
return config.to_diff_dict()
def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
return {}
config = self.try_get_generation_config()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
if any(p in config for p in available_params):
diff_sampling_param = {
p: config.get(p)
for p in available_params if config.get(p) is not None
}
else:
diff_sampling_param = {}
return diff_sampling_param
@property
def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""

View File

@ -197,6 +197,8 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = None
def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
@ -942,6 +944,16 @@ class EngineArgs:
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")
return parser
@classmethod
@ -985,7 +997,8 @@ class EngineArgs:
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern)
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config)
def create_load_config(self) -> LoadConfig:
return LoadConfig(

View File

@ -5,8 +5,8 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
@ -52,7 +52,6 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
@ -65,20 +64,6 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
@ -274,8 +259,8 @@ class LLMEngine:
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
self.model_config)
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,

View File

@ -258,6 +258,13 @@ class LLM:
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
return SamplingParams.from_optional(**diff_sampling_param)
return SamplingParams()
@overload
def generate(
self,
@ -441,7 +448,7 @@ class LLM:
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
sampling_params = self.get_default_sampling_params()
self._validate_and_add_requests(
prompts=parsed_prompts,

View File

@ -211,8 +211,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
temperature: Optional[float] = None
top_p: Optional[float] = None
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
@ -224,9 +224,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
top_k: Optional[int] = None
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
@ -348,15 +348,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
# Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return BeamSearchParams(
beam_width=n,
@ -367,13 +384,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
@ -403,11 +443,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
@ -584,15 +624,15 @@ class CompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
temperature: Optional[float] = None
top_p: Optional[float] = None
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
top_k: Optional[int] = None
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
@ -669,14 +709,30 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
return BeamSearchParams(
beam_width=n,
@ -687,12 +743,35 @@ class CompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
@ -718,11 +797,11 @@ class CompletionRequest(OpenAIBaseModel):
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,

View File

@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
"been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info("Overwriting default chat sampling param with: %s",
diff_sampling_param)
async def create_chat_completion(
self,
@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
default_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern)
self.model_config.logits_processor_pattern,
default_sampling_params)
self._log_inputs(request_id,
request_prompts[i],

View File

@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapters=prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
async def create_completion(
self,
@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
default_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern)
self.model_config.logits_processor_pattern,
default_sampling_params)
request_id_item = f"{request_id}-{i}"

View File

@ -1,5 +1,5 @@
import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from typing import Mapping, Optional, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
@ -34,8 +33,8 @@ class Processor:
self.lora_config = lora_config
self.tokenizer = tokenizer
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.generation_config_fields = model_config.try_get_generation_config(
)
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer,
mm_registry)
@ -181,16 +180,3 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()