Tree: Refactor code organization
Move common functions into their own folder and refactor the backends to use their own folder as well. Also cleanup imports and alphabetize import statments themselves. Finally, move colab and docker into their own folders as well. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ee99349a78
commit
78f920eeda
22 changed files with 41 additions and 42 deletions
127
common/args.py
Normal file
127
common/args.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Argparser for overriding config values"""
|
||||
import argparse
|
||||
|
||||
|
||||
def str_to_bool(value):
|
||||
"""Converts a string into a boolean value"""
|
||||
|
||||
if value.lower() in {"false", "f", "0", "no", "n"}:
|
||||
return False
|
||||
elif value.lower() in {"true", "t", "1", "yes", "y"}:
|
||||
return True
|
||||
raise ValueError(f"{value} is not a valid boolean value")
|
||||
|
||||
|
||||
def init_argparser():
|
||||
"""Creates an argument parser that any function can use"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog="These args are only for a subset of the config. "
|
||||
+ "Please edit config.yml for all options!"
|
||||
)
|
||||
add_network_args(parser)
|
||||
add_model_args(parser)
|
||||
add_logging_args(parser)
|
||||
add_config_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
"""Broad conversion of surface level arg groups to dictionaries"""
|
||||
|
||||
arg_groups = {}
|
||||
for group in parser._action_groups:
|
||||
group_dict = {}
|
||||
for arg in group._group_actions:
|
||||
value = getattr(args, arg.dest, None)
|
||||
if value is not None:
|
||||
group_dict[arg.dest] = value
|
||||
|
||||
arg_groups[group.title] = group_dict
|
||||
|
||||
return arg_groups
|
||||
|
||||
|
||||
def add_config_args(parser: argparse.ArgumentParser):
|
||||
"""Adds config arguments"""
|
||||
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="Path to an overriding config.yml file"
|
||||
)
|
||||
|
||||
|
||||
def add_network_args(parser: argparse.ArgumentParser):
|
||||
"""Adds networking arguments"""
|
||||
|
||||
network_group = parser.add_argument_group("network")
|
||||
network_group.add_argument("--host", type=str, help="The IP to host on")
|
||||
network_group.add_argument("--port", type=int, help="The port to host on")
|
||||
network_group.add_argument(
|
||||
"--disable-auth",
|
||||
type=str_to_bool,
|
||||
help="Disable HTTP token authenticaion with requests",
|
||||
)
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser):
|
||||
"""Adds model arguments"""
|
||||
|
||||
model_group = parser.add_argument_group("model")
|
||||
model_group.add_argument(
|
||||
"--model-dir", type=str, help="Overrides the directory to look for models"
|
||||
)
|
||||
model_group.add_argument("--model-name", type=str, help="An initial model to load")
|
||||
model_group.add_argument(
|
||||
"--max-seq-len", type=int, help="Override the maximum model sequence length"
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--override-base-seq-len",
|
||||
type=str_to_bool,
|
||||
help="Overrides base model context length",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
|
||||
)
|
||||
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
|
||||
model_group.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
help="Set the prompt template for chat completions",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split-auto",
|
||||
type=str_to_bool,
|
||||
help="Automatically allocate resources to GPUs",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split",
|
||||
type=float,
|
||||
nargs="+",
|
||||
help="An integer array of GBs of vram to split between GPUs. "
|
||||
+ "Ignored if gpu_split_auto is true",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--num-experts-per-token",
|
||||
type=int,
|
||||
help="Number of experts to use per token in MoE models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--use-cfg",
|
||||
type=str_to_bool,
|
||||
help="Enables CFG support",
|
||||
)
|
||||
|
||||
|
||||
def add_logging_args(parser: argparse.ArgumentParser):
|
||||
"""Adds logging arguments"""
|
||||
|
||||
logging_group = parser.add_argument_group("logging")
|
||||
logging_group.add_argument(
|
||||
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
|
||||
)
|
||||
logging_group.add_argument(
|
||||
"--log-generation-params",
|
||||
type=str_to_bool,
|
||||
help="Enable generation parameter logging",
|
||||
)
|
||||
127
common/auth.py
Normal file
127
common/auth.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
This method of authorization is pretty insecure, but since TabbyAPI is a local
|
||||
application, it should be fine.
|
||||
"""
|
||||
import secrets
|
||||
import yaml
|
||||
from fastapi import Header, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AuthKeys(BaseModel):
|
||||
"""
|
||||
This class represents the authentication keys for the application.
|
||||
It contains two types of keys: 'api_key' and 'admin_key'.
|
||||
The 'api_key' is used for general API calls, while the 'admin_key'
|
||||
is used for administrative tasks. The class also provides a method
|
||||
to verify if a given key matches the stored 'api_key' or 'admin_key'.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
admin_key: str
|
||||
|
||||
def verify_key(self, test_key: str, key_type: str):
|
||||
"""Verify if a given key matches the stored key."""
|
||||
if key_type == "admin_key":
|
||||
return test_key == self.admin_key
|
||||
if key_type == "api_key":
|
||||
# Admin keys are valid for all API calls
|
||||
return test_key == self.api_key or test_key == self.admin_key
|
||||
return False
|
||||
|
||||
|
||||
AUTH_KEYS: Optional[AuthKeys] = None
|
||||
DISABLE_AUTH: bool = False
|
||||
|
||||
|
||||
def load_auth_keys(disable_from_config: bool):
|
||||
"""Load the authentication keys from api_tokens.yml. If the file does not
|
||||
exist, generate new keys and save them to api_tokens.yml."""
|
||||
global AUTH_KEYS
|
||||
global DISABLE_AUTH
|
||||
|
||||
DISABLE_AUTH = disable_from_config
|
||||
if disable_from_config:
|
||||
logger.warning(
|
||||
"Disabling authentication makes your instance vulnerable. "
|
||||
"Set the `disable_auth` flag to False in config.yml if you "
|
||||
"want to share this instance with others."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
auth_keys_dict = yaml.safe_load(auth_file)
|
||||
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||
except OSError:
|
||||
new_auth_keys = AuthKeys(
|
||||
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
|
||||
)
|
||||
AUTH_KEYS = new_auth_keys
|
||||
|
||||
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
||||
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
|
||||
"If these keys get compromised, make sure to delete api_tokens.yml "
|
||||
"and restart the server. Have fun!"
|
||||
)
|
||||
|
||||
|
||||
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
|
||||
"""Check if the API key is valid."""
|
||||
|
||||
# Allow request if auth is disabled
|
||||
if DISABLE_AUTH:
|
||||
return
|
||||
|
||||
if x_api_key:
|
||||
if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
|
||||
raise HTTPException(401, "Invalid API key")
|
||||
return x_api_key
|
||||
|
||||
if authorization:
|
||||
split_key = authorization.split(" ")
|
||||
if len(split_key) < 2:
|
||||
raise HTTPException(401, "Invalid API key")
|
||||
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
|
||||
split_key[1], "api_key"
|
||||
):
|
||||
raise HTTPException(401, "Invalid API key")
|
||||
|
||||
return authorization
|
||||
|
||||
raise HTTPException(401, "Please provide an API key")
|
||||
|
||||
|
||||
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
|
||||
"""Check if the admin key is valid."""
|
||||
|
||||
# Allow request if auth is disabled
|
||||
if DISABLE_AUTH:
|
||||
return
|
||||
|
||||
if x_admin_key:
|
||||
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
|
||||
raise HTTPException(401, "Invalid admin key")
|
||||
return x_admin_key
|
||||
|
||||
if authorization:
|
||||
split_key = authorization.split(" ")
|
||||
if len(split_key) < 2:
|
||||
raise HTTPException(401, "Invalid admin key")
|
||||
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
|
||||
split_key[1], "admin_key"
|
||||
):
|
||||
raise HTTPException(401, "Invalid admin key")
|
||||
return authorization
|
||||
|
||||
raise HTTPException(401, "Please provide an admin key")
|
||||
83
common/config.py
Normal file
83
common/config.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import yaml
|
||||
import pathlib
|
||||
|
||||
from common.logger import init_logger
|
||||
from common.utils import unwrap
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GLOBAL_CONFIG: dict = {}
|
||||
|
||||
|
||||
def read_config_from_file(config_path: pathlib.Path):
|
||||
"""Sets the global config from a given file path"""
|
||||
global GLOBAL_CONFIG
|
||||
|
||||
try:
|
||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"The YAML config couldn't load because of the following error: "
|
||||
f"\n\n{exc}"
|
||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
||||
)
|
||||
GLOBAL_CONFIG = {}
|
||||
|
||||
|
||||
def override_config_from_args(args: dict):
|
||||
"""Overrides the config based on a dict representation of args"""
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Attempting to override config.yml from args.")
|
||||
read_config_from_file(pathlib.Path(config_override))
|
||||
return
|
||||
|
||||
# Network config
|
||||
network_override = args.get("network")
|
||||
if network_override:
|
||||
network_config = get_network_config()
|
||||
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
|
||||
|
||||
# Model config
|
||||
model_override = args.get("model")
|
||||
if model_override:
|
||||
model_config = get_model_config()
|
||||
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
|
||||
|
||||
# Logging config
|
||||
logging_override = args.get("logging")
|
||||
if logging_override:
|
||||
logging_config = get_gen_logging_config()
|
||||
GLOBAL_CONFIG["logging"] = {
|
||||
**logging_config,
|
||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
}
|
||||
|
||||
|
||||
def get_model_config():
|
||||
"""Returns the model config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
|
||||
|
||||
def get_draft_model_config():
|
||||
"""Returns the draft model config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("draft"), {})
|
||||
|
||||
|
||||
def get_lora_config():
|
||||
"""Returns the lora config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("lora"), {})
|
||||
|
||||
|
||||
def get_network_config():
|
||||
"""Returns the network config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
||||
|
||||
|
||||
def get_gen_logging_config():
|
||||
"""Returns the generation logging config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
||||
71
common/gen_logging.py
Normal file
71
common/gen_logging.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Functions for logging generation events.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Optional
|
||||
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LogPreferences(BaseModel):
|
||||
"""Logging preference config."""
|
||||
|
||||
prompt: bool = False
|
||||
generation_params: bool = False
|
||||
|
||||
|
||||
# Global reference to logging preferences
|
||||
PREFERENCES = LogPreferences()
|
||||
|
||||
|
||||
def update_from_dict(options_dict: Dict[str, bool]):
|
||||
"""Wrapper to set the logging config for generations"""
|
||||
global PREFERENCES
|
||||
|
||||
# Force bools on the dict
|
||||
for value in options_dict.values():
|
||||
if value is None:
|
||||
value = False
|
||||
|
||||
PREFERENCES = LogPreferences.model_validate(options_dict)
|
||||
|
||||
|
||||
def broadcast_status():
|
||||
"""Broadcasts the current logging status"""
|
||||
enabled = []
|
||||
if PREFERENCES.prompt:
|
||||
enabled.append("prompts")
|
||||
|
||||
if PREFERENCES.generation_params:
|
||||
enabled.append("generation params")
|
||||
|
||||
if len(enabled) > 0:
|
||||
logger.info("Generation logging is enabled for: " + ", ".join(enabled))
|
||||
else:
|
||||
logger.info("Generation logging is disabled")
|
||||
|
||||
|
||||
def log_generation_params(**kwargs):
|
||||
"""Logs generation parameters to console."""
|
||||
if PREFERENCES.generation_params:
|
||||
logger.info(f"Generation options: {kwargs}\n")
|
||||
|
||||
|
||||
def log_prompt(prompt: str, negative_prompt: Optional[str]):
|
||||
"""Logs the prompt to console."""
|
||||
if PREFERENCES.prompt:
|
||||
formatted_prompt = "\n" + prompt
|
||||
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
|
||||
|
||||
if negative_prompt:
|
||||
formatted_negative_prompt = "\n" + negative_prompt
|
||||
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
|
||||
|
||||
|
||||
def log_response(response: str):
|
||||
"""Logs the response to console."""
|
||||
if PREFERENCES.prompt:
|
||||
formatted_response = "\n" + response
|
||||
logger.info(f"Response: {formatted_response if response else 'Empty'}\n")
|
||||
27
common/generators.py
Normal file
27
common/generators.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Generator functions for the tabbyAPI."""
|
||||
import inspect
|
||||
from asyncio import Semaphore
|
||||
from functools import partialmethod
|
||||
from typing import AsyncGenerator
|
||||
|
||||
generate_semaphore = Semaphore(1)
|
||||
|
||||
|
||||
# Async generation that blocks on a semaphore
|
||||
async def generate_with_semaphore(generator: AsyncGenerator):
|
||||
"""Generate with a semaphore."""
|
||||
async with generate_semaphore:
|
||||
if inspect.isasyncgenfunction:
|
||||
async for result in generator():
|
||||
yield result
|
||||
else:
|
||||
for result in generator():
|
||||
yield result
|
||||
|
||||
|
||||
# Block a function with semaphore
|
||||
async def call_with_semaphore(callback: partialmethod):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
return await callback()
|
||||
async with generate_semaphore:
|
||||
return callback()
|
||||
71
common/logger.py
Normal file
71
common/logger.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Logging utility.
|
||||
https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import colorlog
|
||||
|
||||
_FORMAT = "%(log_color)s%(levelname)s: %(message)s"
|
||||
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
class ColoredFormatter(colorlog.ColoredFormatter):
|
||||
"""Adds logging prefix to newlines to align multi-line messages."""
|
||||
|
||||
def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"):
|
||||
super().__init__(
|
||||
fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style
|
||||
)
|
||||
|
||||
def format(self, record):
|
||||
msg = super().format(record)
|
||||
if record.message != "":
|
||||
parts = msg.split(record.message)
|
||||
msg = msg.replace("\n", "\r\n" + parts[0])
|
||||
return msg
|
||||
|
||||
|
||||
_root_logger = logging.getLogger("aphrodite")
|
||||
_default_handler = None
|
||||
|
||||
|
||||
def _setup_logger():
|
||||
_root_logger.setLevel(logging.DEBUG)
|
||||
global _default_handler
|
||||
if _default_handler is None:
|
||||
_default_handler = logging.StreamHandler(sys.stdout)
|
||||
_default_handler.flush = sys.stdout.flush # type: ignore
|
||||
_default_handler.setLevel(logging.INFO)
|
||||
_root_logger.addHandler(_default_handler)
|
||||
fmt = ColoredFormatter(
|
||||
_FORMAT,
|
||||
datefmt=_DATE_FORMAT,
|
||||
log_colors={
|
||||
"DEBUG": "cyan",
|
||||
"INFO": "green",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "red,bg_white",
|
||||
},
|
||||
reset=True,
|
||||
)
|
||||
_default_handler.setFormatter(fmt)
|
||||
# Setting this will avoid the message
|
||||
# being propagated to the parent logger.
|
||||
_root_logger.propagate = False
|
||||
|
||||
|
||||
# The logger is initialized when the module is imported.
|
||||
# This is thread-safe as the module is only imported once,
|
||||
# guaranteed by the Python GIL.
|
||||
_setup_logger()
|
||||
|
||||
|
||||
def init_logger(name: str):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(_default_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
104
common/templating.py
Normal file
104
common/templating.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||
import json
|
||||
import pathlib
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
from jinja2 import TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
"""A template for chat completion prompts."""
|
||||
|
||||
name: str
|
||||
template: str
|
||||
|
||||
|
||||
def get_prompt_from_template(
|
||||
messages,
|
||||
prompt_template: PromptTemplate,
|
||||
add_generation_prompt: bool,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
||||
f"or greater. Current version: {package_version('jinja2')}\n"
|
||||
"Please upgrade jinja by running the following command: "
|
||||
"pip install --upgrade jinja2"
|
||||
)
|
||||
|
||||
compiled_template = _compile_template(prompt_template.template)
|
||||
return compiled_template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**special_tokens,
|
||||
)
|
||||
|
||||
|
||||
# Inspired from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
||||
@lru_cache
|
||||
def _compile_template(template: str):
|
||||
"""Compiles a Jinja2 template"""
|
||||
|
||||
# Exception handler
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||
jinja_env.globals["raise_exception"] = raise_exception
|
||||
|
||||
jinja_template = jinja_env.from_string(template)
|
||||
return jinja_template
|
||||
|
||||
|
||||
def get_all_templates():
|
||||
"""Fetches all templates from the templates directory"""
|
||||
|
||||
template_directory = pathlib.Path("templates")
|
||||
return template_directory.glob("*.jinja")
|
||||
|
||||
|
||||
def find_template_from_model(model_path: pathlib.Path):
|
||||
"""Find a matching template name from a model path."""
|
||||
model_name = model_path.name
|
||||
template_files = get_all_templates()
|
||||
for filepath in template_files:
|
||||
template_name = filepath.stem.lower()
|
||||
|
||||
# Check if the template name is present in the model name
|
||||
if template_name in model_name.lower():
|
||||
return template_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_template_from_file(prompt_template_name: str):
|
||||
"""Get a template from a jinja file."""
|
||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template:
|
||||
return PromptTemplate(
|
||||
name=prompt_template_name, template=raw_template.read()
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Get a template from a JSON file
|
||||
# Requires a key and template name
|
||||
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if json_path.exists():
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
if chat_template:
|
||||
return PromptTemplate(name=name, template=chat_template)
|
||||
|
||||
return None
|
||||
58
common/utils.py
Normal file
58
common/utils.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Common utilities for the tabbyAPI"""
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
"""Wrapper callback for load progress."""
|
||||
yield module, modules
|
||||
|
||||
|
||||
class TabbyGeneratorErrorMessage(BaseModel):
|
||||
"""Common error types."""
|
||||
|
||||
message: str
|
||||
trace: Optional[str] = None
|
||||
|
||||
|
||||
class TabbyGeneratorError(BaseModel):
|
||||
"""Common error types."""
|
||||
|
||||
error: TabbyGeneratorErrorMessage
|
||||
|
||||
|
||||
def get_generator_error(message: str):
|
||||
"""Get a generator error."""
|
||||
error_message = TabbyGeneratorErrorMessage(
|
||||
message=message, trace=traceback.format_exc()
|
||||
)
|
||||
|
||||
generator_error = TabbyGeneratorError(error=error_message)
|
||||
|
||||
# Log and send the exception
|
||||
logger.error(generator_error.error.message)
|
||||
return get_sse_packet(generator_error.model_dump_json())
|
||||
|
||||
|
||||
def get_sse_packet(json_data: str):
|
||||
"""Get an SSE packet."""
|
||||
return f"data: {json_data}\n\n"
|
||||
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
"""Unwrap function for Optionals."""
|
||||
if wrapped is None:
|
||||
return default
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def coalesce(*args):
|
||||
"""Coalesce function for multiple unwraps."""
|
||||
return next((arg for arg in args if arg is not None), None)
|
||||
Loading…
Add table
Add a link
Reference in a new issue