[Misc] Update Mistral-3.1 example (#16147)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-07 11:57:37 +08:00 committed by GitHub
parent 3749e28774
commit 0a57386721
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,9 +13,14 @@ from vllm.sampling_params import SamplingParams
# - Server: # - Server:
# #
# ```bash # ```bash
# # Mistral format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --tokenizer-mode mistral --config-format mistral --load-format mistral \ # --tokenizer-mode mistral --config-format mistral --load-format mistral \
# --limit-mm-per-prompt 'image=4' --max-model-len 16384 # --limit-mm-per-prompt 'image=4' --max-model-len 16384
#
# # HF format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
# ``` # ```
# #
# - Client: # - Client:
@ -44,19 +49,22 @@ from vllm.sampling_params import SamplingParams
# python demo.py simple # python demo.py simple
# python demo.py advanced # python demo.py advanced
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
# These scripts have been tested on 2x L40 GPUs
def run_simple_demo(args: argparse.Namespace): def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
sampling_params = SamplingParams(max_tokens=8192) sampling_params = SamplingParams(max_tokens=8192)
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tokenizer_mode="mistral", tokenizer_mode="mistral" if args.format == "mistral" else "auto",
config_format="mistral", config_format="mistral" if args.format == "mistral" else "auto",
load_format="mistral", load_format="mistral" if args.format == "mistral" else "auto",
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
@ -88,17 +96,18 @@ def run_simple_demo(args: argparse.Namespace):
def run_advanced_demo(args: argparse.Namespace): def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
max_img_per_msg = 5 max_img_per_msg = 3
max_tokens_per_img = 4096 max_tokens_per_img = 4096
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7) sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tokenizer_mode="mistral", tokenizer_mode="mistral" if args.format == "mistral" else "auto",
config_format="mistral", config_format="mistral" if args.format == "mistral" else "auto",
load_format="mistral", load_format="mistral" if args.format == "mistral" else "auto",
limit_mm_per_prompt={"image": max_img_per_msg}, limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img, max_model_len=max_img_per_msg * max_tokens_per_img,
tensor_parallel_size=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
@ -166,6 +175,11 @@ def main():
help="Specify the demo mode: 'simple' or 'advanced'", help="Specify the demo mode: 'simple' or 'advanced'",
) )
parser.add_argument('--format',
choices=["mistral", "hf"],
default="mistral",
help='Specify the format of the model to load.')
parser.add_argument( parser.add_argument(
'--disable-mm-preprocessor-cache', '--disable-mm-preprocessor-cache',
action='store_true', action='store_true',