Common: Refactor get_image to common functions

This commit is contained in:
turboderp 2025-06-15 19:20:36 +02:00
parent d357f100d0
commit 4605c0f6bd
2 changed files with 53 additions and 50 deletions

View file

@ -1,19 +1,10 @@
"""Vision utilities for ExLlamaV2."""
import aiohttp
import base64
import io
import re
from async_lru import alru_cache
from fastapi import HTTPException
from PIL import Image
from common import model
from common.networking import (
handle_request_error,
)
from common.optional_dependencies import dependencies
from common.tabby_config import config
from common.image_util import get_image
# Since this is used outside the Exl2 backend, the dependency
# may be optional
@ -21,46 +12,6 @@ if dependencies.exllamav2:
from exllamav2.generator import ExLlamaV2MMEmbedding
async def get_image(url: str) -> Image:
if url.startswith("data:image"):
# Handle base64 image
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
if match:
base64_image = match.group(1)
bytes_image = base64.b64decode(base64_image)
else:
error_message = handle_request_error(
"Failed to read base64 image input.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
else:
# Handle image URL
if config.network.disable_fetch_requests:
error_message = handle_request_error(
f"Failed to fetch image from {url} as fetch requests are disabled.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
bytes_image = await response.read()
else:
error_message = handle_request_error(
f"Failed to fetch image from {url}.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return Image.open(io.BytesIO(bytes_image))
# Fetch the return type on runtime
@alru_cache(20)
async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding":

52
common/image_util.py Normal file
View file

@ -0,0 +1,52 @@
import aiohttp
import base64
import io
import re
from fastapi import HTTPException
from PIL import Image
from common.networking import (
handle_request_error,
)
from common.tabby_config import config
async def get_image(url: str) -> Image:
if url.startswith("data:image"):
# Handle base64 image
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
if match:
base64_image = match.group(1)
bytes_image = base64.b64decode(base64_image)
else:
error_message = handle_request_error(
"Failed to read base64 image input.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
else:
# Handle image URL
if config.network.disable_fetch_requests:
error_message = handle_request_error(
f"Failed to fetch image from {url} as fetch requests are disabled.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
bytes_image = await response.read()
else:
error_message = handle_request_error(
f"Failed to fetch image from {url}.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return Image.open(io.BytesIO(bytes_image))