Config: Isolate to a separate file
Reduce dependency of globals in main to simplify code a bit. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
0d2e726e82
commit
c9126c3145
2 changed files with 67 additions and 34 deletions
55
main.py
55
main.py
|
|
@ -1,7 +1,6 @@
|
|||
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||
import pathlib
|
||||
import uvicorn
|
||||
import yaml
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
|
@ -14,6 +13,14 @@ from progress.bar import IncrementalBar
|
|||
|
||||
import gen_logging
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from config import (
|
||||
read_config_from_file,
|
||||
get_gen_logging_config,
|
||||
get_model_config,
|
||||
get_draft_model_config,
|
||||
get_lora_config,
|
||||
get_network_config
|
||||
)
|
||||
from generators import call_with_semaphore, generate_with_semaphore
|
||||
from model import ModelContainer
|
||||
from OAI.types.completion import CompletionRequest
|
||||
|
|
@ -48,7 +55,6 @@ app = FastAPI()
|
|||
|
||||
# Globally scoped variables. Undefined until initalized in main
|
||||
MODEL_CONTAINER: Optional[ModelContainer] = None
|
||||
config: dict = {}
|
||||
|
||||
|
||||
def _check_model_container():
|
||||
|
|
@ -71,12 +77,11 @@ app.add_middleware(
|
|||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
model_config = get_model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
draft_model_dir = draft_config.get("draft_model_dir")
|
||||
draft_model_dir = get_draft_model_config().get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
|
|
@ -127,9 +132,7 @@ async def get_current_model():
|
|||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models():
|
||||
"""Lists all draft models in the model directory."""
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
|
||||
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
|
|
@ -149,13 +152,11 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not data.name:
|
||||
raise HTTPException(400, "model_name not found.")
|
||||
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
load_data = data.model_dump()
|
||||
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
raise HTTPException(
|
||||
|
|
@ -163,7 +164,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
)
|
||||
|
||||
load_data["draft"]["draft_model_dir"] = unwrap(
|
||||
draft_config.get("draft_model_dir"), "models"
|
||||
get_draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
|
|
@ -240,10 +241,7 @@ async def unload_model():
|
|||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_all_loras():
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
|
||||
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
return loras
|
||||
|
|
@ -281,9 +279,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
raise HTTPException(
|
||||
400,
|
||||
|
|
@ -478,25 +474,16 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
||||
try:
|
||||
with open("config.yml", "r", encoding="utf8") as config_file:
|
||||
config = unwrap(yaml.safe_load(config_file), {})
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"The YAML config couldn't load because of the following error: "
|
||||
f"\n\n{exc}"
|
||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
||||
)
|
||||
config = {}
|
||||
# Load from YAML config
|
||||
read_config_from_file(pathlib.Path("config.yml"))
|
||||
|
||||
network_config = unwrap(config.get("network"), {})
|
||||
network_config = get_network_config()
|
||||
|
||||
# Initialize auth keys
|
||||
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
log_config = unwrap(config.get("logging"), {})
|
||||
log_config = get_gen_logging_config()
|
||||
if log_config:
|
||||
gen_logging.update_from_dict(log_config)
|
||||
|
||||
|
|
@ -504,7 +491,7 @@ if __name__ == "__main__":
|
|||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
model_config = get_model_config()
|
||||
if "model_name" in model_config:
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
|
@ -521,7 +508,7 @@ if __name__ == "__main__":
|
|||
loading_bar.next()
|
||||
|
||||
# Load loras
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
lora_config = get_lora_config()
|
||||
if "loras" in lora_config:
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue