[Misc] refactor argument parsing in examples (#16635)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
parent
b590adfdc1
commit
6ae996a873
@ -187,6 +187,33 @@ model_example_map = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description='Demo on using vLLM for offline inference with '
|
||||||
|
'audio language models')
|
||||||
|
parser.add_argument('--model-type',
|
||||||
|
'-m',
|
||||||
|
type=str,
|
||||||
|
default="ultravox",
|
||||||
|
choices=model_example_map.keys(),
|
||||||
|
help='Huggingface "model_type".')
|
||||||
|
parser.add_argument('--num-prompts',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Number of prompts to run.')
|
||||||
|
parser.add_argument("--num-audios",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
choices=[0, 1, 2],
|
||||||
|
help="Number of audio items per prompt.")
|
||||||
|
parser.add_argument("--seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the seed when initializing `vllm.LLM`.")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = args.model_type
|
model = args.model_type
|
||||||
if model not in model_example_map:
|
if model not in model_example_map:
|
||||||
@ -240,28 +267,5 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
args = parse_args()
|
||||||
description='Demo on using vLLM for offline inference with '
|
|
||||||
'audio language models')
|
|
||||||
parser.add_argument('--model-type',
|
|
||||||
'-m',
|
|
||||||
type=str,
|
|
||||||
default="ultravox",
|
|
||||||
choices=model_example_map.keys(),
|
|
||||||
help='Huggingface "model_type".')
|
|
||||||
parser.add_argument('--num-prompts',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='Number of prompts to run.')
|
|
||||||
parser.add_argument("--num-audios",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
choices=[0, 1, 2],
|
|
||||||
help="Number of audio items per prompt.")
|
|
||||||
parser.add_argument("--seed",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Set the seed when initializing `vllm.LLM`.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -12,9 +12,12 @@ prompts = [
|
|||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m")
|
llm = LLM(model="facebook/opt-125m")
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts.
|
||||||
|
# The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
@ -25,3 +28,7 @@ for output in outputs:
|
|||||||
print(f"Prompt: {prompt!r}")
|
print(f"Prompt: {prompt!r}")
|
||||||
print(f"Output: {generated_text!r}")
|
print(f"Output: {generated_text!r}")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -4,6 +4,24 @@ from vllm import LLM, EngineArgs
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
# Add engine args
|
||||||
|
engine_group = parser.add_argument_group("Engine arguments")
|
||||||
|
EngineArgs.add_cli_args(engine_group)
|
||||||
|
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||||
|
# Add sampling params
|
||||||
|
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||||
|
sampling_group.add_argument("--max-tokens", type=int)
|
||||||
|
sampling_group.add_argument("--temperature", type=float)
|
||||||
|
sampling_group.add_argument("--top-p", type=float)
|
||||||
|
sampling_group.add_argument("--top-k", type=int)
|
||||||
|
# Add example params
|
||||||
|
parser.add_argument("--chat-template-path", type=str)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main(args: dict):
|
def main(args: dict):
|
||||||
# Pop arguments not used by LLM
|
# Pop arguments not used by LLM
|
||||||
max_tokens = args.pop("max_tokens")
|
max_tokens = args.pop("max_tokens")
|
||||||
@ -82,18 +100,6 @@ def main(args: dict):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
parser = create_parser()
|
||||||
# Add engine args
|
|
||||||
engine_group = parser.add_argument_group("Engine arguments")
|
|
||||||
EngineArgs.add_cli_args(engine_group)
|
|
||||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
|
||||||
# Add sampling params
|
|
||||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
|
||||||
sampling_group.add_argument("--max-tokens", type=int)
|
|
||||||
sampling_group.add_argument("--temperature", type=float)
|
|
||||||
sampling_group.add_argument("--top-p", type=float)
|
|
||||||
sampling_group.add_argument("--top-k", type=int)
|
|
||||||
# Add example params
|
|
||||||
parser.add_argument("--chat-template-path", type=str)
|
|
||||||
args: dict = vars(parser.parse_args())
|
args: dict = vars(parser.parse_args())
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
# Set example specific arguments
|
||||||
|
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
task="classify",
|
||||||
|
enforce_eager=True)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def main(args: Namespace):
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -34,11 +44,5 @@ def main(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
args = parse_args()
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
|
||||||
# Set example specific arguments
|
|
||||||
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
|
|
||||||
task="classify",
|
|
||||||
enforce_eager=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
# Set example specific arguments
|
||||||
|
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
|
||||||
|
task="embed",
|
||||||
|
enforce_eager=True)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def main(args: Namespace):
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -34,11 +44,5 @@ def main(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
args = parse_args()
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
|
||||||
# Set example specific arguments
|
|
||||||
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
|
|
||||||
task="embed",
|
|
||||||
enforce_eager=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -4,6 +4,22 @@ from vllm import LLM, EngineArgs
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
# Add engine args
|
||||||
|
engine_group = parser.add_argument_group("Engine arguments")
|
||||||
|
EngineArgs.add_cli_args(engine_group)
|
||||||
|
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||||
|
# Add sampling params
|
||||||
|
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||||
|
sampling_group.add_argument("--max-tokens", type=int)
|
||||||
|
sampling_group.add_argument("--temperature", type=float)
|
||||||
|
sampling_group.add_argument("--top-p", type=float)
|
||||||
|
sampling_group.add_argument("--top-k", type=int)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main(args: dict):
|
def main(args: dict):
|
||||||
# Pop arguments not used by LLM
|
# Pop arguments not used by LLM
|
||||||
max_tokens = args.pop("max_tokens")
|
max_tokens = args.pop("max_tokens")
|
||||||
@ -35,23 +51,15 @@ def main(args: dict):
|
|||||||
]
|
]
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
|
print("-" * 50)
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
parser = create_parser()
|
||||||
# Add engine args
|
|
||||||
engine_group = parser.add_argument_group("Engine arguments")
|
|
||||||
EngineArgs.add_cli_args(engine_group)
|
|
||||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
|
||||||
# Add sampling params
|
|
||||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
|
||||||
sampling_group.add_argument("--max-tokens", type=int)
|
|
||||||
sampling_group.add_argument("--temperature", type=float)
|
|
||||||
sampling_group.add_argument("--top-p", type=float)
|
|
||||||
sampling_group.add_argument("--top-k", type=int)
|
|
||||||
args: dict = vars(parser.parse_args())
|
args: dict = vars(parser.parse_args())
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
# Set example specific arguments
|
||||||
|
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
|
||||||
|
task="score",
|
||||||
|
enforce_eager=True)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def main(args: Namespace):
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
text_1 = "What is the capital of France?"
|
text_1 = "What is the capital of France?"
|
||||||
@ -30,11 +40,5 @@ def main(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
args = parse_args()
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
|
||||||
# Set example specific arguments
|
|
||||||
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
|
|
||||||
task="score",
|
|
||||||
enforce_eager=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -34,6 +34,40 @@ from vllm import LLM, SamplingParams
|
|||||||
from vllm.utils import get_open_port
|
from vllm.utils import get_open_port
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||||
|
parser.add_argument("--model",
|
||||||
|
type=str,
|
||||||
|
default="ibm-research/PowerMoE-3b",
|
||||||
|
help="Model name or path")
|
||||||
|
parser.add_argument("--dp-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Data parallel size")
|
||||||
|
parser.add_argument("--tp-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Tensor parallel size")
|
||||||
|
parser.add_argument("--node-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Total number of nodes")
|
||||||
|
parser.add_argument("--node-rank",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Rank of the current node")
|
||||||
|
parser.add_argument("--master-addr",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Master node IP address")
|
||||||
|
parser.add_argument("--master-port",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Master node port")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||||
dp_master_port, GPUs_per_dp_rank):
|
dp_master_port, GPUs_per_dp_rank):
|
||||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||||
@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
args = parse_args()
|
||||||
parser.add_argument("--model",
|
|
||||||
type=str,
|
|
||||||
default="ibm-research/PowerMoE-3b",
|
|
||||||
help="Model name or path")
|
|
||||||
parser.add_argument("--dp-size",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Data parallel size")
|
|
||||||
parser.add_argument("--tp-size",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Tensor parallel size")
|
|
||||||
parser.add_argument("--node-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Total number of nodes")
|
|
||||||
parser.add_argument("--node-rank",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Rank of the current node")
|
|
||||||
parser.add_argument("--master-addr",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Master node IP address")
|
|
||||||
parser.add_argument("--master-port",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Master node port")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
dp_size = args.dp_size
|
dp_size = args.dp_size
|
||||||
tp_size = args.tp_size
|
tp_size = args.tp_size
|
||||||
|
@ -27,7 +27,7 @@ def load_prompts(dataset_path, num_prompts):
|
|||||||
return prompts[:num_prompts]
|
return prompts[:num_prompts]
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset",
|
"--dataset",
|
||||||
@ -45,7 +45,12 @@ def main():
|
|||||||
parser.add_argument("--enable_chunked_prefill", action='store_true')
|
parser.add_argument("--enable_chunked_prefill", action='store_true')
|
||||||
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
|
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
|
||||||
parser.add_argument("--temp", type=float, default=0)
|
parser.add_argument("--temp", type=float, default=0)
|
||||||
args = parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
|
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
|
||||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
# Set example specific arguments
|
||||||
|
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||||
|
task="embed",
|
||||||
|
trust_remote_code=True)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def main(args: Namespace):
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -40,11 +50,5 @@ def main(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
args = parse_args()
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
|
||||||
# Set example specific arguments
|
|
||||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
|
||||||
task="embed",
|
|
||||||
trust_remote_code=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs, PoolingParams
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
# Set example specific arguments
|
||||||
|
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||||
|
task="embed",
|
||||||
|
trust_remote_code=True)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def main(args: Namespace):
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -38,11 +48,5 @@ def main(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
args = parse_args()
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
|
||||||
# Set example specific arguments
|
|
||||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
|
||||||
task="embed",
|
|
||||||
trust_remote_code=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -8,17 +8,8 @@ from vllm import LLM, SamplingParams
|
|||||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||||
TokensPrompt, zip_enc_dec_prompts)
|
TokensPrompt, zip_enc_dec_prompts)
|
||||||
|
|
||||||
dtype = "float"
|
|
||||||
|
|
||||||
# Create a BART encoder/decoder model instance
|
|
||||||
llm = LLM(
|
|
||||||
model="facebook/bart-large-cnn",
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get BART tokenizer
|
|
||||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
|
||||||
|
|
||||||
|
def create_prompts(tokenizer):
|
||||||
# Test prompts
|
# Test prompts
|
||||||
#
|
#
|
||||||
# This section shows all of the valid ways to prompt an
|
# This section shows all of the valid ways to prompt an
|
||||||
@ -37,6 +28,7 @@ single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
|||||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||||
|
|
||||||
|
# ruff: noqa: E501
|
||||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||||
# no requirement that they be the same prompt type. Some example prompt-type
|
# no requirement that they be the same prompt type. Some example prompt-type
|
||||||
@ -70,25 +62,24 @@ zipped_prompt_list = zip_enc_dec_prompts(
|
|||||||
|
|
||||||
# - Let's put all of the above example prompts together into one list
|
# - Let's put all of the above example prompts together into one list
|
||||||
# which we will pass to the encoder/decoder LLM.
|
# which we will pass to the encoder/decoder LLM.
|
||||||
prompts = [
|
return [
|
||||||
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||||
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||||
] + zipped_prompt_list
|
] + zipped_prompt_list
|
||||||
|
|
||||||
|
|
||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(
|
def create_sampling_params():
|
||||||
|
return SamplingParams(
|
||||||
temperature=0,
|
temperature=0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
min_tokens=0,
|
min_tokens=0,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate output tokens from the prompts. The output is a list of
|
|
||||||
# RequestOutput objects that contain the prompt, generated
|
|
||||||
# text, and other information.
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
|
def print_outputs(outputs):
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
@ -99,3 +90,30 @@ for i, output in enumerate(outputs):
|
|||||||
f"Decoder prompt: {prompt!r}\n"
|
f"Decoder prompt: {prompt!r}\n"
|
||||||
f"Generated text: {generated_text!r}")
|
f"Generated text: {generated_text!r}")
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dtype = "float"
|
||||||
|
|
||||||
|
# Create a BART encoder/decoder model instance
|
||||||
|
llm = LLM(
|
||||||
|
model="facebook/bart-large-cnn",
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get BART tokenizer
|
||||||
|
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||||
|
|
||||||
|
prompts = create_prompts(tokenizer)
|
||||||
|
sampling_params = create_sampling_params()
|
||||||
|
|
||||||
|
# Generate output tokens from the prompts. The output is a list of
|
||||||
|
# RequestOutput objects that contain the prompt, generated
|
||||||
|
# text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
print_outputs(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -126,6 +126,23 @@ model_example_map = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description='Demo on using vLLM for offline inference with '
|
||||||
|
'vision language models for text generation')
|
||||||
|
parser.add_argument('--model-type',
|
||||||
|
'-m',
|
||||||
|
type=str,
|
||||||
|
default="mllama",
|
||||||
|
choices=model_example_map.keys(),
|
||||||
|
help='Huggingface "model_type".')
|
||||||
|
parser.add_argument("--seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the seed when initializing `vllm.LLM`.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = args.model_type
|
model = args.model_type
|
||||||
if model not in model_example_map:
|
if model not in model_example_map:
|
||||||
@ -171,19 +188,5 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
args = parse_args()
|
||||||
description='Demo on using vLLM for offline inference with '
|
|
||||||
'vision language models for text generation')
|
|
||||||
parser.add_argument('--model-type',
|
|
||||||
'-m',
|
|
||||||
type=str,
|
|
||||||
default="mllama",
|
|
||||||
choices=model_example_map.keys(),
|
|
||||||
help='Huggingface "model_type".')
|
|
||||||
parser.add_argument("--seed",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Set the seed when initializing `vllm.LLM`.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -168,7 +168,7 @@ def run_advanced_demo(args: argparse.Namespace):
|
|||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run a demo in simple or advanced mode.")
|
description="Run a demo in simple or advanced mode.")
|
||||||
|
|
||||||
@ -187,8 +187,11 @@ def main():
|
|||||||
'--disable-mm-preprocessor-cache',
|
'--disable-mm-preprocessor-cache',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
if args.mode == "simple":
|
if args.mode == "simple":
|
||||||
print("Running simple demo...")
|
print("Running simple demo...")
|
||||||
|
@ -34,8 +34,7 @@ def time_generation(llm: LLM, prompts: list[str],
|
|||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
|
|
||||||
template = (
|
template = (
|
||||||
"Below is an instruction that describes a task. Write a response "
|
"Below is an instruction that describes a task. Write a response "
|
||||||
"that appropriately completes the request.\n\n### Instruction:\n{}"
|
"that appropriately completes the request.\n\n### Instruction:\n{}"
|
||||||
@ -66,3 +65,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_generation(llm, prompts, sampling_params, "With speculation")
|
time_generation(llm, prompts, sampling_params, "With speculation")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -417,6 +417,38 @@ def run_model(input_data,
|
|||||||
return pred_imgs
|
return pred_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_file",
|
||||||
|
type=str,
|
||||||
|
default="./India_900498_S2Hand.tif",
|
||||||
|
help="Path to the file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
default="output",
|
||||||
|
help="Path to the directory where to save outputs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_indices",
|
||||||
|
default=[1, 2, 3, 8, 11, 12],
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
help=
|
||||||
|
"0-based indices of the six Prithvi channels to be selected from the "
|
||||||
|
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rgb_outputs",
|
||||||
|
action="store_true",
|
||||||
|
help="If present, output files will only contain RGB channels. "
|
||||||
|
"Otherwise, all bands will be saved.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
data_file: str,
|
data_file: str,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
@ -496,35 +528,7 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
|
||||||
|
|
||||||
parser.add_argument(
|
args = parse_args()
|
||||||
"--data_file",
|
|
||||||
type=str,
|
|
||||||
default="./India_900498_S2Hand.tif",
|
|
||||||
help="Path to the file.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_dir",
|
|
||||||
type=str,
|
|
||||||
default="output",
|
|
||||||
help="Path to the directory where to save outputs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input_indices",
|
|
||||||
default=[1, 2, 3, 8, 11, 12],
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
help=
|
|
||||||
"0-based indices of the six Prithvi channels to be selected from the "
|
|
||||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--rgb_outputs",
|
|
||||||
action="store_true",
|
|
||||||
help="If present, output files will only contain RGB channels. "
|
|
||||||
"Otherwise, all bands will be saved.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
main(**vars(args))
|
main(**vars(args))
|
||||||
|
@ -359,7 +359,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
|||||||
f" in folder {context.save_chrome_traces_folder}")
|
f" in folder {context.save_chrome_traces_folder}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def parse_args():
|
||||||
parser = FlexibleArgumentParser(description="""
|
parser = FlexibleArgumentParser(description="""
|
||||||
Profile a model
|
Profile a model
|
||||||
|
|
||||||
@ -449,7 +449,10 @@ Profile a model
|
|||||||
|
|
||||||
EngineArgs.add_cli_args(parser)
|
EngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
args = parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
context = ProfileContext(
|
context = ProfileContext(
|
||||||
engine_args=EngineArgs.from_cli_args(args),
|
engine_args=EngineArgs.from_cli_args(args),
|
||||||
**{
|
**{
|
||||||
@ -458,3 +461,8 @@ Profile a model
|
|||||||
if k in inspect.signature(ProfileContext).parameters
|
if k in inspect.signature(ProfileContext).parameters
|
||||||
})
|
})
|
||||||
run_profile(context, csv_output=args.csv, json_output=args.json)
|
run_profile(context, csv_output=args.csv, json_output=args.json)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
|
@ -29,6 +29,8 @@ from pathlib import Path
|
|||||||
from vllm import LLM, EngineArgs
|
from vllm import LLM, EngineArgs
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
parser = FlexibleArgumentParser()
|
parser = FlexibleArgumentParser()
|
||||||
EngineArgs.add_cli_args(parser)
|
EngineArgs.add_cli_args(parser)
|
||||||
parser.add_argument("--output",
|
parser.add_argument("--output",
|
||||||
@ -43,6 +45,7 @@ parser.add_argument("--max-file-size",
|
|||||||
type=str,
|
type=str,
|
||||||
default=5 * 1024**3,
|
default=5 * 1024**3,
|
||||||
help="max size (in bytes) of each safetensors file")
|
help="max size (in bytes) of each safetensors file")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -87,5 +90,5 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -18,8 +18,8 @@ prompts = [
|
|||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
|
def main():
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
|
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
|
||||||
|
|
||||||
@ -42,3 +42,7 @@ if __name__ == "__main__":
|
|||||||
# Add a buffer to wait for profiler in the background process
|
# Add a buffer to wait for profiler in the background process
|
||||||
# (in case MP is on) to finish writing profiling output.
|
# (in case MP is on) to finish writing profiling output.
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -1097,6 +1097,59 @@ def time_counter(enable: bool):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description='Demo on using vLLM for offline inference with '
|
||||||
|
'vision language models for text generation')
|
||||||
|
parser.add_argument('--model-type',
|
||||||
|
'-m',
|
||||||
|
type=str,
|
||||||
|
default="llava",
|
||||||
|
choices=model_example_map.keys(),
|
||||||
|
help='Huggingface "model_type".')
|
||||||
|
parser.add_argument('--num-prompts',
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help='Number of prompts to run.')
|
||||||
|
parser.add_argument('--modality',
|
||||||
|
type=str,
|
||||||
|
default="image",
|
||||||
|
choices=['image', 'video'],
|
||||||
|
help='Modality of the input.')
|
||||||
|
parser.add_argument('--num-frames',
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help='Number of frames to extract from the video.')
|
||||||
|
parser.add_argument("--seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the seed when initializing `vllm.LLM`.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-repeat-prob',
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
||||||
|
' (if enabled)')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--disable-mm-preprocessor-cache',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--time-generate',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, then print the total generate() call time')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--use-different-prompt-per-request',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, then use different prompt (with the same multi-modal '
|
||||||
|
'data) for each request.')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = args.model_type
|
model = args.model_type
|
||||||
if model not in model_example_map:
|
if model not in model_example_map:
|
||||||
@ -1175,55 +1228,5 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
args = parse_args()
|
||||||
description='Demo on using vLLM for offline inference with '
|
|
||||||
'vision language models for text generation')
|
|
||||||
parser.add_argument('--model-type',
|
|
||||||
'-m',
|
|
||||||
type=str,
|
|
||||||
default="llava",
|
|
||||||
choices=model_example_map.keys(),
|
|
||||||
help='Huggingface "model_type".')
|
|
||||||
parser.add_argument('--num-prompts',
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help='Number of prompts to run.')
|
|
||||||
parser.add_argument('--modality',
|
|
||||||
type=str,
|
|
||||||
default="image",
|
|
||||||
choices=['image', 'video'],
|
|
||||||
help='Modality of the input.')
|
|
||||||
parser.add_argument('--num-frames',
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help='Number of frames to extract from the video.')
|
|
||||||
parser.add_argument("--seed",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Set the seed when initializing `vllm.LLM`.")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--image-repeat-prob',
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
|
||||||
' (if enabled)')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--disable-mm-preprocessor-cache',
|
|
||||||
action='store_true',
|
|
||||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--time-generate',
|
|
||||||
action='store_true',
|
|
||||||
help='If True, then print the total generate() call time')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--use-different-prompt-per-request',
|
|
||||||
action='store_true',
|
|
||||||
help='If True, then use different prompt (with the same multi-modal '
|
|
||||||
'data) for each request.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -156,16 +156,13 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
|||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
|
||||||
run_encode(args.model_name, args.modality, args.seed)
|
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
"e5_v": run_e5_v,
|
"e5_v": run_e5_v,
|
||||||
"vlm2vec": run_vlm2vec,
|
"vlm2vec": run_vlm2vec,
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
def parse_args():
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description='Demo on using vLLM for offline inference with '
|
description='Demo on using vLLM for offline inference with '
|
||||||
'vision language models for multimodal embedding')
|
'vision language models for multimodal embedding')
|
||||||
@ -184,6 +181,13 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Set the seed when initializing `vllm.LLM`.")
|
help="Set the seed when initializing `vllm.LLM`.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
def main(args: Namespace):
|
||||||
|
run_encode(args.model_name, args.modality, args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -767,22 +767,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
|
|||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def parse_args():
|
||||||
model = args.model_type
|
|
||||||
method = args.method
|
|
||||||
seed = args.seed
|
|
||||||
|
|
||||||
image_urls = IMAGE_URLS[:args.num_images]
|
|
||||||
|
|
||||||
if method == "generate":
|
|
||||||
run_generate(model, QUESTION, image_urls, seed)
|
|
||||||
elif method == "chat":
|
|
||||||
run_chat(model, QUESTION, image_urls, seed)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid method: {method}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description='Demo on using vLLM for offline inference with '
|
description='Demo on using vLLM for offline inference with '
|
||||||
'vision language models that support multi-image input for text '
|
'vision language models that support multi-image input for text '
|
||||||
@ -808,6 +793,24 @@ if __name__ == "__main__":
|
|||||||
choices=list(range(1, 13)), # 12 is the max number of images
|
choices=list(range(1, 13)), # 12 is the max number of images
|
||||||
default=2,
|
default=2,
|
||||||
help="Number of images to use for the demo.")
|
help="Number of images to use for the demo.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
def main(args: Namespace):
|
||||||
|
model = args.model_type
|
||||||
|
method = args.method
|
||||||
|
seed = args.seed
|
||||||
|
|
||||||
|
image_urls = IMAGE_URLS[:args.num_images]
|
||||||
|
|
||||||
|
if method == "generate":
|
||||||
|
run_generate(model, QUESTION, image_urls, seed)
|
||||||
|
elif method == "chat":
|
||||||
|
run_chat(model, QUESTION, image_urls, seed)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid method: {method}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -58,6 +58,16 @@ def get_response(response: requests.Response) -> list[str]:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--n", type=int, default=1)
|
||||||
|
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||||
|
parser.add_argument("--stream", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace):
|
def main(args: Namespace):
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
api_url = f"http://{args.host}:{args.port}/generate"
|
api_url = f"http://{args.host}:{args.port}/generate"
|
||||||
@ -82,11 +92,5 @@ def main(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
args = parse_args()
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
|
||||||
parser.add_argument("--n", type=int, default=1)
|
|
||||||
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
|
||||||
parser.add_argument("--stream", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -1,11 +1,75 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Example for starting a Gradio OpenAI Chatbot Webserver
|
||||||
|
Start vLLM API server:
|
||||||
|
vllm serve meta-llama/Llama-2-7b-chat-hf
|
||||||
|
|
||||||
|
Start Gradio OpenAI Chatbot Webserver:
|
||||||
|
python examples/online_serving/gradio_openai_chatbot_webserver.py \
|
||||||
|
-m meta-llama/Llama-2-7b-chat-hf
|
||||||
|
|
||||||
|
Note that `pip install --upgrade gradio` is needed to run this example.
|
||||||
|
More details: https://github.com/gradio-app/gradio
|
||||||
|
|
||||||
|
If your antivirus software blocks the download of frpc for gradio,
|
||||||
|
you can install it manually by following these steps:
|
||||||
|
|
||||||
|
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
||||||
|
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||||
|
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
# Argument parser setup
|
|
||||||
|
def create_openai_client(api_key, base_url):
|
||||||
|
return OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
|
||||||
|
def format_history_to_openai(history):
|
||||||
|
history_openai_format = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a great AI assistant."
|
||||||
|
}]
|
||||||
|
for human, assistant in history:
|
||||||
|
history_openai_format.append({"role": "user", "content": human})
|
||||||
|
history_openai_format.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": assistant
|
||||||
|
})
|
||||||
|
return history_openai_format
|
||||||
|
|
||||||
|
|
||||||
|
def predict(message, history, client, model_name, temp, stop_token_ids):
|
||||||
|
# Format history to OpenAI chat format
|
||||||
|
history_openai_format = format_history_to_openai(history)
|
||||||
|
history_openai_format.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# Send request to OpenAI API (vLLM server)
|
||||||
|
stream = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=history_openai_format,
|
||||||
|
temperature=temp,
|
||||||
|
stream=True,
|
||||||
|
extra_body={
|
||||||
|
'repetition_penalty':
|
||||||
|
1,
|
||||||
|
'stop_token_ids':
|
||||||
|
[int(id.strip())
|
||||||
|
for id in stop_token_ids.split(',')] if stop_token_ids else []
|
||||||
|
})
|
||||||
|
|
||||||
|
# Collect all chunks and concatenate them into a full message
|
||||||
|
full_message = ""
|
||||||
|
for chunk in stream:
|
||||||
|
full_message += (chunk.choices[0].delta.content or "")
|
||||||
|
|
||||||
|
# Return the full message as a single response
|
||||||
|
return full_message
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Chatbot Interface with Customizable Parameters')
|
description='Chatbot Interface with Customizable Parameters')
|
||||||
parser.add_argument('--model-url',
|
parser.add_argument('--model-url',
|
||||||
@ -27,58 +91,39 @@ parser.add_argument('--stop-token-ids',
|
|||||||
help='Comma-separated stop token IDs')
|
help='Comma-separated stop token IDs')
|
||||||
parser.add_argument("--host", type=str, default=None)
|
parser.add_argument("--host", type=str, default=None)
|
||||||
parser.add_argument("--port", type=int, default=8001)
|
parser.add_argument("--port", type=int, default=8001)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_gradio_interface(client, model_name, temp, stop_token_ids):
|
||||||
|
|
||||||
|
def chat_predict(message, history):
|
||||||
|
return predict(message, history, client, model_name, temp,
|
||||||
|
stop_token_ids)
|
||||||
|
|
||||||
|
return gr.ChatInterface(fn=chat_predict,
|
||||||
|
title="Chatbot Interface",
|
||||||
|
description="A simple chatbot powered by vLLM")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
# Parse the arguments
|
# Parse the arguments
|
||||||
args = parser.parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
# Set OpenAI's API key and API base to use vLLM's API server
|
||||||
openai_api_key = "EMPTY"
|
openai_api_key = "EMPTY"
|
||||||
openai_api_base = args.model_url
|
openai_api_base = args.model_url
|
||||||
|
|
||||||
# Create an OpenAI client to interact with the API server
|
# Create an OpenAI client
|
||||||
client = OpenAI(
|
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
|
||||||
api_key=openai_api_key,
|
|
||||||
base_url=openai_api_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Define the Gradio chatbot interface using the predict function
|
||||||
|
gradio_interface = build_gradio_interface(client, args.model, args.temp,
|
||||||
|
args.stop_token_ids)
|
||||||
|
|
||||||
def predict(message, history):
|
gradio_interface.queue().launch(server_name=args.host,
|
||||||
# Convert chat history to OpenAI format
|
|
||||||
history_openai_format = [{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a great ai assistant."
|
|
||||||
}]
|
|
||||||
for human, assistant in history:
|
|
||||||
history_openai_format.append({"role": "user", "content": human})
|
|
||||||
history_openai_format.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": assistant
|
|
||||||
})
|
|
||||||
history_openai_format.append({"role": "user", "content": message})
|
|
||||||
|
|
||||||
# Create a chat completion request and send it to the API server
|
|
||||||
stream = client.chat.completions.create(
|
|
||||||
model=args.model, # Model name to use
|
|
||||||
messages=history_openai_format, # Chat history
|
|
||||||
temperature=args.temp, # Temperature for text generation
|
|
||||||
stream=True, # Stream response
|
|
||||||
extra_body={
|
|
||||||
'repetition_penalty':
|
|
||||||
1,
|
|
||||||
'stop_token_ids': [
|
|
||||||
int(id.strip()) for id in args.stop_token_ids.split(',')
|
|
||||||
if id.strip()
|
|
||||||
] if args.stop_token_ids else []
|
|
||||||
})
|
|
||||||
|
|
||||||
# Read and return generated text from response stream
|
|
||||||
partial_message = ""
|
|
||||||
for chunk in stream:
|
|
||||||
partial_message += (chunk.choices[0].delta.content or "")
|
|
||||||
yield partial_message
|
|
||||||
|
|
||||||
|
|
||||||
# Create and launch a chat interface with Gradio
|
|
||||||
gr.ChatInterface(predict).queue().launch(server_name=args.host,
|
|
||||||
server_port=args.port,
|
server_port=args.port,
|
||||||
share=True)
|
share=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -1,5 +1,22 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Example for starting a Gradio Webserver
|
||||||
|
Start vLLM API server:
|
||||||
|
python -m vllm.entrypoints.api_server \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf
|
||||||
|
|
||||||
|
Start Webserver:
|
||||||
|
python examples/online_serving/gradio_webserver.py
|
||||||
|
|
||||||
|
Note that `pip install --upgrade gradio` is needed to run this example.
|
||||||
|
More details: https://github.com/gradio-app/gradio
|
||||||
|
|
||||||
|
If your antivirus software blocks the download of frpc for gradio,
|
||||||
|
you can install it manually by following these steps:
|
||||||
|
|
||||||
|
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
||||||
|
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||||
|
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -39,16 +56,23 @@ def build_demo():
|
|||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default=None)
|
parser.add_argument("--host", type=str, default=None)
|
||||||
parser.add_argument("--port", type=int, default=8001)
|
parser.add_argument("--port", type=int, default=8001)
|
||||||
parser.add_argument("--model-url",
|
parser.add_argument("--model-url",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:8000/generate")
|
default="http://localhost:8000/generate")
|
||||||
args = parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
demo = build_demo()
|
demo = build_demo()
|
||||||
demo.queue().launch(server_name=args.host,
|
demo.queue().launch(server_name=args.host,
|
||||||
server_port=args.port,
|
server_port=args.port,
|
||||||
share=True)
|
share=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user