API: Add fallback if model isn't loaded
Most endpoints require the model to be loaded, so add a depends. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c0525c042e
commit
d8d61fa19b
1 changed files with 10 additions and 12 deletions
22
main.py
22
main.py
|
|
@ -32,6 +32,10 @@ app = FastAPI()
|
|||
model_container: Optional[ModelContainer] = None
|
||||
config: dict = {}
|
||||
|
||||
def _check_model_container():
|
||||
if model_container is None or model_container.model is None:
|
||||
raise HTTPException(400, "No models are loaded.")
|
||||
|
||||
# Model list endpoint
|
||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
|
|
@ -47,11 +51,8 @@ async def list_models():
|
|||
return models.json()
|
||||
|
||||
# Currently loaded model endpoint
|
||||
@app.get("/v1/model", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_current_model():
|
||||
if model_container is None or model_container.model is None:
|
||||
return HTTPException(400, "No models are loaded.")
|
||||
|
||||
model_card = ModelCard(id=model_container.get_model_path().name)
|
||||
return model_card.json()
|
||||
|
||||
|
|
@ -98,18 +99,15 @@ async def load_model(data: ModelLoadRequest):
|
|||
return EventSourceResponse(generator())
|
||||
|
||||
# Unload model endpoint
|
||||
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
|
||||
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
||||
async def unload_model():
|
||||
global model_container
|
||||
|
||||
if model_container is None:
|
||||
raise HTTPException(400, "No models are loaded.")
|
||||
|
||||
model_container.unload()
|
||||
model_container = None
|
||||
|
||||
# Encode tokens endpoint
|
||||
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
|
||||
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def encode_tokens(data: TokenEncodeRequest):
|
||||
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
|
@ -117,7 +115,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
return response.json()
|
||||
|
||||
# Decode tokens endpoint
|
||||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
|
||||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=message)
|
||||
|
|
@ -125,7 +123,7 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||
return response.json()
|
||||
|
||||
# Completions endpoint
|
||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
|
|
@ -151,7 +149,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
return response.json()
|
||||
|
||||
# Chat completions endpoint
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue