85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
import argparse
|
|
import os
|
|
import subprocess
|
|
|
|
import torch
|
|
|
|
from vllm import LLM
|
|
from vllm.sequence import MultiModalData
|
|
|
|
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
|
|
|
|
|
def run_llava_pixel_values():
|
|
llm = LLM(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
image_input_type="pixel_values",
|
|
image_token_id=32000,
|
|
image_input_shape="1,3,336,336",
|
|
image_feature_size=576,
|
|
)
|
|
|
|
prompt = "<image>" * 576 + (
|
|
"\nUSER: What is the content of this image?\nASSISTANT:")
|
|
|
|
# This should be provided by another online or offline component.
|
|
images = torch.load("images/stop_sign_pixel_values.pt")
|
|
|
|
outputs = llm.generate(prompt,
|
|
multi_modal_data=MultiModalData(
|
|
type=MultiModalData.Type.IMAGE, data=images))
|
|
for o in outputs:
|
|
generated_text = o.outputs[0].text
|
|
print(generated_text)
|
|
|
|
|
|
def run_llava_image_features():
|
|
llm = LLM(
|
|
model="llava-hf/llava-1.5-7b-hf",
|
|
image_input_type="image_features",
|
|
image_token_id=32000,
|
|
image_input_shape="1,576,1024",
|
|
image_feature_size=576,
|
|
)
|
|
|
|
prompt = "<image>" * 576 + (
|
|
"\nUSER: What is the content of this image?\nASSISTANT:")
|
|
|
|
# This should be provided by another online or offline component.
|
|
images = torch.load("images/stop_sign_image_features.pt")
|
|
|
|
outputs = llm.generate(prompt,
|
|
multi_modal_data=MultiModalData(
|
|
type=MultiModalData.Type.IMAGE, data=images))
|
|
for o in outputs:
|
|
generated_text = o.outputs[0].text
|
|
print(generated_text)
|
|
|
|
|
|
def main(args):
|
|
if args.type == "pixel_values":
|
|
run_llava_pixel_values()
|
|
else:
|
|
run_llava_image_features()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Demo on Llava")
|
|
parser.add_argument("--type",
|
|
type=str,
|
|
choices=["pixel_values", "image_features"],
|
|
default="pixel_values",
|
|
help="image input type")
|
|
args = parser.parse_args()
|
|
# Download from s3
|
|
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
|
local_directory = "images"
|
|
|
|
# Make sure the local directory exists or create it
|
|
os.makedirs(local_directory, exist_ok=True)
|
|
|
|
# Use AWS CLI to sync the directory
|
|
subprocess.check_call(
|
|
["aws", "s3", "sync", s3_bucket_path, local_directory])
|
|
main(args)
|