diff --git a/common/config.py b/common/config.py deleted file mode 100644 index b1b251b..0000000 --- a/common/config.py +++ /dev/null @@ -1,88 +0,0 @@ -import yaml -import pathlib -from loguru import logger -from typing import Any - -from common.utils import unwrap, merge_dicts - -# Global config dictionary constant -GLOBAL_CONFIG: dict = {} - - -def load(arguments: dict[str, Any]): - """load the global application config""" - global GLOBAL_CONFIG - - # config is applied in order of items in the list - configs = [ - from_file(pathlib.Path("config.yml")), - from_environment(), - from_args(arguments), - ] - - GLOBAL_CONFIG = merge_dicts(*configs) - - -def from_file(config_path: pathlib.Path) -> dict[str, Any]: - """loads config from a given file path""" - - # try loading from file - try: - with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: - return unwrap(yaml.safe_load(config_file), {}) - except FileNotFoundError: - logger.info("The config.yml file cannot be found") - except Exception as exc: - logger.error( - f"The YAML config couldn't load because of the following error:\n\n{exc}" - ) - - # if no config file was loaded - return {} - - -def from_args(args: dict[str, Any]) -> dict[str, Any]: - """loads config from the provided arguments""" - config = {} - - config_override = unwrap(args.get("options", {}).get("config")) - if config_override: - logger.info("Config file override detected in args.") - config = from_file(pathlib.Path(config_override)) - return config # Return early if loading from file - - for key in ["network", "model", "logging", "developer", "embeddings"]: - override = args.get(key) - if override: - if key == "logging": - # Strip the "log_" prefix from logging keys if present - override = {k.replace("log_", ""): v for k, v in override.items()} - config[key] = override - - return config - - -def from_environment() -> dict[str, Any]: - """loads configuration from environment variables""" - - # TODO: load config from environment variables - # this means that we can have host default to 0.0.0.0 in docker for example - # this would also mean that docker containers no longer require a non - # default config file to be used - return {} - - -# refactor the get_config functions -def get_config(config: dict[str, any], topic: str) -> callable: - return lambda: unwrap(config.get(topic), {}) - - -# each of these is a function -model_config = get_config(GLOBAL_CONFIG, "model") -sampling_config = get_config(GLOBAL_CONFIG, "sampling") -draft_model_config = get_config(model_config(), "draft") -lora_config = get_config(model_config(), "lora") -network_config = get_config(GLOBAL_CONFIG, "network") -logging_config = get_config(GLOBAL_CONFIG, "logging") -developer_config = get_config(GLOBAL_CONFIG, "developer") -embeddings_config = get_config(GLOBAL_CONFIG, "embeddings") diff --git a/common/downloader.py b/common/downloader.py index b9e1b72..b0a8d93 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -10,8 +10,8 @@ from loguru import logger from rich.progress import Progress from typing import List, Optional -from common.config import lora_config, model_config from common.logger import get_progress_bar +from common.tabby_config import config from common.utils import unwrap @@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str """Gets the download folder for the repo.""" if repo_type == "lora": - download_path = pathlib.Path(lora_config().get("lora_dir") or "loras") + download_path = pathlib.Path(config.lora.get("lora_dir") or "loras") else: - download_path = pathlib.Path(model_config().get("model_dir") or "models") + download_path = pathlib.Path(config.model.get("model_dir") or "models") download_path = download_path / (folder_name or repo_id.split("/")[-1]) return download_path diff --git a/common/model.py b/common/model.py index 97bac05..a9ddfff 100644 --- a/common/model.py +++ b/common/model.py @@ -10,9 +10,9 @@ from fastapi import HTTPException from loguru import logger from typing import Optional -from common import config from common.logger import get_loading_progress_bar from common.networking import handle_request_error +from common.tabby_config import config from common.utils import unwrap from endpoints.utils import do_export_openapi @@ -153,8 +153,7 @@ async def unload_embedding_model(): def get_config_default(key: str, model_type: str = "model"): """Fetches a default value from model config if allowed by the user.""" - model_config = config.model_config() - default_keys = unwrap(model_config.get("use_as_default"), []) + default_keys = unwrap(config.model.get("use_as_default"), []) # Add extra keys to defaults default_keys.append("embeddings_device") @@ -162,13 +161,11 @@ def get_config_default(key: str, model_type: str = "model"): if key in default_keys: # Is this a draft model load parameter? if model_type == "draft": - draft_config = config.draft_model_config() - return draft_config.get(key) + return config.draft_model.get(key) elif model_type == "embedding": - embeddings_config = config.embeddings_config() - return embeddings_config.get(key) + return config.embeddings.get(key) else: - return model_config.get(key) + return config.model.get(key) async def check_model_container(): diff --git a/common/networking.py b/common/networking.py index 7c088a9..be6f1ab 100644 --- a/common/networking.py +++ b/common/networking.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from typing import Optional from uuid import uuid4 -from common import config +from common.tabby_config import config from common.utils import unwrap @@ -39,7 +39,7 @@ def handle_request_error(message: str, exc_info: bool = True): """Log a request error to the console.""" trace = traceback.format_exc() - send_trace = unwrap(config.network_config().get("send_tracebacks"), False) + send_trace = unwrap(config.network.get("send_tracebacks"), False) error_message = TabbyRequestErrorMessage( message=message, trace=trace if send_trace else None @@ -134,7 +134,7 @@ def get_global_depends(): depends = [Depends(add_request_id)] - if config.logging_config().get("requests"): + if config.logging.get("requests"): depends.append(Depends(log_request)) return depends diff --git a/common/tabby_config.py b/common/tabby_config.py new file mode 100644 index 0000000..c6119cb --- /dev/null +++ b/common/tabby_config.py @@ -0,0 +1,96 @@ +import yaml +import pathlib +from loguru import logger +from typing import Optional + +from common.utils import unwrap, merge_dicts + + +class TabbyConfig: + network: dict = {} + logging: dict = {} + model: dict = {} + draft_model: dict = {} + lora: dict = {} + sampling: dict = {} + developer: dict = {} + embeddings: dict = {} + + def __init__(self, arguments: Optional[dict] = None): + """load the global application config""" + + # config is applied in order of items in the list + configs = [ + self._from_file(pathlib.Path("config.yml")), + self._from_args(unwrap(arguments, {})), + ] + + merged_config = merge_dicts(*configs) + + self.network = unwrap(merged_config.get("network"), {}) + self.logging = unwrap(merged_config.get("logging"), {}) + self.model = unwrap(merged_config.get("model"), {}) + self.draft_model = unwrap(merged_config.get("draft"), {}) + self.lora = unwrap(merged_config.get("draft"), {}) + self.sampling = unwrap(merged_config.get("sampling"), {}) + self.developer = unwrap(merged_config.get("developer"), {}) + self.embeddings = unwrap(merged_config.get("embeddings"), {}) + + def _from_file(self, config_path: pathlib.Path): + """loads config from a given file path""" + + # try loading from file + try: + with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: + return unwrap(yaml.safe_load(config_file), {}) + except FileNotFoundError: + logger.info("The config.yml file cannot be found") + except Exception as exc: + logger.error( + "The YAML config couldn't load because of " + f"the following error:\n\n{exc}" + ) + + # if no config file was loaded + return {} + + def _from_args(self, args: dict): + """loads config from the provided arguments""" + config = {} + + config_override = unwrap(args.get("options", {}).get("config")) + if config_override: + logger.info("Config file override detected in args.") + config = self.from_file(pathlib.Path(config_override)) + return config # Return early if loading from file + + for key in ["network", "model", "logging", "developer", "embeddings"]: + override = args.get(key) + if override: + if key == "logging": + # Strip the "log_" prefix from logging keys if present + override = {k.replace("log_", ""): v for k, v in override.items()} + config[key] = override + + return config + + def _from_environment(self): + """loads configuration from environment variables""" + + # TODO: load config from environment variables + # this means that we can have host default to 0.0.0.0 in docker for example + # this would also mean that docker containers no longer require a non + # default config file to be used + pass + + +# Create an empty instance of the shared var to make sure nothing breaks +config: TabbyConfig = TabbyConfig() + + +def load_config(arguments: dict): + """Load a populated config class on startup.""" + + global shared_config + + shared_config = TabbyConfig(arguments) diff --git a/common/utils.py b/common/utils.py index 5133ed8..d5723a0 100644 --- a/common/utils.py +++ b/common/utils.py @@ -36,6 +36,8 @@ def merge_dicts(*dicts): for dictionary in dicts: result = merge_dict(result, dictionary) + return result + def flat_map(input_list): """Flattens a list of lists into a single list.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 66bc759..b888f19 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -3,10 +3,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from common import config, model +from common import model from common.auth import check_api_key from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect +from common.tabby_config import config from common.utils import unwrap from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( @@ -58,7 +59,7 @@ async def completion_request( data.prompt = "\n".join(data.prompt) disable_request_streaming = unwrap( - config.developer_config().get("disable_request_streaming"), False + config.developer.get("disable_request_streaming"), False ) # Set an empty JSON schema if the request wants a JSON response @@ -117,7 +118,7 @@ async def chat_completion_request( data.json_schema = {"type": "object"} disable_request_streaming = unwrap( - config.developer_config().get("disable_request_streaming"), False + config.developer.get("disable_request_streaming"), False ) if data.stream and not disable_request_streaming: diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 8850524..1f9d194 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -4,11 +4,12 @@ from sys import maxsize from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse -from common import config, model, sampling +from common import model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect +from common.tabby_config import config from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap from endpoints.core.types.auth import AuthPermissionResponse @@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList: Requires an admin key to see all models. """ - model_config = config.model_config() - model_dir = unwrap(model_config.get("model_dir"), "models") + model_dir = unwrap(config.model.get("model_dir"), "models") model_path = pathlib.Path(model_dir) - draft_model_dir = config.draft_model_config().get("draft_model_dir") + draft_model_dir = config.draft_model.get("draft_model_dir") if get_key_permission(request) == "admin": models = get_model_list(model_path.resolve(), draft_model_dir) else: models = await get_current_model_list() - if unwrap(model_config.get("use_dummy_models"), False): + if unwrap(config.model.get("use_dummy_models"), False): models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) return models @@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList: """ if get_key_permission(request) == "admin": - draft_model_dir = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) + draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models") draft_model_path = pathlib.Path(draft_model_dir) models = get_model_list(draft_model_path.resolve()) @@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) + model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models")) model_path = model_path / data.name draft_model_path = None @@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - draft_model_path = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) + draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models") if not model_path.exists(): error_message = handle_request_error( @@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList: """ if get_key_permission(request) == "admin": - lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) else: loras = get_active_loras() @@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: raise HTTPException(400, error_message) - lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) if not lora_dir.exists(): error_message = handle_request_error( "A parent lora directory does not exist for load. Check your config.yml?", @@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList: if get_key_permission(request) == "admin": embedding_model_dir = unwrap( - config.embeddings_config().get("embedding_model_dir"), "models" + config.embeddings.get("embedding_model_dir"), "models" ) embedding_model_path = pathlib.Path(embedding_model_dir) @@ -307,7 +303,7 @@ async def load_embedding_model( raise HTTPException(400, error_message) embedding_model_dir = pathlib.Path( - unwrap(config.model_config().get("embedding_model_dir"), "models") + unwrap(config.embeddings.get("embedding_model_dir"), "models") ) embedding_model_path = embedding_model_dir / data.name diff --git a/endpoints/server.py b/endpoints/server.py index 4b04f6e..0f6a19b 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -5,9 +5,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger -from common import config from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends +from common.tabby_config import config from common.utils import unwrap from endpoints.Kobold import router as KoboldRouter from endpoints.OAI import router as OAIRouter @@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None): allow_headers=["*"], ) - api_servers = unwrap(config.network_config().get("api_servers"), []) + api_servers = unwrap(config.network.get("api_servers"), []) # Map for API id to server router router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter} diff --git a/main.py b/main.py index e25a9ff..140b89b 100644 --- a/main.py +++ b/main.py @@ -9,12 +9,13 @@ import signal from loguru import logger from typing import Optional -from common import config, gen_logging, sampling, model +from common import gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser from common.auth import load_auth_keys from common.logger import setup_logger from common.networking import is_port_in_use from common.signals import signal_handler +from common.tabby_config import config, load_config from common.utils import unwrap from endpoints.server import export_openapi, start_api from endpoints.utils import do_export_openapi @@ -26,10 +27,8 @@ if not do_export_openapi: async def entrypoint_async(): """Async entry function for program startup""" - network_config = config.network_config() - - host = unwrap(network_config.get("host"), "127.0.0.1") - port = unwrap(network_config.get("port"), 5000) + host = unwrap(config.network.get("host"), "127.0.0.1") + port = unwrap(config.network.get("port"), 5000) # Check if the port is available and attempt to bind a fallback if is_port_in_use(port): @@ -51,18 +50,16 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(unwrap(network_config.get("disable_auth"), False)) + load_auth_keys(unwrap(config.network.get("disable_auth"), False)) # Override the generation log options if given - log_config = config.logging_config() - if log_config: - gen_logging.update_from_dict(log_config) + if config.logging: + gen_logging.update_from_dict(config.logging) gen_logging.broadcast_status() # Set sampler parameter overrides if provided - sampling_config = config.sampling_config() - sampling_override_preset = sampling_config.get("override_preset") + sampling_override_preset = config.sampling.get("override_preset") if sampling_override_preset: try: sampling.overrides_from_file(sampling_override_preset) @@ -71,32 +68,29 @@ async def entrypoint_async(): # If an initial model name is specified, create a container # and load the model - model_config = config.model_config() - model_name = model_config.get("model_name") + model_name = config.model.get("model_name") if model_name: - model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models")) model_path = model_path / model_name - await model.load_model(model_path.resolve(), **model_config) + await model.load_model(model_path.resolve(), **config.model) # Load loras after loading the model - lora_config = config.lora_config() - if lora_config.get("loras"): - lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) - await model.container.load_loras(lora_dir.resolve(), **lora_config) + if config.lora.get("loras"): + lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) + await model.container.load_loras(lora_dir.resolve(), **config.lora) # If an initial embedding model name is specified, create a separate container # and load the model - embedding_config = config.embeddings_config() - embedding_model_name = embedding_config.get("embedding_model_name") + embedding_model_name = config.embeddings.get("embedding_model_name") if embedding_model_name: embedding_model_path = pathlib.Path( - unwrap(embedding_config.get("embedding_model_dir"), "models") + unwrap(config.embeddings.get("embedding_model_dir"), "models") ) embedding_model_path = embedding_model_path / embedding_model_name try: - await model.load_embedding_model(embedding_model_path, **embedding_config) + await model.load_embedding_model(embedding_model_path, **config.embeddings) except ImportError as ex: logger.error(ex.msg) @@ -116,7 +110,7 @@ def entrypoint(arguments: Optional[dict] = None): arguments = convert_args_to_dict(parser.parse_args(), parser) # load config - config.load(arguments) + load_config(arguments) if do_export_openapi: openapi_json = export_openapi() @@ -127,12 +121,10 @@ def entrypoint(arguments: Optional[dict] = None): return - developer_config = config.developer_config() - # Check exllamav2 version and give a descriptive error if it's too old # Skip if launching unsafely - if unwrap(developer_config.get("unsafe_launch"), False): + if unwrap(config.developer.get("unsafe_launch"), False): logger.warning( "UNSAFE: Skipping ExllamaV2 version check.\n" "If you aren't a developer, please keep this off!" @@ -141,12 +133,12 @@ def entrypoint(arguments: Optional[dict] = None): check_exllama_version() # Enable CUDA malloc backend - if unwrap(developer_config.get("cuda_malloc_backend"), False): + if unwrap(config.developer.get("cuda_malloc_backend"), False): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.") # Use Uvloop/Winloop - if unwrap(developer_config.get("uvloop"), False): + if unwrap(config.developer.get("uvloop"), False): if platform.system() == "Windows": from winloop import install else: @@ -158,7 +150,7 @@ def entrypoint(arguments: Optional[dict] = None): logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.") # Set the process priority - if unwrap(developer_config.get("realtime_process_priority"), False): + if unwrap(config.developer.get("realtime_process_priority"), False): import psutil current_process = psutil.Process(os.getpid())