[Bugfix] Support eos_token_id
from config.json
(#5954)
This commit is contained in:
parent
329df38f1a
commit
51e971d39e
31
tests/tokenization/test_get_eos.py
Normal file
31
tests/tokenization/test_get_eos.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
This test file includes some cases where it is inappropriate to
|
||||||
|
only get the `eos_token_id` from the tokenizer as defined by
|
||||||
|
:meth:`vllm.LLMEngine._get_eos_token_id`.
|
||||||
|
"""
|
||||||
|
from vllm.transformers_utils.config import try_get_generation_config
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llama3_eos_token():
|
||||||
|
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(model_name)
|
||||||
|
assert tokenizer.eos_token_id == 128009
|
||||||
|
|
||||||
|
generation_config = try_get_generation_config(model_name,
|
||||||
|
trust_remote_code=False)
|
||||||
|
assert generation_config is not None
|
||||||
|
assert generation_config.eos_token_id == [128001, 128009]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_blip2_eos_token():
|
||||||
|
model_name = "Salesforce/blip2-opt-2.7b"
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(model_name)
|
||||||
|
assert tokenizer.eos_token_id == 2
|
||||||
|
|
||||||
|
generation_config = try_get_generation_config(model_name,
|
||||||
|
trust_remote_code=False)
|
||||||
|
assert generation_config is not None
|
||||||
|
assert generation_config.eos_token_id == 50118
|
@ -1,10 +1,10 @@
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
|
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Type, TypeVar, Union
|
from typing import Set, Type, TypeVar, Union
|
||||||
|
|
||||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||||
@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
|||||||
SequenceStatus)
|
SequenceStatus)
|
||||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||||
init_tracer)
|
init_tracer)
|
||||||
|
from vllm.transformers_utils.config import try_get_generation_config
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||||
get_tokenizer_group)
|
get_tokenizer_group)
|
||||||
@ -46,16 +47,18 @@ logger = init_logger(__name__)
|
|||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
def _load_generation_config_dict(model_config: ModelConfig):
|
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||||
try:
|
config = try_get_generation_config(
|
||||||
return GenerationConfig.from_pretrained(
|
model_config.model,
|
||||||
model_config.model,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
revision=model_config.revision,
|
revision=model_config.revision,
|
||||||
).to_diff_dict()
|
)
|
||||||
except OSError:
|
|
||||||
# Not found.
|
if config is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
return config.to_diff_dict()
|
||||||
|
|
||||||
|
|
||||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from typing import Dict, Optional, Type
|
from typing import Dict, Optional, Type
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
|
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
|
|||||||
return config.text_config
|
return config.text_config
|
||||||
else:
|
else:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_generation_config(
|
||||||
|
model: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
) -> Optional[GenerationConfig]:
|
||||||
|
try:
|
||||||
|
return GenerationConfig.from_pretrained(
|
||||||
|
model,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
except OSError: # Not found
|
||||||
|
try:
|
||||||
|
config = get_config(
|
||||||
|
model,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
return GenerationConfig.from_model_config(config)
|
||||||
|
except OSError: # Not found
|
||||||
|
return None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user