Bugfix: fix broken of download models from modelscope (#5233)
Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
parent
89c920785f
commit
4efff036f0
@ -53,6 +53,27 @@ def test_gc():
|
|||||||
assert allocated < 50 * 1024 * 1024
|
assert allocated < 50 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_from_modelscope(monkeypatch):
|
||||||
|
# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
|
||||||
|
MODELSCOPE_MODEL_NAME = "qwen/Qwen1.5-0.5B-Chat"
|
||||||
|
monkeypatch.setenv("VLLM_USE_MODELSCOPE", "True")
|
||||||
|
try:
|
||||||
|
llm = LLM(model=MODELSCOPE_MODEL_NAME)
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
assert len(outputs) == 4
|
||||||
|
finally:
|
||||||
|
monkeypatch.delenv("VLLM_USE_MODELSCOPE", raising=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
@ -113,6 +113,10 @@ class ModelConfig:
|
|||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.code_revision = code_revision
|
self.code_revision = code_revision
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
|
# The tokenizer version is consistent with the model version by default.
|
||||||
|
if tokenizer_revision is None:
|
||||||
|
self.tokenizer_revision = revision
|
||||||
|
else:
|
||||||
self.tokenizer_revision = tokenizer_revision
|
self.tokenizer_revision = tokenizer_revision
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
self.quantization_param_path = quantization_param_path
|
self.quantization_param_path = quantization_param_path
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||||
JAISConfig, MPTConfig, RWConfig)
|
JAISConfig, MPTConfig, RWConfig)
|
||||||
@ -24,6 +25,10 @@ def get_config(model: str,
|
|||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
|
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
|
||||||
try:
|
try:
|
||||||
|
if VLLM_USE_MODELSCOPE:
|
||||||
|
from modelscope import AutoConfig
|
||||||
|
else:
|
||||||
|
from transformers import AutoConfig
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model,
|
model,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user