[Misc] Manage HTTP connections in one place (#6600)
This commit is contained in:
parent
c051bfe4eb
commit
97234be0ec
@ -16,6 +16,7 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.inputs import TextPrompt
|
||||
@ -74,6 +75,13 @@ IMAGE_ASSETS = _ImageAssets()
|
||||
"""Singleton instance of :class:`_ImageAssets`."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def init_test_http_connection():
|
||||
# pytest_asyncio may use a different event loop per test
|
||||
# so we need to make sure the async client is created anew
|
||||
global_http_connection.reuse_client = False
|
||||
|
||||
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
@ -2,9 +2,8 @@ from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
@ -42,11 +41,10 @@ def client(server):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def base64_encoded_image() -> Dict[str, str]:
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_image() -> Dict[str, str]:
|
||||
return {
|
||||
image_url:
|
||||
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
|
||||
image_url: encode_image_base64(fetch_image(image_url))
|
||||
for image_url in TEST_IMAGE_URLS
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image
|
||||
from vllm.multimodal.utils import async_fetch_image, fetch_image
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_URLS = [
|
||||
@ -37,15 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
|
||||
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_fetch_image_http(image_url: str):
|
||||
image_sync = fetch_image(image_url)
|
||||
image_async = await ImageFetchAiohttp.fetch_image(image_url)
|
||||
image_async = await async_fetch_image(image_url)
|
||||
assert _image_equals(image_sync, image_async)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
@ -78,5 +78,5 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
else:
|
||||
pass # Lossy format; only check that image can be opened
|
||||
|
||||
data_image_async = await ImageFetchAiohttp.fetch_image(data_url)
|
||||
data_image_async = await async_fetch_image(data_url)
|
||||
assert _image_equals(data_image_sync, data_image_async)
|
||||
|
@ -1,11 +1,12 @@
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
|
||||
|
||||
from .base import get_cache_dir
|
||||
|
||||
|
||||
@ -22,11 +23,9 @@ def get_air_example_data_2_asset(filename: str) -> Image.Image:
|
||||
if not image_path.exists():
|
||||
base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"
|
||||
|
||||
with requests.get(f"{base_url}/{filename}", stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
with image_path.open("wb") as f:
|
||||
shutil.copyfileobj(response.raw, f)
|
||||
global_http_connection.download_file(f"{base_url}/{filename}",
|
||||
image_path,
|
||||
timeout=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
|
||||
return Image.open(image_path)
|
||||
|
||||
|
167
vllm/connections.py
Normal file
167
vllm/connections.py
Normal file
@ -0,0 +1,167 @@
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
class HTTPConnection:
|
||||
"""Helper class to send HTTP requests."""
|
||||
|
||||
def __init__(self, *, reuse_client: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.reuse_client = reuse_client
|
||||
|
||||
self._sync_client: Optional[requests.Session] = None
|
||||
self._async_client: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
def get_sync_client(self) -> requests.Session:
|
||||
if self._sync_client is None or not self.reuse_client:
|
||||
self._sync_client = requests.Session()
|
||||
|
||||
return self._sync_client
|
||||
|
||||
# NOTE: We intentionally use an async function even though it is not
|
||||
# required, so that the client is only accessible inside async event loop
|
||||
async def get_async_client(self) -> aiohttp.ClientSession:
|
||||
if self._async_client is None or not self.reuse_client:
|
||||
self._async_client = aiohttp.ClientSession()
|
||||
|
||||
return self._async_client
|
||||
|
||||
def _validate_http_url(self, url: str):
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
|
||||
"must have scheme 'http' or 'https'.")
|
||||
|
||||
def _headers(self, **extras: str) -> Mapping[str, str]:
|
||||
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
stream: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
client = self.get_sync_client()
|
||||
extra_headers = extra_headers or {}
|
||||
|
||||
return client.get(url,
|
||||
headers=self._headers(**extra_headers),
|
||||
stream=stream,
|
||||
timeout=timeout)
|
||||
|
||||
async def get_async_response(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
client = await self.get_async_client()
|
||||
extra_headers = extra_headers or {}
|
||||
|
||||
return client.get(url,
|
||||
headers=self._headers(**extra_headers),
|
||||
timeout=timeout)
|
||||
|
||||
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.content
|
||||
|
||||
async def async_get_bytes(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
) -> bytes:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.read()
|
||||
|
||||
def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.text
|
||||
|
||||
async def async_get_text(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
) -> str:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.text()
|
||||
|
||||
def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
|
||||
async def async_get_json(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
) -> str:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.json()
|
||||
|
||||
def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: Path,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
chunk_size: int = 128,
|
||||
) -> Path:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
with save_path.open("wb") as f:
|
||||
for chunk in r.iter_content(chunk_size):
|
||||
f.write(chunk)
|
||||
|
||||
return save_path
|
||||
|
||||
async def async_download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: Path,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
chunk_size: int = 128,
|
||||
) -> Path:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
with save_path.open("wb") as f:
|
||||
async for chunk in r.content.iter_chunked(chunk_size):
|
||||
f.write(chunk)
|
||||
|
||||
return save_path
|
||||
|
||||
|
||||
global_http_connection = HTTPConnection()
|
||||
"""The global :class:`HTTPConnection` instance used by vLLM."""
|
@ -1,26 +1,12 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
from typing import Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
|
||||
from vllm.multimodal.base import MultiModalDataDict
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
def _validate_remote_url(url: str, *, name: str):
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.scheme not in ["http", "https"]:
|
||||
raise ValueError(f"Invalid '{name}': A valid '{name}' "
|
||||
"must have scheme 'http' or 'https'.")
|
||||
|
||||
|
||||
def _get_request_headers():
|
||||
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
|
||||
|
||||
|
||||
def _load_image_from_bytes(b: bytes):
|
||||
@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
if image_url.startswith('http'):
|
||||
_validate_remote_url(image_url, name="image_url")
|
||||
|
||||
headers = _get_request_headers()
|
||||
|
||||
with requests.get(url=image_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
image_raw = response.content
|
||||
image_raw = global_http_connection.get_bytes(
|
||||
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
|
||||
return image.convert(image_mode)
|
||||
|
||||
|
||||
class ImageFetchAiohttp:
|
||||
aiohttp_client: Optional[aiohttp.ClientSession] = None
|
||||
async def async_fetch_image(image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB") -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
@classmethod
|
||||
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
|
||||
if cls.aiohttp_client is None:
|
||||
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
connector = aiohttp.TCPConnector()
|
||||
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
|
||||
connector=connector)
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
if image_url.startswith('http'):
|
||||
image_raw = await global_http_connection.async_get_bytes(
|
||||
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
|
||||
return cls.aiohttp_client
|
||||
elif image_url.startswith('data:image'):
|
||||
image = _load_image_from_data_url(image_url)
|
||||
else:
|
||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image' or 'http'.")
|
||||
|
||||
@classmethod
|
||||
async def fetch_image(
|
||||
cls,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
|
||||
if image_url.startswith('http'):
|
||||
_validate_remote_url(image_url, name="image_url")
|
||||
|
||||
client = cls.get_aiohttp_client()
|
||||
headers = _get_request_headers()
|
||||
|
||||
async with client.get(url=image_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
image_raw = await response.read()
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
image = _load_image_from_data_url(image_url)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image' or 'http'.")
|
||||
|
||||
return image.convert(image_mode)
|
||||
return image.convert(image_mode)
|
||||
|
||||
|
||||
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
||||
image = await ImageFetchAiohttp.fetch_image(image_url)
|
||||
image = await async_fetch_image(image_url)
|
||||
return {"image": image}
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@ import requests
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
_config_home = envs.VLLM_CONFIG_ROOT
|
||||
@ -204,7 +205,8 @@ class UsageMessage:
|
||||
|
||||
def _send_to_server(self, data):
|
||||
try:
|
||||
requests.post(_USAGE_STATS_SERVER, json=data)
|
||||
global_http_client = global_http_connection.get_sync_client()
|
||||
global_http_client.post(_USAGE_STATS_SERVER, json=data)
|
||||
except requests.exceptions.RequestException:
|
||||
# silently ignore unless we are using debug log
|
||||
logging.debug("Failed to send usage data to server")
|
||||
|
Loading…
x
Reference in New Issue
Block a user