Tree: Update to cleanup globals

Use the module singleton pattern to share global state. This can also
be a modified version of the Global Object Pattern. The main reason
this pattern is used is for ease of use when handling global state
rather than adding extra dependencies for a DI parameter.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-09 22:31:47 -05:00 committed by Brian Dashore
parent b373b25235
commit 5a2de30066
7 changed files with 68 additions and 82 deletions

View file

@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional
from common.gen_logging import LogPreferences
from common.gen_logging import GenLogPreferences
class ModelCardParameters(BaseModel):
@ -30,7 +30,7 @@ class ModelCard(BaseModel):
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
logging: Optional[LogPreferences] = None
logging: Optional[GenLogPreferences] = None
parameters: Optional[ModelCardParameters] = None
@ -53,7 +53,6 @@ class DraftModelLoadRequest(BaseModel):
)
# TODO: Unify this with ModelCardParams
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""

View file

@ -32,6 +32,7 @@ class AuthKeys(BaseModel):
return False
# Global auth constants
AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False

View file

@ -4,10 +4,11 @@ from loguru import logger
from common.utils import unwrap
# Global config dictionary constant
GLOBAL_CONFIG: dict = {}
def read_config_from_file(config_path: pathlib.Path):
def from_file(config_path: pathlib.Path):
"""Sets the global config from a given file path"""
global GLOBAL_CONFIG
@ -23,74 +24,77 @@ def read_config_from_file(config_path: pathlib.Path):
GLOBAL_CONFIG = {}
def override_config_from_args(args: dict):
def from_args(args: dict):
"""Overrides the config based on a dict representation of args"""
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Attempting to override config.yml from args.")
read_config_from_file(pathlib.Path(config_override))
from_file(pathlib.Path(config_override))
return
# Network config
network_override = args.get("network")
if network_override:
network_config = get_network_config()
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
cur_network_config = network_config()
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
# Model config
model_override = args.get("model")
if model_override:
model_config = get_model_config()
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
cur_model_config = model_config()
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
# Logging config
logging_override = args.get("logging")
if logging_override:
logging_config = get_gen_logging_config()
# Generation Logging config
gen_logging_override = args.get("logging")
if gen_logging_override:
cur_gen_logging_config = gen_logging_config()
GLOBAL_CONFIG["logging"] = {
**logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
**cur_gen_logging_config,
**{
k.replace("log_", ""): gen_logging_override[k]
for k in gen_logging_override
},
}
developer_override = args.get("developer")
if developer_override:
developer_config = get_developer_config()
GLOBAL_CONFIG["developer"] = {**developer_config, **developer_override}
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
def get_sampling_config():
def sampling_config():
"""Returns the sampling parameter config from the global config"""
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
def get_model_config():
def model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})
def get_draft_model_config():
def draft_model_config():
"""Returns the draft model config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("draft"), {})
def get_lora_config():
def lora_config():
"""Returns the lora config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("lora"), {})
def get_network_config():
def network_config():
"""Returns the network config from the global config"""
return unwrap(GLOBAL_CONFIG.get("network"), {})
def get_gen_logging_config():
def gen_logging_config():
"""Returns the generation logging config from the global config"""
return unwrap(GLOBAL_CONFIG.get("logging"), {})
def get_developer_config():
def developer_config():
"""Returns the developer specific config from the global config"""
return unwrap(GLOBAL_CONFIG.get("developer"), {})

View file

@ -6,15 +6,15 @@ from loguru import logger
from typing import Dict, Optional
class LogPreferences(BaseModel):
class GenLogPreferences(BaseModel):
"""Logging preference config."""
prompt: bool = False
generation_params: bool = False
# Global reference to logging preferences
PREFERENCES = LogPreferences()
# Global logging preferences constant
PREFERENCES = GenLogPreferences()
def update_from_dict(options_dict: Dict[str, bool]):
@ -26,7 +26,7 @@ def update_from_dict(options_dict: Dict[str, bool]):
if value is None:
value = False
PREFERENCES = LogPreferences.model_validate(options_dict)
PREFERENCES = GenLogPreferences.model_validate(options_dict)
def broadcast_status():

View file

@ -13,6 +13,7 @@ from common.logger import get_loading_progress_bar
from common.utils import load_progress
# Global model container
container: Optional[ExllamaV2Container] = None

View file

@ -274,32 +274,28 @@ class BaseSamplerRequest(BaseModel):
# Global for default overrides
DEFAULT_OVERRIDES = {}
overrides = {}
def get_sampler_overrides():
return DEFAULT_OVERRIDES
def set_overrides_from_dict(new_overrides: dict):
def overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""
global DEFAULT_OVERRIDES
global overrides
if isinstance(new_overrides, dict):
DEFAULT_OVERRIDES = prune_dict(new_overrides)
overrides = prune_dict(new_overrides)
else:
raise TypeError("New sampler overrides must be a dict!")
def set_overrides_from_file(preset_name: str):
def overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
set_overrides_from_dict(preset)
overrides_from_dict(preset)
logger.info("Applied sampler overrides from file.")
else:
@ -316,13 +312,13 @@ def set_overrides_from_file(preset_name: str):
def get_default_sampler_value(key, fallback=None):
"""Gets an overridden default sampler value"""
return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback)
return unwrap(overrides.get(key, {}).get("override"), fallback)
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
"""Forcefully applies overrides if specified by the user"""
for var, value in DEFAULT_OVERRIDES.items():
for var, value in overrides.items():
override = value.get("override")
force = unwrap(value.get("force"), False)
if force and override:

67
main.py
View file

@ -18,33 +18,16 @@ from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
from common.logger import UVICORN_LOG_CONFIG, setup_logger
import common.gen_logging as gen_logging
from backends.exllamav2.utils import check_exllama_version
from common import model
from common import config, model, gen_logging, sampling
from common.args import convert_args_to_dict, init_argparser
from common.auth import check_admin_key, check_api_key, load_auth_keys
from common.config import (
get_developer_config,
get_sampling_config,
override_config_from_args,
read_config_from_file,
get_gen_logging_config,
get_model_config,
get_draft_model_config,
get_lora_config,
get_network_config,
)
from common.generators import (
call_with_semaphore,
generate_with_semaphore,
release_semaphore,
)
from common.sampling import (
get_sampler_overrides,
set_overrides_from_file,
set_overrides_from_dict,
)
from common.logger import UVICORN_LOG_CONFIG, setup_logger
from common.templating import (
get_all_templates,
get_prompt_from_template,
@ -119,11 +102,11 @@ async def check_model_container():
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = get_model_config()
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = get_draft_model_config().get("draft_model_dir")
draft_model_dir = config.draft_model_config().get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
@ -170,7 +153,9 @@ async def get_current_model():
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@ -187,7 +172,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "A model name was not provided.")
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name
load_data = data.model_dump()
@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
load_data["draft"]["draft_model_dir"] = unwrap(
get_draft_model_config().get("draft_model_dir"), "models"
config.draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
@ -309,7 +294,7 @@ async def unload_template():
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return get_sampler_overrides()
return sampling.overrides
@app.post(
@ -321,13 +306,13 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
if data.preset:
try:
set_overrides_from_file(data.preset)
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
raise HTTPException(
400, "Sampler override preset does not exist. Check the name?"
) from e
elif data.overrides:
set_overrides_from_dict(data.overrides)
sampling.overrides_from_dict(data.overrides)
else:
raise HTTPException(
400, "A sampler override preset or dictionary wasn't provided."
@ -341,7 +326,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
set_overrides_from_dict({})
sampling.overrides_from_dict({})
# Lora list endpoint
@ -349,7 +334,7 @@ async def unload_sampler_override():
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
return loras
@ -387,7 +372,7 @@ async def load_lora(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(
400,
@ -468,7 +453,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
get_developer_config().get("disable_request_streaming"), False
config.developer_config().get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:
@ -565,7 +550,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
) from exc
disable_request_streaming = unwrap(
get_developer_config().get("disable_request_streaming"), False
config.developer_config().get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:
@ -656,16 +641,16 @@ async def entrypoint(args: Optional[dict] = None):
signal.signal(signal.SIGTERM, signal_handler)
# Load from YAML config
read_config_from_file(pathlib.Path("config.yml"))
config.from_file(pathlib.Path("config.yml"))
# Parse and override config from args
if args is None:
parser = init_argparser()
args = convert_args_to_dict(parser.parse_args(), parser)
override_config_from_args(args)
config.from_args(args)
developer_config = get_developer_config()
developer_config = config.developer_config()
# Check exllamav2 version and give a descriptive error if it's too old
# Skip if launching unsafely
@ -683,7 +668,7 @@ async def entrypoint(args: Optional[dict] = None):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
logger.warning("Enabled the experimental CUDA malloc backend.")
network_config = get_network_config()
network_config = config.network_config()
host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000)
@ -711,24 +696,24 @@ async def entrypoint(args: Optional[dict] = None):
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
# Override the generation log options if given
log_config = get_gen_logging_config()
log_config = config.gen_logging_config()
if log_config:
gen_logging.update_from_dict(log_config)
gen_logging.broadcast_status()
# Set sampler parameter overrides if provided
sampling_config = get_sampling_config()
sampling_config = config.sampling_config()
sampling_override_preset = sampling_config.get("override_preset")
if sampling_override_preset:
try:
set_overrides_from_file(sampling_override_preset)
sampling.overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))
# If an initial model name is specified, create a container
# and load the model
model_config = get_model_config()
model_config = config.model_config()
model_name = model_config.get("model_name")
if model_name:
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
@ -737,7 +722,7 @@ async def entrypoint(args: Optional[dict] = None):
await model.load_model(model_path.resolve(), **model_config)
# Load loras after loading the model
lora_config = get_lora_config()
lora_config = config.lora_config()
if lora_config.get("loras"):
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model.container.load_loras(lora_dir.resolve(), **lora_config)