Config: Isolate to a separate file

Reduce dependency of globals in main to simplify code a bit.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-23 23:02:37 -05:00
parent 0d2e726e82
commit c9126c3145
2 changed files with 67 additions and 34 deletions

55
main.py
View file

@ -1,7 +1,6 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import pathlib
import uvicorn
import yaml
from asyncio import CancelledError
from typing import Optional
from uuid import uuid4
@ -14,6 +13,14 @@ from progress.bar import IncrementalBar
import gen_logging
from auth import check_admin_key, check_api_key, load_auth_keys
from config import (
read_config_from_file,
get_gen_logging_config,
get_model_config,
get_draft_model_config,
get_lora_config,
get_network_config
)
from generators import call_with_semaphore, generate_with_semaphore
from model import ModelContainer
from OAI.types.completion import CompletionRequest
@ -48,7 +55,6 @@ app = FastAPI()
# Globally scoped variables. Undefined until initalized in main
MODEL_CONTAINER: Optional[ModelContainer] = None
config: dict = {}
def _check_model_container():
@ -71,12 +77,11 @@ app.add_middleware(
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = unwrap(config.get("model"), {})
model_config = get_model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = draft_config.get("draft_model_dir")
draft_model_dir = get_draft_model_config().get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
@ -127,9 +132,7 @@ async def get_current_model():
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
model_config = unwrap(config.get("model"), {})
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@ -149,13 +152,11 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "model_name not found.")
model_config = unwrap(config.get("model"), {})
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
model_path = model_path / data.name
load_data = data.model_dump()
draft_config = unwrap(model_config.get("draft"), {})
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(
@ -163,7 +164,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
load_data["draft"]["draft_model_dir"] = unwrap(
draft_config.get("draft_model_dir"), "models"
get_draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
@ -240,10 +241,7 @@ async def unload_model():
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
return loras
@ -281,9 +279,7 @@ async def load_lora(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(
400,
@ -478,25 +474,16 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function
try:
with open("config.yml", "r", encoding="utf8") as config_file:
config = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
config = {}
# Load from YAML config
read_config_from_file(pathlib.Path("config.yml"))
network_config = unwrap(config.get("network"), {})
network_config = get_network_config()
# Initialize auth keys
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
# Override the generation log options if given
log_config = unwrap(config.get("logging"), {})
log_config = get_gen_logging_config()
if log_config:
gen_logging.update_from_dict(log_config)
@ -504,7 +491,7 @@ if __name__ == "__main__":
# If an initial model name is specified, create a container
# and load the model
model_config = unwrap(config.get("model"), {})
model_config = get_model_config()
if "model_name" in model_config:
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_config.get("model_name")
@ -521,7 +508,7 @@ if __name__ == "__main__":
loading_bar.next()
# Load loras
lora_config = unwrap(model_config.get("lora"), {})
lora_config = get_lora_config()
if "loras" in lora_config:
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)