diff --git a/main.py b/main.py index 24be8ae..8da2738 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,10 @@ app = FastAPI() model_container: Optional[ModelContainer] = None config: dict = {} +def _check_model_container(): + if model_container is None or model_container.model is None: + raise HTTPException(400, "No models are loaded.") + # Model list endpoint @app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) @@ -47,11 +51,8 @@ async def list_models(): return models.json() # Currently loaded model endpoint -@app.get("/v1/model", dependencies=[Depends(check_api_key)]) +@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def get_current_model(): - if model_container is None or model_container.model is None: - return HTTPException(400, "No models are loaded.") - model_card = ModelCard(id=model_container.get_model_path().name) return model_card.json() @@ -98,18 +99,15 @@ async def load_model(data: ModelLoadRequest): return EventSourceResponse(generator()) # Unload model endpoint -@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)]) +@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) async def unload_model(): global model_container - if model_container is None: - raise HTTPException(400, "No models are loaded.") - model_container.unload() model_container = None # Encode tokens endpoint -@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)]) +@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def encode_tokens(data: TokenEncodeRequest): tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist() response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) @@ -117,7 +115,7 @@ async def encode_tokens(data: TokenEncodeRequest): return response.json() # Decode tokens endpoint -@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)]) +@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def decode_tokens(data: TokenDecodeRequest): message = model_container.get_tokens(None, data.tokens, **data.get_params()) response = TokenDecodeResponse(text=message) @@ -125,7 +123,7 @@ async def decode_tokens(data: TokenDecodeRequest): return response.json() # Completions endpoint -@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) +@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def generate_completion(request: Request, data: CompletionRequest): model_path = model_container.get_model_path() @@ -151,7 +149,7 @@ async def generate_completion(request: Request, data: CompletionRequest): return response.json() # Chat completions endpoint -@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def generate_chat_completion(request: Request, data: ChatCompletionRequest): model_path = model_container.get_model_path()