Tree: Use unwrap and coalesce for optional handling
Python doesn't have proper handling of optionals. The only way to handle them is checking via an if statement if the value is None or by using the "or" keyword to unwrap optionals. Previously, I used the "or" method to unwrap, but this caused issues due to falsy values falling back to the default. This is especially the case with booleans were "False" changed to "True". Instead, add two new functions: unwrap and coalesce. Both function to properly implement a functional way of "None" coalescing. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7380a3b79a
commit
5ae2a91c04
5 changed files with 83 additions and 68 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Union
|
||||
from utils import coalesce
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
|
|
@ -83,7 +84,7 @@ class CommonCompletionRequest(BaseModel):
|
|||
"min_p": self.min_p,
|
||||
"tfs": self.tfs,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_range": self.repetition_range or self.repetition_penalty_range or -1,
|
||||
"repetition_range": coalesce(self.repetition_range, self.repetition_penalty_range, -1),
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from OAI.types.lora import LoraList, LoraCard
|
|||
from OAI.types.model import ModelList, ModelCard
|
||||
from packaging import version
|
||||
from typing import Optional, List, Dict
|
||||
from utils import unwrap
|
||||
|
||||
# Check fastchat
|
||||
try:
|
||||
|
|
@ -30,7 +31,7 @@ def create_completion_response(text: str, prompt_tokens: int, completion_tokens:
|
|||
|
||||
response = CompletionResponse(
|
||||
choices = [choice],
|
||||
model = model_name or "",
|
||||
model = unwrap(model_name, ""),
|
||||
usage = UsageStats(prompt_tokens = prompt_tokens,
|
||||
completion_tokens = completion_tokens,
|
||||
total_tokens = prompt_tokens + completion_tokens)
|
||||
|
|
@ -51,7 +52,7 @@ def create_chat_completion_response(text: str, prompt_tokens: int, completion_to
|
|||
|
||||
response = ChatCompletionResponse(
|
||||
choices = [choice],
|
||||
model = model_name or "",
|
||||
model = unwrap(model_name, ""),
|
||||
usage = UsageStats(prompt_tokens = prompt_tokens,
|
||||
completion_tokens = completion_tokens,
|
||||
total_tokens = prompt_tokens + completion_tokens)
|
||||
|
|
@ -80,7 +81,7 @@ def create_chat_completion_stream_chunk(const_id: str,
|
|||
chunk = ChatCompletionStreamChunk(
|
||||
id = const_id,
|
||||
choices = [choice],
|
||||
model = model_name or ""
|
||||
model = unwrap(model_name, "")
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
|
|
|||
46
main.py
46
main.py
|
|
@ -28,7 +28,7 @@ from OAI.utils import (
|
|||
create_chat_completion_stream_chunk
|
||||
)
|
||||
from typing import Optional
|
||||
from utils import get_generator_error, get_sse_packet, load_progress
|
||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
from uuid import uuid4
|
||||
|
||||
app = FastAPI()
|
||||
|
|
@ -54,17 +54,17 @@ app.add_middleware(
|
|||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
model_config = config.get("model") or {}
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
if "model_dir" in model_config:
|
||||
model_path = pathlib.Path(model_config["model_dir"])
|
||||
else:
|
||||
model_path = pathlib.Path("models")
|
||||
|
||||
draft_config = model_config.get("draft") or {}
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
draft_model_dir = draft_config.get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
if model_config.get("use_dummy_models") or False:
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
|
|
@ -89,19 +89,19 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not data.name:
|
||||
raise HTTPException(400, "model_name not found.")
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
load_data = data.dict()
|
||||
|
||||
# TODO: Add API exception if draft directory isn't found
|
||||
draft_config = model_config.get("draft") or {}
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
|
||||
|
||||
load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
|
||||
load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models")
|
||||
|
||||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
|
|
@ -167,9 +167,9 @@ async def unload_model():
|
|||
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_all_loras():
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
|
|
@ -196,9 +196,9 @@ async def load_model(data: LoraLoadRequest):
|
|||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
|
||||
|
||||
|
|
@ -208,8 +208,8 @@ async def load_model(data: LoraLoadRequest):
|
|||
|
||||
result = model_container.load_loras(lora_dir, **data.dict())
|
||||
return LoraLoadResponse(
|
||||
success = result.get("success") or [],
|
||||
failure = result.get("failure") or []
|
||||
success = unwrap(result.get("success"), []),
|
||||
failure = unwrap(result.get("failure"), [])
|
||||
)
|
||||
|
||||
# Unload lora endpoint
|
||||
|
|
@ -234,7 +234,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text = message or "")
|
||||
response = TokenDecodeResponse(text = unwrap(message, ""))
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -337,7 +337,7 @@ if __name__ == "__main__":
|
|||
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
||||
try:
|
||||
with open('config.yml', 'r', encoding = "utf8") as config_file:
|
||||
config = yaml.safe_load(config_file) or {}
|
||||
config = unwrap(yaml.safe_load(config_file), {})
|
||||
except Exception as e:
|
||||
print(
|
||||
"The YAML config couldn't load because of the following error:",
|
||||
|
|
@ -348,10 +348,10 @@ if __name__ == "__main__":
|
|||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
if "model_name" in model_config:
|
||||
# TODO: Move this to model_container
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
||||
model_container = ModelContainer(model_path.resolve(), False, **model_config)
|
||||
|
|
@ -366,12 +366,12 @@ if __name__ == "__main__":
|
|||
loading_bar.next()
|
||||
|
||||
# Load loras
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_config = unwrap(model_config.get("lora"), {})
|
||||
if "loras" in lora_config:
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
model_container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
network_config = config.get("network") or {}
|
||||
network_config = unwrap(config.get("network"), {})
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=network_config.get("host", "127.0.0.1"),
|
||||
|
|
|
|||
84
model.py
84
model.py
|
|
@ -13,6 +13,7 @@ from exllamav2.generator import(
|
|||
ExLlamaV2Sampler
|
||||
)
|
||||
from typing import List, Optional, Union
|
||||
from utils import coalesce, unwrap
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
|
|
@ -30,7 +31,7 @@ class ModelContainer:
|
|||
|
||||
cache_fp8: bool = False
|
||||
gpu_split_auto: bool = True
|
||||
gpu_split: list or None = None
|
||||
gpu_split: Optional[list] = None
|
||||
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
|
||||
|
|
@ -68,7 +69,7 @@ class ModelContainer:
|
|||
|
||||
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
|
||||
self.gpu_split = kwargs.get("gpu_split")
|
||||
self.gpu_split_auto = kwargs.get("gpu_split_auto") or True
|
||||
self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
|
|
@ -78,14 +79,14 @@ class ModelContainer:
|
|||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
# Then override the max_seq_len if present
|
||||
self.config.max_seq_len = kwargs.get("max_seq_len") or 4096
|
||||
self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0
|
||||
self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
|
||||
|
||||
# Automatically calculate rope alpha
|
||||
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len)
|
||||
self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len))
|
||||
|
||||
# Turn off flash attention?
|
||||
self.config.no_flash_attn = kwargs.get("no_flash_attn") or False
|
||||
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False)
|
||||
|
||||
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
|
||||
"""
|
||||
|
|
@ -93,11 +94,11 @@ class ModelContainer:
|
|||
self.config.set_low_mem()
|
||||
"""
|
||||
|
||||
chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len)
|
||||
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
|
||||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
|
||||
draft_args = kwargs.get("draft") or {}
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
|
|
@ -109,14 +110,14 @@ class ModelContainer:
|
|||
if enable_draft:
|
||||
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models")
|
||||
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0
|
||||
self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
||||
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
|
||||
self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len))
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
if "chunk_size" in kwargs:
|
||||
|
|
@ -151,13 +152,13 @@ class ModelContainer:
|
|||
Load loras
|
||||
"""
|
||||
|
||||
loras = kwargs.get("loras") or []
|
||||
loras = unwrap(kwargs.get("loras"), [])
|
||||
success: List[str] = []
|
||||
failure: List[str] = []
|
||||
|
||||
for lora in loras:
|
||||
lora_name = lora.get("name") or None
|
||||
lora_scaling = lora.get("scaling") or 1.0
|
||||
lora_name = lora.get("name")
|
||||
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
||||
|
||||
if lora_name is None:
|
||||
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
|
||||
|
|
@ -265,13 +266,13 @@ class ModelContainer:
|
|||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
add_bos = kwargs.get("add_bos_token") or True,
|
||||
encode_special_tokens = kwargs.get("encode_special_tokens") or True
|
||||
add_bos = unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
|
||||
)
|
||||
if ids:
|
||||
# Assume token decoding
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0]
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
|
||||
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
|
|
@ -311,10 +312,10 @@ class ModelContainer:
|
|||
|
||||
"""
|
||||
|
||||
token_healing = kwargs.get("token_healing") or False
|
||||
max_tokens = kwargs.get("max_tokens") or 150
|
||||
stream_interval = kwargs.get("stream_interval") or 0
|
||||
generate_window = min(kwargs.get("generate_window") or 512, max_tokens)
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
max_tokens = unwrap(kwargs.get("max_tokens"), 150)
|
||||
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
|
||||
generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens)
|
||||
|
||||
# Sampler settings
|
||||
|
||||
|
|
@ -322,42 +323,43 @@ class ModelContainer:
|
|||
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
|
||||
if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"):
|
||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
|
||||
|
||||
if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
|
||||
|
||||
if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
|
||||
|
||||
if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"):
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
|
||||
|
||||
#Apply settings
|
||||
|
||||
gen_settings.temperature = kwargs.get("temperature") or 1.0
|
||||
gen_settings.temperature_last = kwargs.get("temperature_last") or False
|
||||
gen_settings.top_k = kwargs.get("top_k") or 0
|
||||
gen_settings.top_p = kwargs.get("top_p") or 1.0
|
||||
gen_settings.min_p = kwargs.get("min_p") or 0.0
|
||||
gen_settings.tfs = kwargs.get("tfs") or 1.0
|
||||
gen_settings.typical = kwargs.get("typical") or 1.0
|
||||
gen_settings.mirostat = kwargs.get("mirostat") or False
|
||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
|
||||
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
|
||||
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
|
||||
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
|
||||
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
|
||||
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
|
||||
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
|
||||
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
|
||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
|
||||
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed fallback
|
||||
# Always default to 0 if something goes wrong
|
||||
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
|
||||
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
|
||||
|
||||
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []
|
||||
ban_eos_token = kwargs.get("ban_eos_token") or False
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
|
|
@ -383,7 +385,7 @@ class ModelContainer:
|
|||
|
||||
ids = self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos = kwargs.get("add_bos_token") or True,
|
||||
add_bos = unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens = True
|
||||
)
|
||||
context_len = len(ids[0])
|
||||
|
|
|
|||
11
utils.py
11
utils.py
|
|
@ -30,3 +30,14 @@ def get_generator_error(message: str):
|
|||
|
||||
def get_sse_packet(json_data: str):
|
||||
return f"data: {json_data}\n\n"
|
||||
|
||||
# Unwrap function for Optionals
|
||||
def unwrap(wrapped, default = None):
|
||||
if wrapped is None:
|
||||
return default
|
||||
else:
|
||||
return wrapped
|
||||
|
||||
# Coalesce function for multiple unwraps
|
||||
def coalesce(*args):
|
||||
return next((arg for arg in args if arg is not None), None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue