Tree: Use unwrap and coalesce for optional handling
Python doesn't have proper handling of optionals. The only way to handle them is checking via an if statement if the value is None or by using the "or" keyword to unwrap optionals. Previously, I used the "or" method to unwrap, but this caused issues due to falsy values falling back to the default. This is especially the case with booleans were "False" changed to "True". Instead, add two new functions: unwrap and coalesce. Both function to properly implement a functional way of "None" coalescing. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7380a3b79a
commit
5ae2a91c04
5 changed files with 83 additions and 68 deletions
46
main.py
46
main.py
|
|
@ -28,7 +28,7 @@ from OAI.utils import (
|
|||
create_chat_completion_stream_chunk
|
||||
)
|
||||
from typing import Optional
|
||||
from utils import get_generator_error, get_sse_packet, load_progress
|
||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
from uuid import uuid4
|
||||
|
||||
app = FastAPI()
|
||||
|
|
@ -54,17 +54,17 @@ app.add_middleware(
|
|||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
model_config = config.get("model") or {}
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
if "model_dir" in model_config:
|
||||
model_path = pathlib.Path(model_config["model_dir"])
|
||||
else:
|
||||
model_path = pathlib.Path("models")
|
||||
|
||||
draft_config = model_config.get("draft") or {}
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
draft_model_dir = draft_config.get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
if model_config.get("use_dummy_models") or False:
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
|
|
@ -89,19 +89,19 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not data.name:
|
||||
raise HTTPException(400, "model_name not found.")
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
load_data = data.dict()
|
||||
|
||||
# TODO: Add API exception if draft directory isn't found
|
||||
draft_config = model_config.get("draft") or {}
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
|
||||
|
||||
load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
|
||||
load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models")
|
||||
|
||||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
|
|
@ -167,9 +167,9 @@ async def unload_model():
|
|||
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_all_loras():
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
|
|
@ -196,9 +196,9 @@ async def load_model(data: LoraLoadRequest):
|
|||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
|
||||
|
||||
|
|
@ -208,8 +208,8 @@ async def load_model(data: LoraLoadRequest):
|
|||
|
||||
result = model_container.load_loras(lora_dir, **data.dict())
|
||||
return LoraLoadResponse(
|
||||
success = result.get("success") or [],
|
||||
failure = result.get("failure") or []
|
||||
success = unwrap(result.get("success"), []),
|
||||
failure = unwrap(result.get("failure"), [])
|
||||
)
|
||||
|
||||
# Unload lora endpoint
|
||||
|
|
@ -234,7 +234,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text = message or "")
|
||||
response = TokenDecodeResponse(text = unwrap(message, ""))
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -337,7 +337,7 @@ if __name__ == "__main__":
|
|||
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
||||
try:
|
||||
with open('config.yml', 'r', encoding = "utf8") as config_file:
|
||||
config = yaml.safe_load(config_file) or {}
|
||||
config = unwrap(yaml.safe_load(config_file), {})
|
||||
except Exception as e:
|
||||
print(
|
||||
"The YAML config couldn't load because of the following error:",
|
||||
|
|
@ -348,10 +348,10 @@ if __name__ == "__main__":
|
|||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
if "model_name" in model_config:
|
||||
# TODO: Move this to model_container
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
||||
model_container = ModelContainer(model_path.resolve(), False, **model_config)
|
||||
|
|
@ -366,12 +366,12 @@ if __name__ == "__main__":
|
|||
loading_bar.next()
|
||||
|
||||
# Load loras
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
if "loras" in lora_config:
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
model_container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
network_config = config.get("network") or {}
|
||||
network_config = unwrap(config.get("network"), {})
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=network_config.get("host", "127.0.0.1"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue