Tree: Use unwrap and coalesce for optional handling

Python doesn't have proper handling of optionals. The only way to
handle them is checking via an if statement if the value is None or
by using the "or" keyword to unwrap optionals.

Previously, I used the "or" method to unwrap, but this caused issues
due to falsy values falling back to the default. This is especially
the case with booleans were "False" changed to "True".

Instead, add two new functions: unwrap and coalesce. Both function
to properly implement a functional way of "None" coalescing.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-09 21:52:17 -05:00
parent 7380a3b79a
commit 5ae2a91c04
5 changed files with 83 additions and 68 deletions

View file

@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Union
from utils import coalesce
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
@ -83,7 +84,7 @@ class CommonCompletionRequest(BaseModel):
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_range": self.repetition_range or self.repetition_penalty_range or -1,
"repetition_range": coalesce(self.repetition_range, self.repetition_penalty_range, -1),
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,

View file

@ -12,6 +12,7 @@ from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List, Dict
from utils import unwrap
# Check fastchat
try:
@ -30,7 +31,7 @@ def create_completion_response(text: str, prompt_tokens: int, completion_tokens:
response = CompletionResponse(
choices = [choice],
model = model_name or "",
model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
@ -51,7 +52,7 @@ def create_chat_completion_response(text: str, prompt_tokens: int, completion_to
response = ChatCompletionResponse(
choices = [choice],
model = model_name or "",
model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
@ -80,7 +81,7 @@ def create_chat_completion_stream_chunk(const_id: str,
chunk = ChatCompletionStreamChunk(
id = const_id,
choices = [choice],
model = model_name or ""
model = unwrap(model_name, "")
)
return chunk

46
main.py
View file

@ -28,7 +28,7 @@ from OAI.utils import (
create_chat_completion_stream_chunk
)
from typing import Optional
from utils import get_generator_error, get_sse_packet, load_progress
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from uuid import uuid4
app = FastAPI()
@ -54,17 +54,17 @@ app.add_middleware(
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
model_config = config.get("model") or {}
model_config = unwrap(config.get("model"), {})
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
model_path = pathlib.Path("models")
draft_config = model_config.get("draft") or {}
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = draft_config.get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if model_config.get("use_dummy_models") or False:
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo"))
return models
@ -89,19 +89,19 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "model_name not found.")
model_config = config.get("model") or {}
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_config = unwrap(config.get("model"), {})
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / data.name
load_data = data.dict()
# TODO: Add API exception if draft directory isn't found
draft_config = model_config.get("draft") or {}
draft_config = unwrap(model_config.get("draft"), {})
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models")
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
@ -167,9 +167,9 @@ async def unload_model():
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_all_loras():
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
@ -196,9 +196,9 @@ async def load_model(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
@ -208,8 +208,8 @@ async def load_model(data: LoraLoadRequest):
result = model_container.load_loras(lora_dir, **data.dict())
return LoraLoadResponse(
success = result.get("success") or [],
failure = result.get("failure") or []
success = unwrap(result.get("success"), []),
failure = unwrap(result.get("failure"), [])
)
# Unload lora endpoint
@ -234,7 +234,7 @@ async def encode_tokens(data: TokenEncodeRequest):
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text = message or "")
response = TokenDecodeResponse(text = unwrap(message, ""))
return response
@ -337,7 +337,7 @@ if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function
try:
with open('config.yml', 'r', encoding = "utf8") as config_file:
config = yaml.safe_load(config_file) or {}
config = unwrap(yaml.safe_load(config_file), {})
except Exception as e:
print(
"The YAML config couldn't load because of the following error:",
@ -348,10 +348,10 @@ if __name__ == "__main__":
# If an initial model name is specified, create a container and load the model
model_config = config.get("model") or {}
model_config = unwrap(config.get("model"), {})
if "model_name" in model_config:
# TODO: Move this to model_container
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path.resolve(), False, **model_config)
@ -366,12 +366,12 @@ if __name__ == "__main__":
loading_bar.next()
# Load loras
lora_config = model_config.get("lora") or {}
lora_config = unwrap(model_config.get("lora"), {})
if "loras" in lora_config:
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model_container.load_loras(lora_dir.resolve(), **lora_config)
network_config = config.get("network") or {}
network_config = unwrap(config.get("network"), {})
uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),

View file

@ -13,6 +13,7 @@ from exllamav2.generator import(
ExLlamaV2Sampler
)
from typing import List, Optional, Union
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
@ -30,7 +31,7 @@ class ModelContainer:
cache_fp8: bool = False
gpu_split_auto: bool = True
gpu_split: list or None = None
gpu_split: Optional[list] = None
active_loras: List[ExLlamaV2Lora] = []
@ -68,7 +69,7 @@ class ModelContainer:
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
self.gpu_split = kwargs.get("gpu_split")
self.gpu_split_auto = kwargs.get("gpu_split_auto") or True
self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve())
@ -78,14 +79,14 @@ class ModelContainer:
base_seq_len = self.config.max_seq_len
# Then override the max_seq_len if present
self.config.max_seq_len = kwargs.get("max_seq_len") or 4096
self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0
self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
# Automatically calculate rope alpha
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len)
self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len))
# Turn off flash attention?
self.config.no_flash_attn = kwargs.get("no_flash_attn") or False
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False)
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
"""
@ -93,11 +94,11 @@ class ModelContainer:
self.config.set_low_mem()
"""
chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len)
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2
draft_args = kwargs.get("draft") or {}
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
@ -109,14 +110,14 @@ class ModelContainer:
if enable_draft:
self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models")
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0
self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len))
self.draft_config.max_seq_len = self.config.max_seq_len
if "chunk_size" in kwargs:
@ -151,13 +152,13 @@ class ModelContainer:
Load loras
"""
loras = kwargs.get("loras") or []
loras = unwrap(kwargs.get("loras"), [])
success: List[str] = []
failure: List[str] = []
for lora in loras:
lora_name = lora.get("name") or None
lora_scaling = lora.get("scaling") or 1.0
lora_name = lora.get("name")
lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
@ -265,13 +266,13 @@ class ModelContainer:
# Assume token encoding
return self.tokenizer.encode(
text,
add_bos = kwargs.get("add_bos_token") or True,
encode_special_tokens = kwargs.get("encode_special_tokens") or True
add_bos = unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
)
if ids:
# Assume token decoding
ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0]
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
def generate(self, prompt: str, **kwargs):
@ -311,10 +312,10 @@ class ModelContainer:
"""
token_healing = kwargs.get("token_healing") or False
max_tokens = kwargs.get("max_tokens") or 150
stream_interval = kwargs.get("stream_interval") or 0
generate_window = min(kwargs.get("generate_window") or 512, max_tokens)
token_healing = unwrap(kwargs.get("token_healing"), False)
max_tokens = unwrap(kwargs.get("max_tokens"), 150)
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens)
# Sampler settings
@ -322,42 +323,43 @@ class ModelContainer:
# Warn of unsupported settings if the setting is enabled
if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"):
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"):
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
#Apply settings
gen_settings.temperature = kwargs.get("temperature") or 1.0
gen_settings.temperature_last = kwargs.get("temperature_last") or False
gen_settings.top_k = kwargs.get("top_k") or 0
gen_settings.top_p = kwargs.get("top_p") or 1.0
gen_settings.min_p = kwargs.get("min_p") or 0.0
gen_settings.tfs = kwargs.get("tfs") or 1.0
gen_settings.typical = kwargs.get("typical") or 1.0
gen_settings.mirostat = kwargs.get("mirostat") or False
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed fallback
# Always default to 0 if something goes wrong
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []
ban_eos_token = kwargs.get("ban_eos_token") or False
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
# Ban the EOS token if specified. If not, append to stop conditions as well.
@ -383,7 +385,7 @@ class ModelContainer:
ids = self.tokenizer.encode(
prompt,
add_bos = kwargs.get("add_bos_token") or True,
add_bos = unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = True
)
context_len = len(ids[0])

View file

@ -30,3 +30,14 @@ def get_generator_error(message: str):
def get_sse_packet(json_data: str):
return f"data: {json_data}\n\n"
# Unwrap function for Optionals
def unwrap(wrapped, default = None):
if wrapped is None:
return default
else:
return wrapped
# Coalesce function for multiple unwraps
def coalesce(*args):
return next((arg for arg in args if arg is not None), None)