Tree: Refactor code organization

Move common functions into their own folder and refactor the backends
to use their own folder as well.

Also cleanup imports and alphabetize import statments themselves.

Finally, move colab and docker into their own folders as well.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-01-18 00:42:52 -05:00 committed by Brian Dashore
parent ee99349a78
commit 78f920eeda
22 changed files with 41 additions and 42 deletions

127
common/args.py Normal file
View file

@ -0,0 +1,127 @@
"""Argparser for overriding config values"""
import argparse
def str_to_bool(value):
"""Converts a string into a boolean value"""
if value.lower() in {"false", "f", "0", "no", "n"}:
return False
elif value.lower() in {"true", "t", "1", "yes", "y"}:
return True
raise ValueError(f"{value} is not a valid boolean value")
def init_argparser():
"""Creates an argument parser that any function can use"""
parser = argparse.ArgumentParser(
epilog="These args are only for a subset of the config. "
+ "Please edit config.yml for all options!"
)
add_network_args(parser)
add_model_args(parser)
add_logging_args(parser)
add_config_args(parser)
return parser
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Broad conversion of surface level arg groups to dictionaries"""
arg_groups = {}
for group in parser._action_groups:
group_dict = {}
for arg in group._group_actions:
value = getattr(args, arg.dest, None)
if value is not None:
group_dict[arg.dest] = value
arg_groups[group.title] = group_dict
return arg_groups
def add_config_args(parser: argparse.ArgumentParser):
"""Adds config arguments"""
parser.add_argument(
"--config", type=str, help="Path to an overriding config.yml file"
)
def add_network_args(parser: argparse.ArgumentParser):
"""Adds networking arguments"""
network_group = parser.add_argument_group("network")
network_group.add_argument("--host", type=str, help="The IP to host on")
network_group.add_argument("--port", type=int, help="The port to host on")
network_group.add_argument(
"--disable-auth",
type=str_to_bool,
help="Disable HTTP token authenticaion with requests",
)
def add_model_args(parser: argparse.ArgumentParser):
"""Adds model arguments"""
model_group = parser.add_argument_group("model")
model_group.add_argument(
"--model-dir", type=str, help="Overrides the directory to look for models"
)
model_group.add_argument("--model-name", type=str, help="An initial model to load")
model_group.add_argument(
"--max-seq-len", type=int, help="Override the maximum model sequence length"
)
model_group.add_argument(
"--override-base-seq-len",
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the prompt template for chat completions",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
help="Automatically allocate resources to GPUs",
)
model_group.add_argument(
"--gpu-split",
type=float,
nargs="+",
help="An integer array of GBs of vram to split between GPUs. "
+ "Ignored if gpu_split_auto is true",
)
model_group.add_argument(
"--num-experts-per-token",
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--use-cfg",
type=str_to_bool,
help="Enables CFG support",
)
def add_logging_args(parser: argparse.ArgumentParser):
"""Adds logging arguments"""
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
)
logging_group.add_argument(
"--log-generation-params",
type=str_to_bool,
help="Enable generation parameter logging",
)

127
common/auth.py Normal file
View file

@ -0,0 +1,127 @@
"""
This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
import secrets
import yaml
from fastapi import Header, HTTPException
from pydantic import BaseModel
from typing import Optional
from common.logger import init_logger
logger = init_logger(__name__)
class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
"""
api_key: str
admin_key: str
def verify_key(self, test_key: str, key_type: str):
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
if key_type == "api_key":
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
return False
AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
global DISABLE_AUTH
DISABLE_AUTH = disable_from_config
if disable_from_config:
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)
return
try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except OSError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
)
AUTH_KEYS = new_auth_keys
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
)
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
"""Check if the API key is valid."""
# Allow request if auth is disabled
if DISABLE_AUTH:
return
if x_api_key:
if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
raise HTTPException(401, "Invalid API key")
return x_api_key
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid API key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "api_key"
):
raise HTTPException(401, "Invalid API key")
return authorization
raise HTTPException(401, "Please provide an API key")
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
"""Check if the admin key is valid."""
# Allow request if auth is disabled
if DISABLE_AUTH:
return
if x_admin_key:
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
raise HTTPException(401, "Invalid admin key")
return x_admin_key
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid admin key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "admin_key"
):
raise HTTPException(401, "Invalid admin key")
return authorization
raise HTTPException(401, "Please provide an admin key")

83
common/config.py Normal file
View file

@ -0,0 +1,83 @@
import yaml
import pathlib
from common.logger import init_logger
from common.utils import unwrap
logger = init_logger(__name__)
GLOBAL_CONFIG: dict = {}
def read_config_from_file(config_path: pathlib.Path):
"""Sets the global config from a given file path"""
global GLOBAL_CONFIG
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
GLOBAL_CONFIG = {}
def override_config_from_args(args: dict):
"""Overrides the config based on a dict representation of args"""
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Attempting to override config.yml from args.")
read_config_from_file(pathlib.Path(config_override))
return
# Network config
network_override = args.get("network")
if network_override:
network_config = get_network_config()
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
# Model config
model_override = args.get("model")
if model_override:
model_config = get_model_config()
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
# Logging config
logging_override = args.get("logging")
if logging_override:
logging_config = get_gen_logging_config()
GLOBAL_CONFIG["logging"] = {
**logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
}
def get_model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})
def get_draft_model_config():
"""Returns the draft model config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("draft"), {})
def get_lora_config():
"""Returns the lora config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("lora"), {})
def get_network_config():
"""Returns the network config from the global config"""
return unwrap(GLOBAL_CONFIG.get("network"), {})
def get_gen_logging_config():
"""Returns the generation logging config from the global config"""
return unwrap(GLOBAL_CONFIG.get("logging"), {})

71
common/gen_logging.py Normal file
View file

@ -0,0 +1,71 @@
"""
Functions for logging generation events.
"""
from pydantic import BaseModel
from typing import Dict, Optional
from common.logger import init_logger
logger = init_logger(__name__)
class LogPreferences(BaseModel):
"""Logging preference config."""
prompt: bool = False
generation_params: bool = False
# Global reference to logging preferences
PREFERENCES = LogPreferences()
def update_from_dict(options_dict: Dict[str, bool]):
"""Wrapper to set the logging config for generations"""
global PREFERENCES
# Force bools on the dict
for value in options_dict.values():
if value is None:
value = False
PREFERENCES = LogPreferences.model_validate(options_dict)
def broadcast_status():
"""Broadcasts the current logging status"""
enabled = []
if PREFERENCES.prompt:
enabled.append("prompts")
if PREFERENCES.generation_params:
enabled.append("generation params")
if len(enabled) > 0:
logger.info("Generation logging is enabled for: " + ", ".join(enabled))
else:
logger.info("Generation logging is disabled")
def log_generation_params(**kwargs):
"""Logs generation parameters to console."""
if PREFERENCES.generation_params:
logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str, negative_prompt: Optional[str]):
"""Logs the prompt to console."""
if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
def log_response(response: str):
"""Logs the response to console."""
if PREFERENCES.prompt:
formatted_response = "\n" + response
logger.info(f"Response: {formatted_response if response else 'Empty'}\n")

27
common/generators.py Normal file
View file

@ -0,0 +1,27 @@
"""Generator functions for the tabbyAPI."""
import inspect
from asyncio import Semaphore
from functools import partialmethod
from typing import AsyncGenerator
generate_semaphore = Semaphore(1)
# Async generation that blocks on a semaphore
async def generate_with_semaphore(generator: AsyncGenerator):
"""Generate with a semaphore."""
async with generate_semaphore:
if inspect.isasyncgenfunction:
async for result in generator():
yield result
else:
for result in generator():
yield result
# Block a function with semaphore
async def call_with_semaphore(callback: partialmethod):
if inspect.iscoroutinefunction(callback):
return await callback()
async with generate_semaphore:
return callback()

71
common/logger.py Normal file
View file

@ -0,0 +1,71 @@
"""
Logging utility.
https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py
"""
import logging
import sys
import colorlog
_FORMAT = "%(log_color)s%(levelname)s: %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
class ColoredFormatter(colorlog.ColoredFormatter):
"""Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"):
super().__init__(
fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style
)
def format(self, record):
msg = super().format(record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
_root_logger = logging.getLogger("aphrodite")
_default_handler = None
def _setup_logger():
_root_logger.setLevel(logging.DEBUG)
global _default_handler
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(logging.INFO)
_root_logger.addHandler(_default_handler)
fmt = ColoredFormatter(
_FORMAT,
datefmt=_DATE_FORMAT,
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
reset=True,
)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False
# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger()
def init_logger(name: str):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.addHandler(_default_handler)
logger.propagate = False
return logger

104
common/templating.py Normal file
View file

@ -0,0 +1,104 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from jinja2 import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from packaging import version
from pydantic import BaseModel
from typing import Optional, Dict
class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
name: str
template: str
def get_prompt_from_template(
messages,
prompt_template: PromptTemplate,
add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None,
):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"pip install --upgrade jinja2"
)
compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
**special_tokens,
)
# Inspired from
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
"""Compiles a Jinja2 template"""
# Exception handler
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
jinja_template = jinja_env.from_string(template)
return jinja_template
def get_all_templates():
"""Fetches all templates from the templates directory"""
template_directory = pathlib.Path("templates")
return template_directory.glob("*.jinja")
def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_files = get_all_templates()
for filepath in template_files:
template_name = filepath.stem.lower()
# Check if the template name is present in the model name
if template_name in model_name.lower():
return template_name
return None
def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate(
name=prompt_template_name, template=raw_template.read()
)
return None
# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
"""Get a template from a JSON file. Requires a key and template name"""
if json_path.exists():
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(name=name, template=chat_template)
return None

58
common/utils.py Normal file
View file

@ -0,0 +1,58 @@
"""Common utilities for the tabbyAPI"""
import traceback
from typing import Optional
from pydantic import BaseModel
from common.logger import init_logger
logger = init_logger(__name__)
def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
class TabbyGeneratorErrorMessage(BaseModel):
"""Common error types."""
message: str
trace: Optional[str] = None
class TabbyGeneratorError(BaseModel):
"""Common error types."""
error: TabbyGeneratorErrorMessage
def get_generator_error(message: str):
"""Get a generator error."""
error_message = TabbyGeneratorErrorMessage(
message=message, trace=traceback.format_exc()
)
generator_error = TabbyGeneratorError(error=error_message)
# Log and send the exception
logger.error(generator_error.error.message)
return get_sse_packet(generator_error.model_dump_json())
def get_sse_packet(json_data: str):
"""Get an SSE packet."""
return f"data: {json_data}\n\n"
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None:
return default
return wrapped
def coalesce(*args):
"""Coalesce function for multiple unwraps."""
return next((arg for arg in args if arg is not None), None)