[Bugfix] Support eos_token_id
from config.json
(#5954)
This commit is contained in:
parent
329df38f1a
commit
51e971d39e
31
tests/tokenization/test_get_eos.py
Normal file
31
tests/tokenization/test_get_eos.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
This test file includes some cases where it is inappropriate to
|
||||
only get the `eos_token_id` from the tokenizer as defined by
|
||||
:meth:`vllm.LLMEngine._get_eos_token_id`.
|
||||
"""
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def test_get_llama3_eos_token():
|
||||
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 128009
|
||||
|
||||
generation_config = try_get_generation_config(model_name,
|
||||
trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == [128001, 128009]
|
||||
|
||||
|
||||
def test_get_blip2_eos_token():
|
||||
model_name = "Salesforce/blip2-opt-2.7b"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 2
|
||||
|
||||
generation_config = try_get_generation_config(model_name,
|
||||
trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == 50118
|
@ -1,10 +1,10 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, TypeVar, Union
|
||||
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@ -46,16 +47,18 @@ logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig):
|
||||
try:
|
||||
return GenerationConfig.from_pretrained(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
).to_diff_dict()
|
||||
except OSError:
|
||||
# Not found.
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
|
||||
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
return config.text_config
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def try_get_generation_config(
|
||||
model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
) -> Optional[GenerationConfig]:
|
||||
try:
|
||||
return GenerationConfig.from_pretrained(
|
||||
model,
|
||||
revision=revision,
|
||||
)
|
||||
except OSError: # Not found
|
||||
try:
|
||||
config = get_config(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
)
|
||||
return GenerationConfig.from_model_config(config)
|
||||
except OSError: # Not found
|
||||
return None
|
||||
|
Loading…
x
Reference in New Issue
Block a user