diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 039042a..ffb678b 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,6 +1,5 @@ import asyncio from fastapi import APIRouter, Depends, HTTPException, Request -from fastapi.responses import JSONResponse from sse_starlette import EventSourceResponse from sys import maxsize @@ -9,7 +8,6 @@ from common.auth import check_api_key from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap -import endpoints.OAI.embeddings as OAIembeddings from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, @@ -25,6 +23,7 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) +from endpoints.OAI.utils.embeddings import embeddings router = APIRouter() @@ -134,14 +133,8 @@ async def chat_completion_request( @router.post( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], - response_model=EmbeddingsResponse, ) -async def handle_embeddings(request: EmbeddingsRequest): - input = request.input - if not input: - raise JSONResponse( - status_code=400, content={"error": "Missing required argument input"} - ) - model = request.model if request.model else None - response = await OAIembeddings.embeddings(input, request.encoding_format, model) - return JSONResponse(response) +async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: + response = await embeddings(data.input, data.encoding_format, data.model) + + return response diff --git a/endpoints/OAI/embeddings.py b/endpoints/OAI/utils/embeddings.py similarity index 100% rename from endpoints/OAI/embeddings.py rename to endpoints/OAI/utils/embeddings.py