From 5ae2a91c04a437b369be193628c19d4c03edd9ef Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 9 Dec 2023 21:52:17 -0500 Subject: [PATCH] Tree: Use unwrap and coalesce for optional handling Python doesn't have proper handling of optionals. The only way to handle them is checking via an if statement if the value is None or by using the "or" keyword to unwrap optionals. Previously, I used the "or" method to unwrap, but this caused issues due to falsy values falling back to the default. This is especially the case with booleans were "False" changed to "True". Instead, add two new functions: unwrap and coalesce. Both function to properly implement a functional way of "None" coalescing. Signed-off-by: kingbri --- OAI/types/common.py | 3 +- OAI/utils.py | 7 ++-- main.py | 46 ++++++++++++------------- model.py | 84 +++++++++++++++++++++++---------------------- utils.py | 11 ++++++ 5 files changed, 83 insertions(+), 68 deletions(-) diff --git a/OAI/types/common.py b/OAI/types/common.py index cbdc44e..5584590 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field from typing import List, Dict, Optional, Union +from utils import coalesce class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) @@ -83,7 +84,7 @@ class CommonCompletionRequest(BaseModel): "min_p": self.min_p, "tfs": self.tfs, "repetition_penalty": self.repetition_penalty, - "repetition_range": self.repetition_range or self.repetition_penalty_range or -1, + "repetition_range": coalesce(self.repetition_range, self.repetition_penalty_range, -1), "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, diff --git a/OAI/utils.py b/OAI/utils.py index bb0c93b..1bde11a 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -12,6 +12,7 @@ from OAI.types.lora import LoraList, LoraCard from OAI.types.model import ModelList, ModelCard from packaging import version from typing import Optional, List, Dict +from utils import unwrap # Check fastchat try: @@ -30,7 +31,7 @@ def create_completion_response(text: str, prompt_tokens: int, completion_tokens: response = CompletionResponse( choices = [choice], - model = model_name or "", + model = unwrap(model_name, ""), usage = UsageStats(prompt_tokens = prompt_tokens, completion_tokens = completion_tokens, total_tokens = prompt_tokens + completion_tokens) @@ -51,7 +52,7 @@ def create_chat_completion_response(text: str, prompt_tokens: int, completion_to response = ChatCompletionResponse( choices = [choice], - model = model_name or "", + model = unwrap(model_name, ""), usage = UsageStats(prompt_tokens = prompt_tokens, completion_tokens = completion_tokens, total_tokens = prompt_tokens + completion_tokens) @@ -80,7 +81,7 @@ def create_chat_completion_stream_chunk(const_id: str, chunk = ChatCompletionStreamChunk( id = const_id, choices = [choice], - model = model_name or "" + model = unwrap(model_name, "") ) return chunk diff --git a/main.py b/main.py index 5627f53..8f71248 100644 --- a/main.py +++ b/main.py @@ -28,7 +28,7 @@ from OAI.utils import ( create_chat_completion_stream_chunk ) from typing import Optional -from utils import get_generator_error, get_sse_packet, load_progress +from utils import get_generator_error, get_sse_packet, load_progress, unwrap from uuid import uuid4 app = FastAPI() @@ -54,17 +54,17 @@ app.add_middleware( @app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): - model_config = config.get("model") or {} + model_config = unwrap(config.get("model"), {}) if "model_dir" in model_config: model_path = pathlib.Path(model_config["model_dir"]) else: model_path = pathlib.Path("models") - draft_config = model_config.get("draft") or {} + draft_config = unwrap(model_config.get("draft"), {}) draft_model_dir = draft_config.get("draft_model_dir") models = get_model_list(model_path.resolve(), draft_model_dir) - if model_config.get("use_dummy_models") or False: + if unwrap(model_config.get("use_dummy_models"), False): models.data.insert(0, ModelCard(id = "gpt-3.5-turbo")) return models @@ -89,19 +89,19 @@ async def load_model(request: Request, data: ModelLoadRequest): if not data.name: raise HTTPException(400, "model_name not found.") - model_config = config.get("model") or {} - model_path = pathlib.Path(model_config.get("model_dir") or "models") + model_config = unwrap(config.get("model"), {}) + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / data.name load_data = data.dict() # TODO: Add API exception if draft directory isn't found - draft_config = model_config.get("draft") or {} + draft_config = unwrap(model_config.get("draft"), {}) if data.draft: if not data.draft.draft_model_name: raise HTTPException(400, "draft_model_name was not found inside the draft object.") - load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models" + load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models") if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") @@ -167,9 +167,9 @@ async def unload_model(): @app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def get_all_loras(): - model_config = config.get("model") or {} - lora_config = model_config.get("lora") or {} - lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras") + model_config = unwrap(config.get("model"), {}) + lora_config = unwrap(model_config.get("lora"), {}) + lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) @@ -196,9 +196,9 @@ async def load_model(data: LoraLoadRequest): if not data.loras: raise HTTPException(400, "List of loras to load is not found.") - model_config = config.get("model") or {} - lora_config = model_config.get("lora") or {} - lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras") + model_config = unwrap(config.get("model"), {}) + lora_config = unwrap(model_config.get("lora"), {}) + lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) if not lora_dir.exists(): raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?") @@ -208,8 +208,8 @@ async def load_model(data: LoraLoadRequest): result = model_container.load_loras(lora_dir, **data.dict()) return LoraLoadResponse( - success = result.get("success") or [], - failure = result.get("failure") or [] + success = unwrap(result.get("success"), []), + failure = unwrap(result.get("failure"), []) ) # Unload lora endpoint @@ -234,7 +234,7 @@ async def encode_tokens(data: TokenEncodeRequest): @app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def decode_tokens(data: TokenDecodeRequest): message = model_container.get_tokens(None, data.tokens, **data.get_params()) - response = TokenDecodeResponse(text = message or "") + response = TokenDecodeResponse(text = unwrap(message, "")) return response @@ -337,7 +337,7 @@ if __name__ == "__main__": # Load from YAML config. Possibly add a config -> kwargs conversion function try: with open('config.yml', 'r', encoding = "utf8") as config_file: - config = yaml.safe_load(config_file) or {} + config = unwrap(yaml.safe_load(config_file), {}) except Exception as e: print( "The YAML config couldn't load because of the following error:", @@ -348,10 +348,10 @@ if __name__ == "__main__": # If an initial model name is specified, create a container and load the model - model_config = config.get("model") or {} + model_config = unwrap(config.get("model"), {}) if "model_name" in model_config: # TODO: Move this to model_container - model_path = pathlib.Path(model_config.get("model_dir") or "models") + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / model_config.get("model_name") model_container = ModelContainer(model_path.resolve(), False, **model_config) @@ -366,12 +366,12 @@ if __name__ == "__main__": loading_bar.next() # Load loras - lora_config = model_config.get("lora") or {} + lora_config = unwrap(model_config.get("lora"), {}) if "loras" in lora_config: - lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras") + lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) model_container.load_loras(lora_dir.resolve(), **lora_config) - network_config = config.get("network") or {} + network_config = unwrap(config.get("network"), {}) uvicorn.run( app, host=network_config.get("host", "127.0.0.1"), diff --git a/model.py b/model.py index 7190cd0..582b008 100644 --- a/model.py +++ b/model.py @@ -13,6 +13,7 @@ from exllamav2.generator import( ExLlamaV2Sampler ) from typing import List, Optional, Union +from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 @@ -30,7 +31,7 @@ class ModelContainer: cache_fp8: bool = False gpu_split_auto: bool = True - gpu_split: list or None = None + gpu_split: Optional[list] = None active_loras: List[ExLlamaV2Lora] = [] @@ -68,7 +69,7 @@ class ModelContainer: self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" self.gpu_split = kwargs.get("gpu_split") - self.gpu_split_auto = kwargs.get("gpu_split_auto") or True + self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) self.config = ExLlamaV2Config() self.config.model_dir = str(model_directory.resolve()) @@ -78,14 +79,14 @@ class ModelContainer: base_seq_len = self.config.max_seq_len # Then override the max_seq_len if present - self.config.max_seq_len = kwargs.get("max_seq_len") or 4096 - self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0 + self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) + self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0) # Automatically calculate rope alpha - self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len) + self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)) # Turn off flash attention? - self.config.no_flash_attn = kwargs.get("no_flash_attn") or False + self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False) # low_mem is currently broken in exllamav2. Don't use it until it's fixed. """ @@ -93,11 +94,11 @@ class ModelContainer: self.config.set_low_mem() """ - chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len) + chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len) self.config.max_input_len = chunk_size self.config.max_attn_size = chunk_size ** 2 - draft_args = kwargs.get("draft") or {} + draft_args = unwrap(kwargs.get("draft"), {}) draft_model_name = draft_args.get("draft_model_name") enable_draft = draft_args and draft_model_name @@ -109,14 +110,14 @@ class ModelContainer: if enable_draft: self.draft_config = ExLlamaV2Config() - draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models") + draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models")) draft_model_path = draft_model_path / draft_model_name self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0 - self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) + self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0) + self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len)) self.draft_config.max_seq_len = self.config.max_seq_len if "chunk_size" in kwargs: @@ -151,13 +152,13 @@ class ModelContainer: Load loras """ - loras = kwargs.get("loras") or [] + loras = unwrap(kwargs.get("loras"), []) success: List[str] = [] failure: List[str] = [] for lora in loras: - lora_name = lora.get("name") or None - lora_scaling = lora.get("scaling") or 1.0 + lora_name = lora.get("name") + lora_scaling = unwrap(lora.get("scaling"), 1.0) if lora_name is None: print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.") @@ -265,13 +266,13 @@ class ModelContainer: # Assume token encoding return self.tokenizer.encode( text, - add_bos = kwargs.get("add_bos_token") or True, - encode_special_tokens = kwargs.get("encode_special_tokens") or True + add_bos = unwrap(kwargs.get("add_bos_token"), True), + encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True) ) if ids: # Assume token decoding ids = torch.tensor([ids]) - return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0] + return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0] def generate(self, prompt: str, **kwargs): @@ -311,10 +312,10 @@ class ModelContainer: """ - token_healing = kwargs.get("token_healing") or False - max_tokens = kwargs.get("max_tokens") or 150 - stream_interval = kwargs.get("stream_interval") or 0 - generate_window = min(kwargs.get("generate_window") or 512, max_tokens) + token_healing = unwrap(kwargs.get("token_healing"), False) + max_tokens = unwrap(kwargs.get("max_tokens"), 150) + stream_interval = unwrap(kwargs.get("stream_interval"), 0) + generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens) # Sampler settings @@ -322,42 +323,43 @@ class ModelContainer: # Warn of unsupported settings if the setting is enabled - if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"): + if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"): print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") - if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): + if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling") - if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): + if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") - if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"): + if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"): print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") #Apply settings - gen_settings.temperature = kwargs.get("temperature") or 1.0 - gen_settings.temperature_last = kwargs.get("temperature_last") or False - gen_settings.top_k = kwargs.get("top_k") or 0 - gen_settings.top_p = kwargs.get("top_p") or 1.0 - gen_settings.min_p = kwargs.get("min_p") or 0.0 - gen_settings.tfs = kwargs.get("tfs") or 1.0 - gen_settings.typical = kwargs.get("typical") or 1.0 - gen_settings.mirostat = kwargs.get("mirostat") or False + gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) + gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False) + gen_settings.top_k = unwrap(kwargs.get("top_k"), 0) + gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0) + gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0) + gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0) + gen_settings.typical = unwrap(kwargs.get("typical"), 1.0) + gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) # Default tau and eta fallbacks don't matter if mirostat is off - gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5 - gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1 - gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0 - gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len + gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) + gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) + gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0) + gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len) # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed fallback + # Always default to 0 if something goes wrong fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range - gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0 + gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0) - stop_conditions: List[Union[str, int]] = kwargs.get("stop") or [] - ban_eos_token = kwargs.get("ban_eos_token") or False + stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) # Ban the EOS token if specified. If not, append to stop conditions as well. @@ -383,7 +385,7 @@ class ModelContainer: ids = self.tokenizer.encode( prompt, - add_bos = kwargs.get("add_bos_token") or True, + add_bos = unwrap(kwargs.get("add_bos_token"), True), encode_special_tokens = True ) context_len = len(ids[0]) diff --git a/utils.py b/utils.py index 623b27b..a243373 100644 --- a/utils.py +++ b/utils.py @@ -30,3 +30,14 @@ def get_generator_error(message: str): def get_sse_packet(json_data: str): return f"data: {json_data}\n\n" + +# Unwrap function for Optionals +def unwrap(wrapped, default = None): + if wrapped is None: + return default + else: + return wrapped + +# Coalesce function for multiple unwraps +def coalesce(*args): + return next((arg for arg in args if arg is not None), None)