[Bugfix] Support eos_token_id from config.json (#5954)

This commit is contained in:
Cyrus Leung 2024-06-29 19:19:02 +08:00 committed by GitHub
parent 329df38f1a
commit 51e971d39e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 11 deletions

View File

@ -0,0 +1,31 @@
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118

View File

@ -1,10 +1,10 @@
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig, LoRAConfig, ModelConfig, ObservabilityConfig,
@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
SequenceStatus) SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group) get_tokenizer_group)
@ -46,16 +47,18 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig): def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
try: config = try_get_generation_config(
return GenerationConfig.from_pretrained(
model_config.model, model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision, revision=model_config.revision,
).to_diff_dict() )
except OSError:
# Not found. if config is None:
return {} return {}
return config.to_diff_dict()
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

View File

@ -1,7 +1,7 @@
import contextlib import contextlib
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
from transformers import PretrainedConfig from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config return config.text_config
else: else:
return config return config
def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(
model,
revision=revision,
)
except OSError: # Not found
try:
config = get_config(
model,
trust_remote_code=trust_remote_code,
revision=revision,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None