API: Back to async

According to FastAPI docs, if you're using a generic function, running
it in async will make it more performant (which makes sense since
running def functions for routes will automatically run the caller
through a threadpool).

Tested and everything works fine.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-04 22:59:10 -05:00 committed by Brian Dashore
parent b0c295dd2f
commit d2c6ae2d35
2 changed files with 24 additions and 20 deletions

View file

@ -76,7 +76,9 @@ def load_auth_keys(disable_from_config: bool):
)
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the API key is valid."""
# Allow request if auth is disabled
@ -102,7 +104,9 @@ def check_api_key(x_api_key: str = Header(None), authorization: str = Header(Non
raise HTTPException(401, "Please provide an API key")
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
async def check_admin_key(
x_admin_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the admin key is valid."""
# Allow request if auth is disabled

36
main.py
View file

@ -92,7 +92,7 @@ app = FastAPI(
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
def _check_model_container():
async def _check_model_container():
if MODEL_CONTAINER is None or not (
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
):
@ -116,7 +116,7 @@ app.add_middleware(
# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
def list_models():
async def list_models():
"""Lists all models in the model directory."""
model_config = get_model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
@ -140,7 +140,7 @@ def list_models():
"/v1/internal/model/info",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def get_current_model():
async def get_current_model():
"""Returns the currently loaded model."""
model_name = MODEL_CONTAINER.get_model_path().name
prompt_template = MODEL_CONTAINER.prompt_template
@ -173,7 +173,7 @@ def get_current_model():
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
def list_draft_models():
async def list_draft_models():
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)
@ -225,7 +225,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
# Unload the existing model
if MODEL_CONTAINER and MODEL_CONTAINER.model:
unload_model()
await unload_model()
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
@ -235,7 +235,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
try:
for module, modules in load_status:
if await request.is_disconnected():
unload_model()
await unload_model()
break
if module == 0:
@ -293,7 +293,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def unload_model():
async def unload_model():
"""Unloads the currently loaded model."""
global MODEL_CONTAINER
@ -303,7 +303,7 @@ def unload_model():
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
def get_templates():
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
return TemplateList(data=template_strings)
@ -313,7 +313,7 @@ def get_templates():
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def switch_template(data: TemplateSwitchRequest):
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
raise HTTPException(400, "New template name not found.")
@ -329,7 +329,7 @@ def switch_template(data: TemplateSwitchRequest):
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def unload_template():
async def unload_template():
"""Unloads the currently selected template"""
MODEL_CONTAINER.prompt_template = None
@ -338,7 +338,7 @@ def unload_template():
# Sampler override endpoints
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
def list_sampler_overrides():
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return get_sampler_overrides()
@ -348,7 +348,7 @@ def list_sampler_overrides():
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
def switch_sampler_override(data: SamplerOverrideSwitchRequest):
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
@ -370,7 +370,7 @@ def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
def unload_sampler_override():
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
set_overrides_from_dict({})
@ -379,7 +379,7 @@ def unload_sampler_override():
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
def get_all_loras():
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
@ -392,7 +392,7 @@ def get_all_loras():
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def get_active_loras():
async def get_active_loras():
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=list(
@ -455,7 +455,7 @@ async def load_lora(data: LoraLoadRequest):
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def unload_loras():
async def unload_loras():
"""Unloads the currently loaded loras."""
MODEL_CONTAINER.unload(loras_only=True)
@ -465,7 +465,7 @@ def unload_loras():
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def encode_tokens(data: TokenEncodeRequest):
async def encode_tokens(data: TokenEncodeRequest):
"""Encodes a string into tokens."""
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
tokens = unwrap(raw_tokens, [])
@ -479,7 +479,7 @@ def encode_tokens(data: TokenEncodeRequest):
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
def decode_tokens(data: TokenDecodeRequest):
async def decode_tokens(data: TokenDecodeRequest):
"""Decodes tokens into a string."""
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))