From 5a2de300662f70918a57719035016e7fe2b505f2 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 9 Mar 2024 22:31:47 -0500 Subject: [PATCH] Tree: Update to cleanup globals Use the module singleton pattern to share global state. This can also be a modified version of the Global Object Pattern. The main reason this pattern is used is for ease of use when handling global state rather than adding extra dependencies for a DI parameter. Signed-off-by: kingbri --- OAI/types/model.py | 5 ++-- common/auth.py | 1 + common/config.py | 48 +++++++++++++++++-------------- common/gen_logging.py | 8 +++--- common/model.py | 1 + common/sampling.py | 20 ++++++------- main.py | 67 +++++++++++++++++-------------------------- 7 files changed, 68 insertions(+), 82 deletions(-) diff --git a/OAI/types/model.py b/OAI/types/model.py index 03f8fd3..5e308c6 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, ConfigDict from time import time from typing import List, Optional -from common.gen_logging import LogPreferences +from common.gen_logging import GenLogPreferences class ModelCardParameters(BaseModel): @@ -30,7 +30,7 @@ class ModelCard(BaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time())) owned_by: str = "tabbyAPI" - logging: Optional[LogPreferences] = None + logging: Optional[GenLogPreferences] = None parameters: Optional[ModelCardParameters] = None @@ -53,7 +53,6 @@ class DraftModelLoadRequest(BaseModel): ) -# TODO: Unify this with ModelCardParams class ModelLoadRequest(BaseModel): """Represents a model load request.""" diff --git a/common/auth.py b/common/auth.py index 8c00d9a..2e154e6 100644 --- a/common/auth.py +++ b/common/auth.py @@ -32,6 +32,7 @@ class AuthKeys(BaseModel): return False +# Global auth constants AUTH_KEYS: Optional[AuthKeys] = None DISABLE_AUTH: bool = False diff --git a/common/config.py b/common/config.py index fc18c71..86aedac 100644 --- a/common/config.py +++ b/common/config.py @@ -4,10 +4,11 @@ from loguru import logger from common.utils import unwrap +# Global config dictionary constant GLOBAL_CONFIG: dict = {} -def read_config_from_file(config_path: pathlib.Path): +def from_file(config_path: pathlib.Path): """Sets the global config from a given file path""" global GLOBAL_CONFIG @@ -23,74 +24,77 @@ def read_config_from_file(config_path: pathlib.Path): GLOBAL_CONFIG = {} -def override_config_from_args(args: dict): +def from_args(args: dict): """Overrides the config based on a dict representation of args""" config_override = unwrap(args.get("options", {}).get("config")) if config_override: logger.info("Attempting to override config.yml from args.") - read_config_from_file(pathlib.Path(config_override)) + from_file(pathlib.Path(config_override)) return # Network config network_override = args.get("network") if network_override: - network_config = get_network_config() - GLOBAL_CONFIG["network"] = {**network_config, **network_override} + cur_network_config = network_config() + GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override} # Model config model_override = args.get("model") if model_override: - model_config = get_model_config() - GLOBAL_CONFIG["model"] = {**model_config, **model_override} + cur_model_config = model_config() + GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override} - # Logging config - logging_override = args.get("logging") - if logging_override: - logging_config = get_gen_logging_config() + # Generation Logging config + gen_logging_override = args.get("logging") + if gen_logging_override: + cur_gen_logging_config = gen_logging_config() GLOBAL_CONFIG["logging"] = { - **logging_config, - **{k.replace("log_", ""): logging_override[k] for k in logging_override}, + **cur_gen_logging_config, + **{ + k.replace("log_", ""): gen_logging_override[k] + for k in gen_logging_override + }, } developer_override = args.get("developer") if developer_override: - developer_config = get_developer_config() - GLOBAL_CONFIG["developer"] = {**developer_config, **developer_override} + cur_developer_config = developer_config() + GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override} -def get_sampling_config(): +def sampling_config(): """Returns the sampling parameter config from the global config""" return unwrap(GLOBAL_CONFIG.get("sampling"), {}) -def get_model_config(): +def model_config(): """Returns the model config from the global config""" return unwrap(GLOBAL_CONFIG.get("model"), {}) -def get_draft_model_config(): +def draft_model_config(): """Returns the draft model config from the global config""" model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) return unwrap(model_config.get("draft"), {}) -def get_lora_config(): +def lora_config(): """Returns the lora config from the global config""" model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) return unwrap(model_config.get("lora"), {}) -def get_network_config(): +def network_config(): """Returns the network config from the global config""" return unwrap(GLOBAL_CONFIG.get("network"), {}) -def get_gen_logging_config(): +def gen_logging_config(): """Returns the generation logging config from the global config""" return unwrap(GLOBAL_CONFIG.get("logging"), {}) -def get_developer_config(): +def developer_config(): """Returns the developer specific config from the global config""" return unwrap(GLOBAL_CONFIG.get("developer"), {}) diff --git a/common/gen_logging.py b/common/gen_logging.py index 1fe84e8..db8acf1 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -6,15 +6,15 @@ from loguru import logger from typing import Dict, Optional -class LogPreferences(BaseModel): +class GenLogPreferences(BaseModel): """Logging preference config.""" prompt: bool = False generation_params: bool = False -# Global reference to logging preferences -PREFERENCES = LogPreferences() +# Global logging preferences constant +PREFERENCES = GenLogPreferences() def update_from_dict(options_dict: Dict[str, bool]): @@ -26,7 +26,7 @@ def update_from_dict(options_dict: Dict[str, bool]): if value is None: value = False - PREFERENCES = LogPreferences.model_validate(options_dict) + PREFERENCES = GenLogPreferences.model_validate(options_dict) def broadcast_status(): diff --git a/common/model.py b/common/model.py index e626237..60484e7 100644 --- a/common/model.py +++ b/common/model.py @@ -13,6 +13,7 @@ from common.logger import get_loading_progress_bar from common.utils import load_progress +# Global model container container: Optional[ExllamaV2Container] = None diff --git a/common/sampling.py b/common/sampling.py index 4acbbbb..aa7e7d0 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -274,32 +274,28 @@ class BaseSamplerRequest(BaseModel): # Global for default overrides -DEFAULT_OVERRIDES = {} +overrides = {} -def get_sampler_overrides(): - return DEFAULT_OVERRIDES - - -def set_overrides_from_dict(new_overrides: dict): +def overrides_from_dict(new_overrides: dict): """Wrapper function to update sampler overrides""" - global DEFAULT_OVERRIDES + global overrides if isinstance(new_overrides, dict): - DEFAULT_OVERRIDES = prune_dict(new_overrides) + overrides = prune_dict(new_overrides) else: raise TypeError("New sampler overrides must be a dict!") -def set_overrides_from_file(preset_name: str): +def overrides_from_file(preset_name: str): """Fetches an override preset from a file""" preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") if preset_path.exists(): with open(preset_path, "r", encoding="utf8") as raw_preset: preset = yaml.safe_load(raw_preset) - set_overrides_from_dict(preset) + overrides_from_dict(preset) logger.info("Applied sampler overrides from file.") else: @@ -316,13 +312,13 @@ def set_overrides_from_file(preset_name: str): def get_default_sampler_value(key, fallback=None): """Gets an overridden default sampler value""" - return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback) + return unwrap(overrides.get(key, {}).get("override"), fallback) def apply_forced_sampler_overrides(params: BaseSamplerRequest): """Forcefully applies overrides if specified by the user""" - for var, value in DEFAULT_OVERRIDES.items(): + for var, value in overrides.items(): override = value.get("override") force = unwrap(value.get("force"), False) if force and override: diff --git a/main.py b/main.py index 878fe70..a28e034 100644 --- a/main.py +++ b/main.py @@ -18,33 +18,16 @@ from fastapi.middleware.cors import CORSMiddleware from functools import partial from loguru import logger -from common.logger import UVICORN_LOG_CONFIG, setup_logger -import common.gen_logging as gen_logging from backends.exllamav2.utils import check_exllama_version -from common import model +from common import config, model, gen_logging, sampling from common.args import convert_args_to_dict, init_argparser from common.auth import check_admin_key, check_api_key, load_auth_keys -from common.config import ( - get_developer_config, - get_sampling_config, - override_config_from_args, - read_config_from_file, - get_gen_logging_config, - get_model_config, - get_draft_model_config, - get_lora_config, - get_network_config, -) from common.generators import ( call_with_semaphore, generate_with_semaphore, release_semaphore, ) -from common.sampling import ( - get_sampler_overrides, - set_overrides_from_file, - set_overrides_from_dict, -) +from common.logger import UVICORN_LOG_CONFIG, setup_logger from common.templating import ( get_all_templates, get_prompt_from_template, @@ -119,11 +102,11 @@ async def check_model_container(): @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): """Lists all models in the model directory.""" - model_config = get_model_config() + model_config = config.model_config() model_dir = unwrap(model_config.get("model_dir"), "models") model_path = pathlib.Path(model_dir) - draft_model_dir = get_draft_model_config().get("draft_model_dir") + draft_model_dir = config.draft_model_config().get("draft_model_dir") models = get_model_list(model_path.resolve(), draft_model_dir) if unwrap(model_config.get("use_dummy_models"), False): @@ -170,7 +153,9 @@ async def get_current_model(): @app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(): """Lists all draft models in the model directory.""" - draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models") + draft_model_dir = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) draft_model_path = pathlib.Path(draft_model_dir) models = get_model_list(draft_model_path.resolve()) @@ -187,7 +172,7 @@ async def load_model(request: Request, data: ModelLoadRequest): if not data.name: raise HTTPException(400, "A model name was not provided.") - model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models")) + model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) model_path = model_path / data.name load_data = data.model_dump() @@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest): ) load_data["draft"]["draft_model_dir"] = unwrap( - get_draft_model_config().get("draft_model_dir"), "models" + config.draft_model_config().get("draft_model_dir"), "models" ) if not model_path.exists(): @@ -309,7 +294,7 @@ async def unload_template(): async def list_sampler_overrides(): """API wrapper to list all currently applied sampler overrides""" - return get_sampler_overrides() + return sampling.overrides @app.post( @@ -321,13 +306,13 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): if data.preset: try: - set_overrides_from_file(data.preset) + sampling.overrides_from_file(data.preset) except FileNotFoundError as e: raise HTTPException( 400, "Sampler override preset does not exist. Check the name?" ) from e elif data.overrides: - set_overrides_from_dict(data.overrides) + sampling.overrides_from_dict(data.overrides) else: raise HTTPException( 400, "A sampler override preset or dictionary wasn't provided." @@ -341,7 +326,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): async def unload_sampler_override(): """Unloads the currently selected override preset""" - set_overrides_from_dict({}) + sampling.overrides_from_dict({}) # Lora list endpoint @@ -349,7 +334,7 @@ async def unload_sampler_override(): @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) async def get_all_loras(): """Lists all LoRAs in the lora directory.""" - lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras")) + lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) return loras @@ -387,7 +372,7 @@ async def load_lora(data: LoraLoadRequest): if not data.loras: raise HTTPException(400, "List of loras to load is not found.") - lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras")) + lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) if not lora_dir.exists(): raise HTTPException( 400, @@ -468,7 +453,7 @@ async def generate_completion(request: Request, data: CompletionRequest): data.prompt = "\n".join(data.prompt) disable_request_streaming = unwrap( - get_developer_config().get("disable_request_streaming"), False + config.developer_config().get("disable_request_streaming"), False ) if data.stream and not disable_request_streaming: @@ -565,7 +550,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest ) from exc disable_request_streaming = unwrap( - get_developer_config().get("disable_request_streaming"), False + config.developer_config().get("disable_request_streaming"), False ) if data.stream and not disable_request_streaming: @@ -656,16 +641,16 @@ async def entrypoint(args: Optional[dict] = None): signal.signal(signal.SIGTERM, signal_handler) # Load from YAML config - read_config_from_file(pathlib.Path("config.yml")) + config.from_file(pathlib.Path("config.yml")) # Parse and override config from args if args is None: parser = init_argparser() args = convert_args_to_dict(parser.parse_args(), parser) - override_config_from_args(args) + config.from_args(args) - developer_config = get_developer_config() + developer_config = config.developer_config() # Check exllamav2 version and give a descriptive error if it's too old # Skip if launching unsafely @@ -683,7 +668,7 @@ async def entrypoint(args: Optional[dict] = None): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" logger.warning("Enabled the experimental CUDA malloc backend.") - network_config = get_network_config() + network_config = config.network_config() host = unwrap(network_config.get("host"), "127.0.0.1") port = unwrap(network_config.get("port"), 5000) @@ -711,24 +696,24 @@ async def entrypoint(args: Optional[dict] = None): load_auth_keys(unwrap(network_config.get("disable_auth"), False)) # Override the generation log options if given - log_config = get_gen_logging_config() + log_config = config.gen_logging_config() if log_config: gen_logging.update_from_dict(log_config) gen_logging.broadcast_status() # Set sampler parameter overrides if provided - sampling_config = get_sampling_config() + sampling_config = config.sampling_config() sampling_override_preset = sampling_config.get("override_preset") if sampling_override_preset: try: - set_overrides_from_file(sampling_override_preset) + sampling.overrides_from_file(sampling_override_preset) except FileNotFoundError as e: logger.warning(str(e)) # If an initial model name is specified, create a container # and load the model - model_config = get_model_config() + model_config = config.model_config() model_name = model_config.get("model_name") if model_name: model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) @@ -737,7 +722,7 @@ async def entrypoint(args: Optional[dict] = None): await model.load_model(model_path.resolve(), **model_config) # Load loras after loading the model - lora_config = get_lora_config() + lora_config = config.lora_config() if lora_config.get("loras"): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) model.container.load_loras(lora_dir.resolve(), **lora_config)