diff --git a/OAI/types/models.py b/OAI/types/model.py similarity index 100% rename from OAI/types/models.py rename to OAI/types/model.py diff --git a/OAI/utils.py b/OAI/utils.py index c609057..c889605 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,7 +1,7 @@ import pathlib from OAI.types.completions import CompletionResponse, CompletionRespChoice from OAI.types.common import UsageStats -from OAI.types.models import ModelList, ModelCard +from OAI.types.model import ModelList, ModelCard from typing import Optional def create_completion_response(text: str, index: int, model_name: Optional[str]): diff --git a/main.py b/main.py index c0bb54c..ef39427 100644 --- a/main.py +++ b/main.py @@ -6,8 +6,8 @@ from fastapi import FastAPI, Request, HTTPException, Depends from model import ModelContainer from progress.bar import IncrementalBar from sse_starlette import EventSourceResponse -from OAI.types.completions import CompletionRequest, CompletionResponse -from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse +from OAI.types.completions import CompletionRequest +from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse from OAI.utils import create_completion_response, get_model_list from typing import Optional from utils import load_progress @@ -34,7 +34,7 @@ async def get_current_model(): model_card = ModelCard(id=model_container.get_model_path().name) return model_card.model_dump_json() -@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)]) +@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) async def load_model(data: ModelLoadRequest): if model_container and model_container.model: raise HTTPException(400, "A model is already loaded! Please unload it first.") @@ -80,7 +80,7 @@ async def unload_model(): model_container.unload() model_container = None -@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)]) +@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) async def generate_completion(request: Request, data: CompletionRequest): if data.stream: async def generator():