OAI: Allow /v1/encode endpoint to handle vision requests

* More robust checks for OAI chat completion message lists on /v1/encode endpoint
* Added TODO to support other aspects of chat completions
* Fix oversight where embeddings was not defined in advance on /v1/chat/completions endpoint
This commit is contained in:
DocShotgun 2024-11-19 11:14:37 -08:00
parent c42655336b
commit 5611365c07
4 changed files with 36 additions and 5 deletions

View file

@ -862,7 +862,9 @@ class ExllamaV2Container:
async with self.load_condition:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs):
def encode_tokens(
self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs
):
"""Wrapper to encode tokens from a text string."""
return (
@ -870,6 +872,7 @@ class ExllamaV2Container:
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=embeddings.content,
)
.flatten()
.tolist()

View file

@ -1,4 +1,5 @@
import asyncio
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
@ -124,6 +125,8 @@ async def chat_completion_request(
model_path = model.container.model_dir
embeddings = MultimodalEmbeddingWrapper()
if isinstance(data.messages, str):
prompt = data.messages
else:

View file

@ -1,6 +1,7 @@
import asyncio
import pathlib
from sys import maxsize
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette import EventSourceResponse
@ -357,10 +358,27 @@ async def unload_embedding_model():
)
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
embeddings = MultimodalEmbeddingWrapper()
if isinstance(data.text, str):
text = data.text
else:
elif isinstance(data.text, list) and "oai" in config.network.api_servers:
# TODO: Support additional chat completion args for encode
# i.e. add_generation_prompt, template selection, tool args, template kwargs
if model.container.prompt_template is None:
error_message = handle_request_error(
"Tokenization of chat completion requests is disabled "
"because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
from endpoints.OAI.utils.chat_completion import preprocess_vision_request
if model.container.use_vision:
data.text, embeddings = await preprocess_vision_request(data.text)
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
@ -371,9 +389,16 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
text = await model.container.prompt_template.render(template_vars)
else:
error_message = handle_request_error(
"OAI API server must be enabled to handle chat completion message inputs.",
exc_info=False,
).error.message
raw_tokens = model.container.encode_tokens(text, **data.get_params())
raise HTTPException(422, error_message)
raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))

View file

@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: Union[str, List[Dict[str, str]]]
text: Union[str, List[Dict]]
class TokenEncodeResponse(BaseModel):