API: Transform multimodal into an actual class
Migrate the add method into the class itself. Also, a BaseModel isn't needed here since this isn't a serialized class. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8ffc636dce
commit
c652a6e030
2 changed files with 14 additions and 22 deletions
|
|
@ -1,7 +1,6 @@
|
|||
from typing import List
|
||||
from backends.exllamav2.vision import get_image_embedding
|
||||
from common import model
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
|
||||
from common.optional_dependencies import dependencies
|
||||
|
|
@ -10,27 +9,22 @@ if dependencies.exllamav2:
|
|||
from exllamav2 import ExLlamaV2VisionTower
|
||||
|
||||
|
||||
class MultimodalEmbeddingWrapper(BaseModel):
|
||||
class MultimodalEmbeddingWrapper:
|
||||
"""Common multimodal embedding wrapper"""
|
||||
|
||||
type: str = None
|
||||
content: List = []
|
||||
text_alias: List[str] = []
|
||||
|
||||
async def add(self, url: str):
|
||||
# Determine the type of vision embedding to use
|
||||
if not self.type:
|
||||
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
|
||||
self.type = "ExLlamaV2MMEmbedding"
|
||||
|
||||
async def add_image_embedding(
|
||||
embeddings: MultimodalEmbeddingWrapper, url: str
|
||||
) -> MultimodalEmbeddingWrapper:
|
||||
# Determine the type of vision embedding to use
|
||||
if not embeddings.type:
|
||||
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
|
||||
embeddings.type = "ExLlamaV2MMEmbedding"
|
||||
|
||||
if embeddings.type == "ExLlamaV2MMEmbedding":
|
||||
embedding = await get_image_embedding(url)
|
||||
embeddings.content.append(embedding)
|
||||
embeddings.text_alias.append(embedding.text_alias)
|
||||
else:
|
||||
logger.error("No valid vision model to create embedding")
|
||||
|
||||
return embeddings
|
||||
if self.type == "ExLlamaV2MMEmbedding":
|
||||
embedding = await get_image_embedding(url)
|
||||
self.content.append(embedding)
|
||||
self.text_alias.append(embedding.text_alias)
|
||||
else:
|
||||
logger.error("No valid vision model to create embedding")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from jinja2 import TemplateError
|
|||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
|
|
@ -483,9 +483,7 @@ async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
|
|||
if content.type == "text":
|
||||
concatenated_content += content.text
|
||||
elif content.type == "image_url":
|
||||
embeddings = await add_image_embedding(
|
||||
embeddings, content.image_url.url
|
||||
)
|
||||
await embeddings.add(content.image_url.url)
|
||||
concatenated_content += embeddings.text_alias[-1]
|
||||
|
||||
message.content = concatenated_content
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue