[Misc] refactor argument parsing in examples (#16635)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
parent
b590adfdc1
commit
6ae996a873
@ -187,6 +187,33 @@ model_example_map = {
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'audio language models')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument("--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
@ -240,28 +267,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'audio language models')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument("--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -12,16 +12,23 @@ prompts = [
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -4,6 +4,24 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
# Add example params
|
||||
parser.add_argument("--chat-template-path", type=str)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: dict):
|
||||
# Pop arguments not used by LLM
|
||||
max_tokens = args.pop("max_tokens")
|
||||
@ -82,18 +100,6 @@ def main(args: dict):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
# Add example params
|
||||
parser.add_argument("--chat-template-path", type=str)
|
||||
parser = create_parser()
|
||||
args: dict = vars(parser.parse_args())
|
||||
main(args)
|
||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
|
||||
task="classify",
|
||||
enforce_eager=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -34,11 +44,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
|
||||
task="classify",
|
||||
enforce_eager=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
|
||||
task="embed",
|
||||
enforce_eager=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -34,11 +44,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
|
||||
task="embed",
|
||||
enforce_eager=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -4,6 +4,22 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: dict):
|
||||
# Pop arguments not used by LLM
|
||||
max_tokens = args.pop("max_tokens")
|
||||
@ -35,23 +51,15 @@ def main(args: dict):
|
||||
]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
parser = create_parser()
|
||||
args: dict = vars(parser.parse_args())
|
||||
main(args)
|
||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
|
||||
task="score",
|
||||
enforce_eager=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
@ -30,11 +40,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
|
||||
task="score",
|
||||
enforce_eager=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -34,6 +34,40 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size")
|
||||
parser.add_argument("--tp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Tensor parallel size")
|
||||
parser.add_argument("--node-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes")
|
||||
parser.add_argument("--node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node")
|
||||
parser.add_argument("--master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address")
|
||||
parser.add_argument("--master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
dp_master_port, GPUs_per_dp_rank):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size")
|
||||
parser.add_argument("--tp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Tensor parallel size")
|
||||
parser.add_argument("--node-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes")
|
||||
parser.add_argument("--node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node")
|
||||
parser.add_argument("--master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address")
|
||||
parser.add_argument("--master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
|
||||
dp_size = args.dp_size
|
||||
tp_size = args.tp_size
|
||||
|
@ -27,7 +27,7 @@ def load_prompts(dataset_path, num_prompts):
|
||||
return prompts[:num_prompts]
|
||||
|
||||
|
||||
def main():
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
@ -45,7 +45,12 @@ def main():
|
||||
parser.add_argument("--enable_chunked_prefill", action='store_true')
|
||||
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
|
||||
parser.add_argument("--temp", type=float, default=0)
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
|
||||
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
|
||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -40,11 +50,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs, PoolingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -38,11 +48,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -8,94 +8,112 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
TokensPrompt, zip_enc_dec_prompts)
|
||||
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
)
|
||||
def create_prompts(tokenizer):
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
|
||||
prompt="The capital of France is"))
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
|
||||
prompt="The capital of France is"))
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
# ruff: noqa: E501
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
)
|
||||
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
['An encoder prompt', 'Another encoder prompt'],
|
||||
['A decoder prompt', 'Another decoder prompt'])
|
||||
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
)
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
return [
|
||||
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||
] + zipped_prompt_list
|
||||
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
['An encoder prompt', 'Another encoder prompt'],
|
||||
['A decoder prompt', 'Another decoder prompt'])
|
||||
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
prompts = [
|
||||
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||
] + zipped_prompt_list
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
)
|
||||
def create_sampling_params():
|
||||
return SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i+1}:")
|
||||
print(f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
def print_outputs(outputs):
|
||||
print("-" * 50)
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i+1}:")
|
||||
print(f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
|
||||
prompts = create_prompts(tokenizer)
|
||||
sampling_params = create_sampling_params()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -126,6 +126,23 @@ model_example_map = {
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
@ -171,19 +188,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -168,7 +168,7 @@ def run_advanced_demo(args: argparse.Namespace):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a demo in simple or advanced mode.")
|
||||
|
||||
@ -187,8 +187,11 @@ def main():
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.mode == "simple":
|
||||
print("Running simple demo...")
|
||||
|
@ -34,8 +34,7 @@ def time_generation(llm: LLM, prompts: list[str],
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
template = (
|
||||
"Below is an instruction that describes a task. Write a response "
|
||||
"that appropriately completes the request.\n\n### Instruction:\n{}"
|
||||
@ -66,3 +65,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
time_generation(llm, prompts, sampling_params, "With speculation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -417,6 +417,38 @@ def run_model(input_data,
|
||||
return pred_imgs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help=
|
||||
"0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
data_file: str,
|
||||
output_dir: str,
|
||||
@ -496,35 +528,7 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help=
|
||||
"0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
|
@ -359,7 +359,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
f" in folder {context.save_chrome_traces_folder}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(description="""
|
||||
Profile a model
|
||||
|
||||
@ -449,7 +449,10 @@ Profile a model
|
||||
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
context = ProfileContext(
|
||||
engine_args=EngineArgs.from_cli_args(args),
|
||||
**{
|
||||
@ -458,3 +461,8 @@ Profile a model
|
||||
if k in inspect.signature(ProfileContext).parameters
|
||||
})
|
||||
run_profile(context, csv_output=args.csv, json_output=args.json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -29,20 +29,23 @@ from pathlib import Path
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser()
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--output",
|
||||
"-o",
|
||||
required=True,
|
||||
type=str,
|
||||
help="path to output checkpoint")
|
||||
parser.add_argument("--file-pattern",
|
||||
type=str,
|
||||
help="string pattern of saved filenames")
|
||||
parser.add_argument("--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file")
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--output",
|
||||
"-o",
|
||||
required=True,
|
||||
type=str,
|
||||
help="path to output checkpoint")
|
||||
parser.add_argument("--file-pattern",
|
||||
type=str,
|
||||
help="string pattern of saved filenames")
|
||||
parser.add_argument("--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -87,5 +90,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -18,8 +18,8 @@ prompts = [
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
|
||||
|
||||
@ -42,3 +42,7 @@ if __name__ == "__main__":
|
||||
# Add a buffer to wait for profiler in the background process
|
||||
# (in case MP is on) to finish writing profiling output.
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1097,6 +1097,59 @@ def time_counter(enable: bool):
|
||||
yield
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="llava",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument('--modality',
|
||||
type=str,
|
||||
default="image",
|
||||
choices=['image', 'video'],
|
||||
help='Modality of the input.')
|
||||
parser.add_argument('--num-frames',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
parser.add_argument(
|
||||
'--image-repeat-prob',
|
||||
type=float,
|
||||
default=None,
|
||||
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
||||
' (if enabled)')
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
|
||||
parser.add_argument(
|
||||
'--time-generate',
|
||||
action='store_true',
|
||||
help='If True, then print the total generate() call time')
|
||||
|
||||
parser.add_argument(
|
||||
'--use-different-prompt-per-request',
|
||||
action='store_true',
|
||||
help='If True, then use different prompt (with the same multi-modal '
|
||||
'data) for each request.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
@ -1175,55 +1228,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="llava",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument('--modality',
|
||||
type=str,
|
||||
default="image",
|
||||
choices=['image', 'video'],
|
||||
help='Modality of the input.')
|
||||
parser.add_argument('--num-frames',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
parser.add_argument(
|
||||
'--image-repeat-prob',
|
||||
type=float,
|
||||
default=None,
|
||||
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
||||
' (if enabled)')
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
|
||||
parser.add_argument(
|
||||
'--time-generate',
|
||||
action='store_true',
|
||||
help='If True, then print the total generate() call time')
|
||||
|
||||
parser.add_argument(
|
||||
'--use-different-prompt-per-request',
|
||||
action='store_true',
|
||||
help='If True, then use different prompt (with the same multi-modal '
|
||||
'data) for each request.')
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -156,16 +156,13 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"e5_v": run_e5_v,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for multimodal embedding')
|
||||
@ -184,6 +181,13 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main(args: Namespace):
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -767,22 +767,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
model = args.model_type
|
||||
method = args.method
|
||||
seed = args.seed
|
||||
|
||||
image_urls = IMAGE_URLS[:args.num_images]
|
||||
|
||||
if method == "generate":
|
||||
run_generate(model, QUESTION, image_urls, seed)
|
||||
elif method == "chat":
|
||||
run_chat(model, QUESTION, image_urls, seed)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models that support multi-image input for text '
|
||||
@ -808,6 +793,24 @@ if __name__ == "__main__":
|
||||
choices=list(range(1, 13)), # 12 is the max number of images
|
||||
default=2,
|
||||
help="Number of images to use for the demo.")
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main(args: Namespace):
|
||||
model = args.model_type
|
||||
method = args.method
|
||||
seed = args.seed
|
||||
|
||||
image_urls = IMAGE_URLS[:args.num_images]
|
||||
|
||||
if method == "generate":
|
||||
run_generate(model, QUESTION, image_urls, seed)
|
||||
elif method == "chat":
|
||||
run_chat(model, QUESTION, image_urls, seed)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -58,6 +58,16 @@ def get_response(response: requests.Response) -> list[str]:
|
||||
return output
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--n", type=int, default=1)
|
||||
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
prompt = args.prompt
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
@ -82,11 +92,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--n", type=int, default=1)
|
||||
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
@ -1,52 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Example for starting a Gradio OpenAI Chatbot Webserver
|
||||
Start vLLM API server:
|
||||
vllm serve meta-llama/Llama-2-7b-chat-hf
|
||||
|
||||
Start Gradio OpenAI Chatbot Webserver:
|
||||
python examples/online_serving/gradio_openai_chatbot_webserver.py \
|
||||
-m meta-llama/Llama-2-7b-chat-hf
|
||||
|
||||
Note that `pip install --upgrade gradio` is needed to run this example.
|
||||
More details: https://github.com/gradio-app/gradio
|
||||
|
||||
If your antivirus software blocks the download of frpc for gradio,
|
||||
you can install it manually by following these steps:
|
||||
|
||||
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
||||
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
from openai import OpenAI
|
||||
|
||||
# Argument parser setup
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Chatbot Interface with Customizable Parameters')
|
||||
parser.add_argument('--model-url',
|
||||
type=str,
|
||||
default='http://localhost:8000/v1',
|
||||
help='Model URL')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Model name for the chatbot')
|
||||
parser.add_argument('--temp',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
parser.add_argument('--stop-token-ids',
|
||||
type=str,
|
||||
default='',
|
||||
help='Comma-separated stop token IDs')
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = args.model_url
|
||||
|
||||
# Create an OpenAI client to interact with the API server
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
def create_openai_client(api_key, base_url):
|
||||
return OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
|
||||
def predict(message, history):
|
||||
# Convert chat history to OpenAI format
|
||||
def format_history_to_openai(history):
|
||||
history_openai_format = [{
|
||||
"role": "system",
|
||||
"content": "You are a great ai assistant."
|
||||
"content": "You are a great AI assistant."
|
||||
}]
|
||||
for human, assistant in history:
|
||||
history_openai_format.append({"role": "user", "content": human})
|
||||
@ -54,31 +38,92 @@ def predict(message, history):
|
||||
"role": "assistant",
|
||||
"content": assistant
|
||||
})
|
||||
return history_openai_format
|
||||
|
||||
|
||||
def predict(message, history, client, model_name, temp, stop_token_ids):
|
||||
# Format history to OpenAI chat format
|
||||
history_openai_format = format_history_to_openai(history)
|
||||
history_openai_format.append({"role": "user", "content": message})
|
||||
|
||||
# Create a chat completion request and send it to the API server
|
||||
# Send request to OpenAI API (vLLM server)
|
||||
stream = client.chat.completions.create(
|
||||
model=args.model, # Model name to use
|
||||
messages=history_openai_format, # Chat history
|
||||
temperature=args.temp, # Temperature for text generation
|
||||
stream=True, # Stream response
|
||||
model=model_name,
|
||||
messages=history_openai_format,
|
||||
temperature=temp,
|
||||
stream=True,
|
||||
extra_body={
|
||||
'repetition_penalty':
|
||||
1,
|
||||
'stop_token_ids': [
|
||||
int(id.strip()) for id in args.stop_token_ids.split(',')
|
||||
if id.strip()
|
||||
] if args.stop_token_ids else []
|
||||
'stop_token_ids':
|
||||
[int(id.strip())
|
||||
for id in stop_token_ids.split(',')] if stop_token_ids else []
|
||||
})
|
||||
|
||||
# Read and return generated text from response stream
|
||||
partial_message = ""
|
||||
# Collect all chunks and concatenate them into a full message
|
||||
full_message = ""
|
||||
for chunk in stream:
|
||||
partial_message += (chunk.choices[0].delta.content or "")
|
||||
yield partial_message
|
||||
full_message += (chunk.choices[0].delta.content or "")
|
||||
|
||||
# Return the full message as a single response
|
||||
return full_message
|
||||
|
||||
|
||||
# Create and launch a chat interface with Gradio
|
||||
gr.ChatInterface(predict).queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Chatbot Interface with Customizable Parameters')
|
||||
parser.add_argument('--model-url',
|
||||
type=str,
|
||||
default='http://localhost:8000/v1',
|
||||
help='Model URL')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Model name for the chatbot')
|
||||
parser.add_argument('--temp',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
parser.add_argument('--stop-token-ids',
|
||||
type=str,
|
||||
default='',
|
||||
help='Comma-separated stop token IDs')
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_gradio_interface(client, model_name, temp, stop_token_ids):
|
||||
|
||||
def chat_predict(message, history):
|
||||
return predict(message, history, client, model_name, temp,
|
||||
stop_token_ids)
|
||||
|
||||
return gr.ChatInterface(fn=chat_predict,
|
||||
title="Chatbot Interface",
|
||||
description="A simple chatbot powered by vLLM")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse the arguments
|
||||
args = parse_args()
|
||||
|
||||
# Set OpenAI's API key and API base to use vLLM's API server
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = args.model_url
|
||||
|
||||
# Create an OpenAI client
|
||||
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
|
||||
|
||||
# Define the Gradio chatbot interface using the predict function
|
||||
gradio_interface = build_gradio_interface(client, args.model, args.temp,
|
||||
args.stop_token_ids)
|
||||
|
||||
gradio_interface.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,5 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Example for starting a Gradio Webserver
|
||||
Start vLLM API server:
|
||||
python -m vllm.entrypoints.api_server \
|
||||
--model meta-llama/Llama-2-7b-chat-hf
|
||||
|
||||
Start Webserver:
|
||||
python examples/online_serving/gradio_webserver.py
|
||||
|
||||
Note that `pip install --upgrade gradio` is needed to run this example.
|
||||
More details: https://github.com/gradio-app/gradio
|
||||
|
||||
If your antivirus software blocks the download of frpc for gradio,
|
||||
you can install it manually by following these steps:
|
||||
|
||||
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
||||
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
|
||||
@ -39,16 +56,23 @@ def build_demo():
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/generate")
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
demo = build_demo()
|
||||
demo.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user