[Misc] Manage HTTP connections in one place (#6600)

This commit is contained in:
Cyrus Leung 2024-07-23 12:32:02 +08:00 committed by GitHub
parent c051bfe4eb
commit 97234be0ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 215 additions and 85 deletions

View File

@ -16,6 +16,7 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
@ -74,6 +75,13 @@ IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
@pytest.fixture(autouse=True)
def init_test_http_connection():
# pytest_asyncio may use a different event loop per test
# so we need to make sure the async client is created anew
global_http_connection.reuse_client = False
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()

View File

@ -2,9 +2,8 @@ from typing import Dict, List
import openai
import pytest
import pytest_asyncio
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
from vllm.multimodal.utils import encode_image_base64, fetch_image
from ...utils import VLLM_PATH, RemoteOpenAIServer
@ -42,11 +41,10 @@ def client(server):
return server.get_async_client()
@pytest_asyncio.fixture(scope="session")
async def base64_encoded_image() -> Dict[str, str]:
@pytest.fixture(scope="session")
def base64_encoded_image() -> Dict[str, str]:
return {
image_url:
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}

View File

@ -7,7 +7,7 @@ import numpy as np
import pytest
from PIL import Image
from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image
from vllm.multimodal.utils import async_fetch_image, fetch_image
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
@ -37,15 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
@pytest.mark.asyncio(scope="module")
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url)
image_async = await ImageFetchAiohttp.fetch_image(image_url)
image_async = await async_fetch_image(image_url)
assert _image_equals(image_sync, image_async)
@pytest.mark.asyncio(scope="module")
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
@ -78,5 +78,5 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
else:
pass # Lossy format; only check that image can be opened
data_image_async = await ImageFetchAiohttp.fetch_image(data_url)
data_image_async = await async_fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async)

View File

@ -1,11 +1,12 @@
import shutil
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal
import requests
from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from .base import get_cache_dir
@ -22,11 +23,9 @@ def get_air_example_data_2_asset(filename: str) -> Image.Image:
if not image_path.exists():
base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"
with requests.get(f"{base_url}/{filename}", stream=True) as response:
response.raise_for_status()
with image_path.open("wb") as f:
shutil.copyfileobj(response.raw, f)
global_http_connection.download_file(f"{base_url}/{filename}",
image_path,
timeout=VLLM_IMAGE_FETCH_TIMEOUT)
return Image.open(image_path)

167
vllm/connections.py Normal file
View File

@ -0,0 +1,167 @@
from pathlib import Path
from typing import Mapping, Optional
from urllib.parse import urlparse
import aiohttp
import requests
from vllm.version import __version__ as VLLM_VERSION
class HTTPConnection:
"""Helper class to send HTTP requests."""
def __init__(self, *, reuse_client: bool = True) -> None:
super().__init__()
self.reuse_client = reuse_client
self._sync_client: Optional[requests.Session] = None
self._async_client: Optional[aiohttp.ClientSession] = None
def get_sync_client(self) -> requests.Session:
if self._sync_client is None or not self.reuse_client:
self._sync_client = requests.Session()
return self._sync_client
# NOTE: We intentionally use an async function even though it is not
# required, so that the client is only accessible inside async event loop
async def get_async_client(self) -> aiohttp.ClientSession:
if self._async_client is None or not self.reuse_client:
self._async_client = aiohttp.ClientSession()
return self._async_client
def _validate_http_url(self, url: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.")
def _headers(self, **extras: str) -> Mapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
def get_response(
self,
url: str,
*,
stream: bool = False,
timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None,
):
self._validate_http_url(url)
client = self.get_sync_client()
extra_headers = extra_headers or {}
return client.get(url,
headers=self._headers(**extra_headers),
stream=stream,
timeout=timeout)
async def get_async_response(
self,
url: str,
*,
timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None,
):
self._validate_http_url(url)
client = await self.get_async_client()
extra_headers = extra_headers or {}
return client.get(url,
headers=self._headers(**extra_headers),
timeout=timeout)
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
return r.content
async def async_get_bytes(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> bytes:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
return await r.read()
def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
return r.text
async def async_get_text(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> str:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
return await r.text()
def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
return r.json()
async def async_get_json(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> str:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
return await r.json()
def download_file(
self,
url: str,
save_path: Path,
*,
timeout: Optional[float] = None,
chunk_size: int = 128,
) -> Path:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
with save_path.open("wb") as f:
for chunk in r.iter_content(chunk_size):
f.write(chunk)
return save_path
async def async_download_file(
self,
url: str,
save_path: Path,
*,
timeout: Optional[float] = None,
chunk_size: int = 128,
) -> Path:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
with save_path.open("wb") as f:
async for chunk in r.content.iter_chunked(chunk_size):
f.write(chunk)
return save_path
global_http_connection = HTTPConnection()
"""The global :class:`HTTPConnection` instance used by vLLM."""

View File

@ -1,26 +1,12 @@
import base64
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse
from typing import Union
import aiohttp
import requests
from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.base import MultiModalDataDict
from vllm.version import __version__ as VLLM_VERSION
def _validate_remote_url(url: str, *, name: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise ValueError(f"Invalid '{name}': A valid '{name}' "
"must have scheme 'http' or 'https'.")
def _get_request_headers():
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
def _load_image_from_bytes(b: bytes):
@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
headers = _get_request_headers()
with requests.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = response.content
image_raw = global_http_connection.get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
return image.convert(image_mode)
class ImageFetchAiohttp:
aiohttp_client: Optional[aiohttp.ClientSession] = None
async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
@classmethod
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
if cls.aiohttp_client is None:
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
connector = aiohttp.TCPConnector()
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
connector=connector)
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
image_raw = await global_http_connection.async_get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
image = _load_image_from_bytes(image_raw)
return cls.aiohttp_client
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
@classmethod
async def fetch_image(
cls,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
client = cls.get_aiohttp_client()
headers = _get_request_headers()
async with client.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = await response.read()
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError(
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
return image.convert(image_mode)
return image.convert(image_mode)
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
image = await async_fetch_image(image_url)
return {"image": image}

View File

@ -16,6 +16,7 @@ import requests
import torch
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.version import __version__ as VLLM_VERSION
_config_home = envs.VLLM_CONFIG_ROOT
@ -204,7 +205,8 @@ class UsageMessage:
def _send_to_server(self, data):
try:
requests.post(_USAGE_STATS_SERVER, json=data)
global_http_client = global_http_connection.get_sync_client()
global_http_client.post(_USAGE_STATS_SERVER, json=data)
except requests.exceptions.RequestException:
# silently ignore unless we are using debug log
logging.debug("Failed to send usage data to server")