API: Don't use response_class

This arg in routes caused many errors and isn't even needed for
responses.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-14 21:56:15 -05:00
parent b625bface9
commit 4670a77c26
3 changed files with 5 additions and 5 deletions

View file

@ -1,7 +1,7 @@
import pathlib
from OAI.types.completions import CompletionResponse, CompletionRespChoice
from OAI.types.common import UsageStats
from OAI.types.models import ModelList, ModelCard
from OAI.types.model import ModelList, ModelCard
from typing import Optional
def create_completion_response(text: str, index: int, model_name: Optional[str]):

View file

@ -6,8 +6,8 @@ from fastapi import FastAPI, Request, HTTPException, Depends
from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.types.completions import CompletionRequest, CompletionResponse
from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse
from OAI.types.completions import CompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.utils import create_completion_response, get_model_list
from typing import Optional
from utils import load_progress
@ -34,7 +34,7 @@ async def get_current_model():
model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.model_dump_json()
@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)])
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
if model_container and model_container.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.")
@ -80,7 +80,7 @@ async def unload_model():
model_container.unload()
model_container = None
@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)])
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:
async def generator():