diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3c6634f..bc9142a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -862,7 +862,9 @@ class ExllamaV2Container: async with self.load_condition: self.load_condition.notify_all() - def encode_tokens(self, text: str, **kwargs): + def encode_tokens( + self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs + ): """Wrapper to encode tokens from a text string.""" return ( @@ -870,6 +872,7 @@ class ExllamaV2Container: text, add_bos=unwrap(kwargs.get("add_bos_token"), True), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + embeddings=embeddings.content, ) .flatten() .tolist() diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index c018f5c..acb35f9 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,4 +1,5 @@ import asyncio +from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -124,6 +125,8 @@ async def chat_completion_request( model_path = model.container.model_dir + embeddings = MultimodalEmbeddingWrapper() + if isinstance(data.messages, str): prompt = data.messages else: diff --git a/endpoints/core/router.py b/endpoints/core/router.py index f2b4247..0a48a2e 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,6 +1,7 @@ import asyncio import pathlib from sys import maxsize +from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request, Response from sse_starlette import EventSourceResponse @@ -357,10 +358,27 @@ async def unload_embedding_model(): ) async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" + embeddings = MultimodalEmbeddingWrapper() if isinstance(data.text, str): text = data.text - else: + elif isinstance(data.text, list) and "oai" in config.network.api_servers: + # TODO: Support additional chat completion args for encode + # i.e. add_generation_prompt, template selection, tool args, template kwargs + if model.container.prompt_template is None: + error_message = handle_request_error( + "Tokenization of chat completion requests is disabled " + "because a prompt template is not set.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) + + from endpoints.OAI.utils.chat_completion import preprocess_vision_request + + if model.container.use_vision: + data.text, embeddings = await preprocess_vision_request(data.text) + special_tokens_dict = model.container.get_special_tokens( unwrap(data.add_bos_token, True) ) @@ -371,9 +389,16 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: **special_tokens_dict, } - text, _ = model.container.prompt_template.render(template_vars) + text = await model.container.prompt_template.render(template_vars) + else: + error_message = handle_request_error( + "OAI API server must be enabled to handle chat completion message inputs.", + exc_info=False, + ).error.message - raw_tokens = model.container.encode_tokens(text, **data.get_params()) + raise HTTPException(422, error_message) + + raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params()) tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index 945adbf..2c205ab 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[Dict[str, str]]] + text: Union[str, List[Dict]] class TokenEncodeResponse(BaseModel):