Common: Refactor get_image to common functions
This commit is contained in:
parent
d357f100d0
commit
4605c0f6bd
2 changed files with 53 additions and 50 deletions
|
|
@ -1,19 +1,10 @@
|
||||||
"""Vision utilities for ExLlamaV2."""
|
"""Vision utilities for ExLlamaV2."""
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import re
|
|
||||||
from async_lru import alru_cache
|
from async_lru import alru_cache
|
||||||
from fastapi import HTTPException
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.networking import (
|
|
||||||
handle_request_error,
|
|
||||||
)
|
|
||||||
from common.optional_dependencies import dependencies
|
from common.optional_dependencies import dependencies
|
||||||
from common.tabby_config import config
|
from common.image_util import get_image
|
||||||
|
|
||||||
# Since this is used outside the Exl2 backend, the dependency
|
# Since this is used outside the Exl2 backend, the dependency
|
||||||
# may be optional
|
# may be optional
|
||||||
|
|
@ -21,46 +12,6 @@ if dependencies.exllamav2:
|
||||||
from exllamav2.generator import ExLlamaV2MMEmbedding
|
from exllamav2.generator import ExLlamaV2MMEmbedding
|
||||||
|
|
||||||
|
|
||||||
async def get_image(url: str) -> Image:
|
|
||||||
if url.startswith("data:image"):
|
|
||||||
# Handle base64 image
|
|
||||||
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
|
|
||||||
if match:
|
|
||||||
base64_image = match.group(1)
|
|
||||||
bytes_image = base64.b64decode(base64_image)
|
|
||||||
else:
|
|
||||||
error_message = handle_request_error(
|
|
||||||
"Failed to read base64 image input.",
|
|
||||||
exc_info=False,
|
|
||||||
).error.message
|
|
||||||
|
|
||||||
raise HTTPException(400, error_message)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Handle image URL
|
|
||||||
if config.network.disable_fetch_requests:
|
|
||||||
error_message = handle_request_error(
|
|
||||||
f"Failed to fetch image from {url} as fetch requests are disabled.",
|
|
||||||
exc_info=False,
|
|
||||||
).error.message
|
|
||||||
|
|
||||||
raise HTTPException(400, error_message)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(url) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
bytes_image = await response.read()
|
|
||||||
else:
|
|
||||||
error_message = handle_request_error(
|
|
||||||
f"Failed to fetch image from {url}.",
|
|
||||||
exc_info=False,
|
|
||||||
).error.message
|
|
||||||
|
|
||||||
raise HTTPException(400, error_message)
|
|
||||||
|
|
||||||
return Image.open(io.BytesIO(bytes_image))
|
|
||||||
|
|
||||||
|
|
||||||
# Fetch the return type on runtime
|
# Fetch the return type on runtime
|
||||||
@alru_cache(20)
|
@alru_cache(20)
|
||||||
async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding":
|
async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding":
|
||||||
|
|
|
||||||
52
common/image_util.py
Normal file
52
common/image_util.py
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
import aiohttp
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from common.networking import (
|
||||||
|
handle_request_error,
|
||||||
|
)
|
||||||
|
from common.tabby_config import config
|
||||||
|
|
||||||
|
|
||||||
|
async def get_image(url: str) -> Image:
|
||||||
|
if url.startswith("data:image"):
|
||||||
|
# Handle base64 image
|
||||||
|
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
|
||||||
|
if match:
|
||||||
|
base64_image = match.group(1)
|
||||||
|
bytes_image = base64.b64decode(base64_image)
|
||||||
|
else:
|
||||||
|
error_message = handle_request_error(
|
||||||
|
"Failed to read base64 image input.",
|
||||||
|
exc_info=False,
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Handle image URL
|
||||||
|
if config.network.disable_fetch_requests:
|
||||||
|
error_message = handle_request_error(
|
||||||
|
f"Failed to fetch image from {url} as fetch requests are disabled.",
|
||||||
|
exc_info=False,
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
bytes_image = await response.read()
|
||||||
|
else:
|
||||||
|
error_message = handle_request_error(
|
||||||
|
f"Failed to fetch image from {url}.",
|
||||||
|
exc_info=False,
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
|
return Image.open(io.BytesIO(bytes_image))
|
||||||
Loading…
Add table
Add a link
Reference in a new issue