OAI: Allow /v1/encode endpoint to handle vision requests
* More robust checks for OAI chat completion message lists on /v1/encode endpoint * Added TODO to support other aspects of chat completions * Fix oversight where embeddings was not defined in advance on /v1/chat/completions endpoint
This commit is contained in:
parent
c42655336b
commit
5611365c07
4 changed files with 36 additions and 5 deletions
|
|
@ -862,7 +862,9 @@ class ExllamaV2Container:
|
|||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
def encode_tokens(
|
||||
self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs
|
||||
):
|
||||
"""Wrapper to encode tokens from a text string."""
|
||||
|
||||
return (
|
||||
|
|
@ -870,6 +872,7 @@ class ExllamaV2Container:
|
|||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
embeddings=embeddings.content,
|
||||
)
|
||||
.flatten()
|
||||
.tolist()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
|
|
@ -124,6 +125,8 @@ async def chat_completion_request(
|
|||
|
||||
model_path = model.container.model_dir
|
||||
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import pathlib
|
||||
from sys import maxsize
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
|
|
@ -357,10 +358,27 @@ async def unload_embedding_model():
|
|||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
else:
|
||||
elif isinstance(data.text, list) and "oai" in config.network.api_servers:
|
||||
# TODO: Support additional chat completion args for encode
|
||||
# i.e. add_generation_prompt, template selection, tool args, template kwargs
|
||||
if model.container.prompt_template is None:
|
||||
error_message = handle_request_error(
|
||||
"Tokenization of chat completion requests is disabled "
|
||||
"because a prompt template is not set.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
from endpoints.OAI.utils.chat_completion import preprocess_vision_request
|
||||
|
||||
if model.container.use_vision:
|
||||
data.text, embeddings = await preprocess_vision_request(data.text)
|
||||
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
|
|
@ -371,9 +389,16 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
|||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
text = await model.container.prompt_template.render(template_vars)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"OAI API server must be enabled to handle chat completion message inputs.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel):
|
|||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
"""Represents a tokenization request."""
|
||||
|
||||
text: Union[str, List[Dict[str, str]]]
|
||||
text: Union[str, List[Dict]]
|
||||
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue