[Bugfix] Support eos_token_id from config.json (#5954)

This commit is contained in:
Cyrus Leung 2024-06-29 19:19:02 +08:00 committed by GitHub
parent 329df38f1a
commit 51e971d39e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 11 deletions

View File

@ -0,0 +1,31 @@
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118

View File

@ -1,10 +1,10 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer
from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@ -46,16 +47,18 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
)
if config is None:
return {}
return config.to_diff_dict()
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

View File

@ -1,7 +1,7 @@
import contextlib
from typing import Dict, Optional, Type
from transformers import PretrainedConfig
from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config
else:
return config
def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(
model,
revision=revision,
)
except OSError: # Not found
try:
config = get_config(
model,
trust_remote_code=trust_remote_code,
revision=revision,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None