Tree: Use unwrap and coalesce for optional handling

Python doesn't have proper handling of optionals. The only way to
handle them is checking via an if statement if the value is None or
by using the "or" keyword to unwrap optionals.

Previously, I used the "or" method to unwrap, but this caused issues
due to falsy values falling back to the default. This is especially
the case with booleans were "False" changed to "True".

Instead, add two new functions: unwrap and coalesce. Both function
to properly implement a functional way of "None" coalescing.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-09 21:52:17 -05:00
parent 7380a3b79a
commit 5ae2a91c04
5 changed files with 83 additions and 68 deletions

46
main.py
View file

@ -28,7 +28,7 @@ from OAI.utils import (
create_chat_completion_stream_chunk
)
from typing import Optional
from utils import get_generator_error, get_sse_packet, load_progress
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from uuid import uuid4
app = FastAPI()
@ -54,17 +54,17 @@ app.add_middleware(
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
model_config = config.get("model") or {}
model_config = unwrap(config.get("model"), {})
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
model_path = pathlib.Path("models")
draft_config = model_config.get("draft") or {}
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = draft_config.get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if model_config.get("use_dummy_models") or False:
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo"))
return models
@ -89,19 +89,19 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "model_name not found.")
model_config = config.get("model") or {}
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_config = unwrap(config.get("model"), {})
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / data.name
load_data = data.dict()
# TODO: Add API exception if draft directory isn't found
draft_config = model_config.get("draft") or {}
draft_config = unwrap(model_config.get("draft"), {})
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models")
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
@ -167,9 +167,9 @@ async def unload_model():
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_all_loras():
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
@ -196,9 +196,9 @@ async def load_model(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
@ -208,8 +208,8 @@ async def load_model(data: LoraLoadRequest):
result = model_container.load_loras(lora_dir, **data.dict())
return LoraLoadResponse(
success = result.get("success") or [],
failure = result.get("failure") or []
success = unwrap(result.get("success"), []),
failure = unwrap(result.get("failure"), [])
)
# Unload lora endpoint
@ -234,7 +234,7 @@ async def encode_tokens(data: TokenEncodeRequest):
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text = message or "")
response = TokenDecodeResponse(text = unwrap(message, ""))
return response
@ -337,7 +337,7 @@ if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function
try:
with open('config.yml', 'r', encoding = "utf8") as config_file:
config = yaml.safe_load(config_file) or {}
config = unwrap(yaml.safe_load(config_file), {})
except Exception as e:
print(
"The YAML config couldn't load because of the following error:",
@ -348,10 +348,10 @@ if __name__ == "__main__":
# If an initial model name is specified, create a container and load the model
model_config = config.get("model") or {}
model_config = unwrap(config.get("model"), {})
if "model_name" in model_config:
# TODO: Move this to model_container
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path.resolve(), False, **model_config)
@ -366,12 +366,12 @@ if __name__ == "__main__":
loading_bar.next()
# Load loras
lora_config = model_config.get("lora") or {}
lora_config = unwrap(model_config.get("lora"), {})
if "loras" in lora_config:
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model_container.load_loras(lora_dir.resolve(), **lora_config)
network_config = config.get("network") or {}
network_config = unwrap(config.get("network"), {})
uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),