[Feature] Add load generation config from model (#11164)
Signed-off-by: liuyanyi <wolfsonliu@163.com> Signed-off-by: Yanyi Liu <wolfsonliu@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
98356735ac
commit
5aef49806d
30
examples/offline_inference_with_default_generation_config.py
Normal file
30
examples/offline_inference_with_default_generation_config.py
Normal file
@ -0,0 +1,30 @@
|
||||
from vllm import LLM
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM with built-in default generation config.
|
||||
# The generation config is set to None by default to keep
|
||||
# the behavior consistent with the previous version.
|
||||
# If you want to use the default generation config from the model,
|
||||
# you should set the generation_config to "auto".
|
||||
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto")
|
||||
|
||||
# Load the default sampling parameters from the model.
|
||||
sampling_params = llm.get_default_sampling_params()
|
||||
# Modify the sampling parameters if needed.
|
||||
sampling_params.temperature = 0.5
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
@ -31,6 +32,10 @@ class MockModelConfig:
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
|
||||
def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"temperature": 0.5,
|
||||
"repetition_penalty": 1.05
|
||||
}
|
||||
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.5
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
# Test the param when user set it
|
||||
req.temperature = 0.1
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.1
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
# Test When temperature==0.0
|
||||
req.temperature = 0.0
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.0
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
@ -27,7 +27,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
try_get_generation_config, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, print_warning_once, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
@ -160,6 +161,7 @@ class ModelConfig:
|
||||
logits processor qualified names that can be passed with the
|
||||
`logits_processors` extra completion argument. Defaults to None,
|
||||
which allows no processors.
|
||||
generation_config: Configuration parameter file for generation.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -218,7 +220,8 @@ class ModelConfig:
|
||||
disable_mm_preprocessor_cache: bool = False,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||
logits_processor_pattern: Optional[str] = None) -> None:
|
||||
logits_processor_pattern: Optional[str] = None,
|
||||
generation_config: Optional[str] = None) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@ -348,6 +351,8 @@ class ModelConfig:
|
||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||
self.logits_processor_pattern = logits_processor_pattern
|
||||
|
||||
self.generation_config = generation_config
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
@ -813,6 +818,56 @@ class ModelConfig:
|
||||
|
||||
return self.multimodal_config
|
||||
|
||||
def try_get_generation_config(self) -> Dict[str, Any]:
|
||||
if self.generation_config is None or self.generation_config == "auto":
|
||||
config = try_get_generation_config(
|
||||
self.model,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
revision=self.revision,
|
||||
)
|
||||
else:
|
||||
config = try_get_generation_config(
|
||||
self.generation_config,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
|
||||
def get_diff_sampling_param(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method returns a dictionary containing the parameters
|
||||
that differ from the default sampling parameters, but only
|
||||
if `generation_config` is set. If `generation_config` is not
|
||||
set, an empty dictionary is returned.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with the differing sampling
|
||||
parameters if `generation_config` is set, otherwise an
|
||||
empty dictionary.
|
||||
"""
|
||||
if self.generation_config is None:
|
||||
# When generation_config is not set
|
||||
return {}
|
||||
config = self.try_get_generation_config()
|
||||
available_params = [
|
||||
"repetition_penalty",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"min_p",
|
||||
]
|
||||
if any(p in config for p in available_params):
|
||||
diff_sampling_param = {
|
||||
p: config.get(p)
|
||||
for p in available_params if config.get(p) is not None
|
||||
}
|
||||
else:
|
||||
diff_sampling_param = {}
|
||||
return diff_sampling_param
|
||||
|
||||
@property
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
"""Extract the HF encoder/decoder model flag."""
|
||||
|
@ -197,6 +197,8 @@ class EngineArgs:
|
||||
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
|
||||
generation_config: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
@ -942,6 +944,16 @@ class EngineArgs:
|
||||
default="auto",
|
||||
help='The worker class to use for distributed execution.')
|
||||
|
||||
parser.add_argument(
|
||||
"--generation-config",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The folder path to the generation config. "
|
||||
"Defaults to None, will use the default generation config in vLLM. "
|
||||
"If set to 'auto', the generation config will be automatically "
|
||||
"loaded from model. If set to a folder path, the generation config "
|
||||
"will be loaded from the specified folder path.")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -985,7 +997,8 @@ class EngineArgs:
|
||||
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern)
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
generation_config=self.generation_config)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
return LoadConfig(
|
||||
|
@ -5,8 +5,8 @@ from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
||||
List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
@ -52,7 +52,6 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
@ -65,20 +64,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||
|
||||
@ -274,8 +259,8 @@ class LLMEngine:
|
||||
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
self.model_config)
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
|
@ -258,6 +258,13 @@ class LLM:
|
||||
else:
|
||||
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
def get_default_sampling_params(self) -> SamplingParams:
|
||||
diff_sampling_param = (
|
||||
self.llm_engine.model_config.get_diff_sampling_param())
|
||||
if diff_sampling_param:
|
||||
return SamplingParams.from_optional(**diff_sampling_param)
|
||||
return SamplingParams()
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -441,7 +448,7 @@ class LLM:
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
|
@ -211,8 +211,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
@ -224,9 +224,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
top_k: Optional[int] = None
|
||||
min_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
@ -348,15 +348,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
# Default sampling parameters for chat completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": -1,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None
|
||||
) -> BeamSearchParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
@ -367,13 +384,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
@ -403,11 +443,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
@ -584,15 +624,15 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
top_k: Optional[int] = None
|
||||
min_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
@ -669,14 +709,30 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
# Default sampling parameters for completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": -1,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None
|
||||
) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 1.0)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
@ -687,12 +743,35 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.logprobs
|
||||
@ -718,11 +797,11 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
|
@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
"been registered") from e
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||
if diff_sampling_param:
|
||||
logger.info("Overwriting default chat sampling param with: %s",
|
||||
diff_sampling_param)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
# Build default sampling params
|
||||
default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
default_max_tokens, default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern)
|
||||
self.model_config.logits_processor_pattern,
|
||||
default_sampling_params)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request_prompts[i],
|
||||
|
@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||
if diff_sampling_param:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
diff_sampling_param)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
# Build default sampling params
|
||||
default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
default_max_tokens, default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern)
|
||||
self.model_config.logits_processor_pattern,
|
||||
default_sampling_params)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
from typing import Mapping, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
||||
@ -34,8 +33,8 @@ class Processor:
|
||||
self.lora_config = lora_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
self.generation_config_fields = model_config.try_get_generation_config(
|
||||
)
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
@ -181,16 +180,3 @@ class Processor:
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
|
Loading…
x
Reference in New Issue
Block a user