API: Back to async
According to FastAPI docs, if you're using a generic function, running it in async will make it more performant (which makes sense since running def functions for routes will automatically run the caller through a threadpool). Tested and everything works fine. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b0c295dd2f
commit
d2c6ae2d35
2 changed files with 24 additions and 20 deletions
|
|
@ -76,7 +76,9 @@ def load_auth_keys(disable_from_config: bool):
|
|||
)
|
||||
|
||||
|
||||
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
|
||||
async def check_api_key(
|
||||
x_api_key: str = Header(None), authorization: str = Header(None)
|
||||
):
|
||||
"""Check if the API key is valid."""
|
||||
|
||||
# Allow request if auth is disabled
|
||||
|
|
@ -102,7 +104,9 @@ def check_api_key(x_api_key: str = Header(None), authorization: str = Header(Non
|
|||
raise HTTPException(401, "Please provide an API key")
|
||||
|
||||
|
||||
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
|
||||
async def check_admin_key(
|
||||
x_admin_key: str = Header(None), authorization: str = Header(None)
|
||||
):
|
||||
"""Check if the admin key is valid."""
|
||||
|
||||
# Allow request if auth is disabled
|
||||
|
|
|
|||
36
main.py
36
main.py
|
|
@ -92,7 +92,7 @@ app = FastAPI(
|
|||
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
||||
|
||||
|
||||
def _check_model_container():
|
||||
async def _check_model_container():
|
||||
if MODEL_CONTAINER is None or not (
|
||||
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
|
||||
):
|
||||
|
|
@ -116,7 +116,7 @@ app.add_middleware(
|
|||
# Model list endpoint
|
||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
def list_models():
|
||||
async def list_models():
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = get_model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
|
|
@ -140,7 +140,7 @@ def list_models():
|
|||
"/v1/internal/model/info",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
)
|
||||
def get_current_model():
|
||||
async def get_current_model():
|
||||
"""Returns the currently loaded model."""
|
||||
model_name = MODEL_CONTAINER.get_model_path().name
|
||||
prompt_template = MODEL_CONTAINER.prompt_template
|
||||
|
|
@ -173,7 +173,7 @@ def get_current_model():
|
|||
|
||||
|
||||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
def list_draft_models():
|
||||
async def list_draft_models():
|
||||
"""Lists all draft models in the model directory."""
|
||||
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
|
@ -225,7 +225,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
|
||||
# Unload the existing model
|
||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||
unload_model()
|
||||
await unload_model()
|
||||
|
||||
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
try:
|
||||
for module, modules in load_status:
|
||||
if await request.is_disconnected():
|
||||
unload_model()
|
||||
await unload_model()
|
||||
break
|
||||
|
||||
if module == 0:
|
||||
|
|
@ -293,7 +293,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
"/v1/model/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
)
|
||||
def unload_model():
|
||||
async def unload_model():
|
||||
"""Unloads the currently loaded model."""
|
||||
global MODEL_CONTAINER
|
||||
|
||||
|
|
@ -303,7 +303,7 @@ def unload_model():
|
|||
|
||||
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
def get_templates():
|
||||
async def get_templates():
|
||||
templates = get_all_templates()
|
||||
template_strings = list(map(lambda template: template.stem, templates))
|
||||
return TemplateList(data=template_strings)
|
||||
|
|
@ -313,7 +313,7 @@ def get_templates():
|
|||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
)
|
||||
def switch_template(data: TemplateSwitchRequest):
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template"""
|
||||
if not data.name:
|
||||
raise HTTPException(400, "New template name not found.")
|
||||
|
|
@ -329,7 +329,7 @@ def switch_template(data: TemplateSwitchRequest):
|
|||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
)
|
||||
def unload_template():
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
MODEL_CONTAINER.prompt_template = None
|
||||
|
|
@ -338,7 +338,7 @@ def unload_template():
|
|||
# Sampler override endpoints
|
||||
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
def list_sampler_overrides():
|
||||
async def list_sampler_overrides():
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return get_sampler_overrides()
|
||||
|
|
@ -348,7 +348,7 @@ def list_sampler_overrides():
|
|||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
|
|
@ -370,7 +370,7 @@ def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
def unload_sampler_override():
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
set_overrides_from_dict({})
|
||||
|
|
@ -379,7 +379,7 @@ def unload_sampler_override():
|
|||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
def get_all_loras():
|
||||
async def get_all_loras():
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
|
@ -392,7 +392,7 @@ def get_all_loras():
|
|||
"/v1/lora",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
)
|
||||
def get_active_loras():
|
||||
async def get_active_loras():
|
||||
"""Returns the currently loaded loras."""
|
||||
active_loras = LoraList(
|
||||
data=list(
|
||||
|
|
@ -455,7 +455,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||
"/v1/lora/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
)
|
||||
def unload_loras():
|
||||
async def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
MODEL_CONTAINER.unload(loras_only=True)
|
||||
|
||||
|
|
@ -465,7 +465,7 @@ def unload_loras():
|
|||
"/v1/token/encode",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
)
|
||||
def encode_tokens(data: TokenEncodeRequest):
|
||||
async def encode_tokens(data: TokenEncodeRequest):
|
||||
"""Encodes a string into tokens."""
|
||||
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
|
|
@ -479,7 +479,7 @@ def encode_tokens(data: TokenEncodeRequest):
|
|||
"/v1/token/decode",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
)
|
||||
def decode_tokens(data: TokenDecodeRequest):
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
"""Decodes tokens into a string."""
|
||||
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue