Tree: Update to cleanup globals
Use the module singleton pattern to share global state. This can also be a modified version of the Global Object Pattern. The main reason this pattern is used is for ease of use when handling global state rather than adding extra dependencies for a DI parameter. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b373b25235
commit
5a2de30066
7 changed files with 68 additions and 82 deletions
|
|
@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, ConfigDict
|
|||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from common.gen_logging import LogPreferences
|
||||
from common.gen_logging import GenLogPreferences
|
||||
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
|
|
@ -30,7 +30,7 @@ class ModelCard(BaseModel):
|
|||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
logging: Optional[LogPreferences] = None
|
||||
logging: Optional[GenLogPreferences] = None
|
||||
parameters: Optional[ModelCardParameters] = None
|
||||
|
||||
|
||||
|
|
@ -53,7 +53,6 @@ class DraftModelLoadRequest(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
# TODO: Unify this with ModelCardParams
|
||||
class ModelLoadRequest(BaseModel):
|
||||
"""Represents a model load request."""
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class AuthKeys(BaseModel):
|
|||
return False
|
||||
|
||||
|
||||
# Global auth constants
|
||||
AUTH_KEYS: Optional[AuthKeys] = None
|
||||
DISABLE_AUTH: bool = False
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ from loguru import logger
|
|||
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global config dictionary constant
|
||||
GLOBAL_CONFIG: dict = {}
|
||||
|
||||
|
||||
def read_config_from_file(config_path: pathlib.Path):
|
||||
def from_file(config_path: pathlib.Path):
|
||||
"""Sets the global config from a given file path"""
|
||||
global GLOBAL_CONFIG
|
||||
|
||||
|
|
@ -23,74 +24,77 @@ def read_config_from_file(config_path: pathlib.Path):
|
|||
GLOBAL_CONFIG = {}
|
||||
|
||||
|
||||
def override_config_from_args(args: dict):
|
||||
def from_args(args: dict):
|
||||
"""Overrides the config based on a dict representation of args"""
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Attempting to override config.yml from args.")
|
||||
read_config_from_file(pathlib.Path(config_override))
|
||||
from_file(pathlib.Path(config_override))
|
||||
return
|
||||
|
||||
# Network config
|
||||
network_override = args.get("network")
|
||||
if network_override:
|
||||
network_config = get_network_config()
|
||||
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
|
||||
cur_network_config = network_config()
|
||||
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
|
||||
|
||||
# Model config
|
||||
model_override = args.get("model")
|
||||
if model_override:
|
||||
model_config = get_model_config()
|
||||
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
|
||||
cur_model_config = model_config()
|
||||
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
||||
|
||||
# Logging config
|
||||
logging_override = args.get("logging")
|
||||
if logging_override:
|
||||
logging_config = get_gen_logging_config()
|
||||
# Generation Logging config
|
||||
gen_logging_override = args.get("logging")
|
||||
if gen_logging_override:
|
||||
cur_gen_logging_config = gen_logging_config()
|
||||
GLOBAL_CONFIG["logging"] = {
|
||||
**logging_config,
|
||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
**cur_gen_logging_config,
|
||||
**{
|
||||
k.replace("log_", ""): gen_logging_override[k]
|
||||
for k in gen_logging_override
|
||||
},
|
||||
}
|
||||
|
||||
developer_override = args.get("developer")
|
||||
if developer_override:
|
||||
developer_config = get_developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**developer_config, **developer_override}
|
||||
cur_developer_config = developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||
|
||||
|
||||
def get_sampling_config():
|
||||
def sampling_config():
|
||||
"""Returns the sampling parameter config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
|
||||
|
||||
|
||||
def get_model_config():
|
||||
def model_config():
|
||||
"""Returns the model config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
|
||||
|
||||
def get_draft_model_config():
|
||||
def draft_model_config():
|
||||
"""Returns the draft model config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("draft"), {})
|
||||
|
||||
|
||||
def get_lora_config():
|
||||
def lora_config():
|
||||
"""Returns the lora config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("lora"), {})
|
||||
|
||||
|
||||
def get_network_config():
|
||||
def network_config():
|
||||
"""Returns the network config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
||||
|
||||
|
||||
def get_gen_logging_config():
|
||||
def gen_logging_config():
|
||||
"""Returns the generation logging config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
||||
|
||||
|
||||
def get_developer_config():
|
||||
def developer_config():
|
||||
"""Returns the developer specific config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||
|
|
|
|||
|
|
@ -6,15 +6,15 @@ from loguru import logger
|
|||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class LogPreferences(BaseModel):
|
||||
class GenLogPreferences(BaseModel):
|
||||
"""Logging preference config."""
|
||||
|
||||
prompt: bool = False
|
||||
generation_params: bool = False
|
||||
|
||||
|
||||
# Global reference to logging preferences
|
||||
PREFERENCES = LogPreferences()
|
||||
# Global logging preferences constant
|
||||
PREFERENCES = GenLogPreferences()
|
||||
|
||||
|
||||
def update_from_dict(options_dict: Dict[str, bool]):
|
||||
|
|
@ -26,7 +26,7 @@ def update_from_dict(options_dict: Dict[str, bool]):
|
|||
if value is None:
|
||||
value = False
|
||||
|
||||
PREFERENCES = LogPreferences.model_validate(options_dict)
|
||||
PREFERENCES = GenLogPreferences.model_validate(options_dict)
|
||||
|
||||
|
||||
def broadcast_status():
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from common.logger import get_loading_progress_bar
|
|||
from common.utils import load_progress
|
||||
|
||||
|
||||
# Global model container
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -274,32 +274,28 @@ class BaseSamplerRequest(BaseModel):
|
|||
|
||||
|
||||
# Global for default overrides
|
||||
DEFAULT_OVERRIDES = {}
|
||||
overrides = {}
|
||||
|
||||
|
||||
def get_sampler_overrides():
|
||||
return DEFAULT_OVERRIDES
|
||||
|
||||
|
||||
def set_overrides_from_dict(new_overrides: dict):
|
||||
def overrides_from_dict(new_overrides: dict):
|
||||
"""Wrapper function to update sampler overrides"""
|
||||
|
||||
global DEFAULT_OVERRIDES
|
||||
global overrides
|
||||
|
||||
if isinstance(new_overrides, dict):
|
||||
DEFAULT_OVERRIDES = prune_dict(new_overrides)
|
||||
overrides = prune_dict(new_overrides)
|
||||
else:
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
||||
def set_overrides_from_file(preset_name: str):
|
||||
def overrides_from_file(preset_name: str):
|
||||
"""Fetches an override preset from a file"""
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
set_overrides_from_dict(preset)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
else:
|
||||
|
|
@ -316,13 +312,13 @@ def set_overrides_from_file(preset_name: str):
|
|||
def get_default_sampler_value(key, fallback=None):
|
||||
"""Gets an overridden default sampler value"""
|
||||
|
||||
return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback)
|
||||
return unwrap(overrides.get(key, {}).get("override"), fallback)
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in DEFAULT_OVERRIDES.items():
|
||||
for var, value in overrides.items():
|
||||
override = value.get("override")
|
||||
force = unwrap(value.get("force"), False)
|
||||
if force and override:
|
||||
|
|
|
|||
67
main.py
67
main.py
|
|
@ -18,33 +18,16 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from functools import partial
|
||||
from loguru import logger
|
||||
|
||||
from common.logger import UVICORN_LOG_CONFIG, setup_logger
|
||||
import common.gen_logging as gen_logging
|
||||
from backends.exllamav2.utils import check_exllama_version
|
||||
from common import model
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
from common.auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from common.config import (
|
||||
get_developer_config,
|
||||
get_sampling_config,
|
||||
override_config_from_args,
|
||||
read_config_from_file,
|
||||
get_gen_logging_config,
|
||||
get_model_config,
|
||||
get_draft_model_config,
|
||||
get_lora_config,
|
||||
get_network_config,
|
||||
)
|
||||
from common.generators import (
|
||||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
release_semaphore,
|
||||
)
|
||||
from common.sampling import (
|
||||
get_sampler_overrides,
|
||||
set_overrides_from_file,
|
||||
set_overrides_from_dict,
|
||||
)
|
||||
from common.logger import UVICORN_LOG_CONFIG, setup_logger
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
|
|
@ -119,11 +102,11 @@ async def check_model_container():
|
|||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = get_model_config()
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = get_draft_model_config().get("draft_model_dir")
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
|
|
@ -170,7 +153,9 @@ async def get_current_model():
|
|||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models():
|
||||
"""Lists all draft models in the model directory."""
|
||||
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
|
|
@ -187,7 +172,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not data.name:
|
||||
raise HTTPException(400, "A model name was not provided.")
|
||||
|
||||
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
load_data = data.model_dump()
|
||||
|
|
@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
)
|
||||
|
||||
load_data["draft"]["draft_model_dir"] = unwrap(
|
||||
get_draft_model_config().get("draft_model_dir"), "models"
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
|
|
@ -309,7 +294,7 @@ async def unload_template():
|
|||
async def list_sampler_overrides():
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return get_sampler_overrides()
|
||||
return sampling.overrides
|
||||
|
||||
|
||||
@app.post(
|
||||
|
|
@ -321,13 +306,13 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
|
||||
if data.preset:
|
||||
try:
|
||||
set_overrides_from_file(data.preset)
|
||||
sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(
|
||||
400, "Sampler override preset does not exist. Check the name?"
|
||||
) from e
|
||||
elif data.overrides:
|
||||
set_overrides_from_dict(data.overrides)
|
||||
sampling.overrides_from_dict(data.overrides)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400, "A sampler override preset or dictionary wasn't provided."
|
||||
|
|
@ -341,7 +326,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
set_overrides_from_dict({})
|
||||
sampling.overrides_from_dict({})
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
|
|
@ -349,7 +334,7 @@ async def unload_sampler_override():
|
|||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_all_loras():
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
return loras
|
||||
|
|
@ -387,7 +372,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
raise HTTPException(
|
||||
400,
|
||||
|
|
@ -468,7 +453,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
get_developer_config().get("disable_request_streaming"), False
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
|
|
@ -565,7 +550,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
) from exc
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
get_developer_config().get("disable_request_streaming"), False
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
|
|
@ -656,16 +641,16 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Load from YAML config
|
||||
read_config_from_file(pathlib.Path("config.yml"))
|
||||
config.from_file(pathlib.Path("config.yml"))
|
||||
|
||||
# Parse and override config from args
|
||||
if args is None:
|
||||
parser = init_argparser()
|
||||
args = convert_args_to_dict(parser.parse_args(), parser)
|
||||
|
||||
override_config_from_args(args)
|
||||
config.from_args(args)
|
||||
|
||||
developer_config = get_developer_config()
|
||||
developer_config = config.developer_config()
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
|
|
@ -683,7 +668,7 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("Enabled the experimental CUDA malloc backend.")
|
||||
|
||||
network_config = get_network_config()
|
||||
network_config = config.network_config()
|
||||
|
||||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
||||
port = unwrap(network_config.get("port"), 5000)
|
||||
|
|
@ -711,24 +696,24 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
log_config = get_gen_logging_config()
|
||||
log_config = config.gen_logging_config()
|
||||
if log_config:
|
||||
gen_logging.update_from_dict(log_config)
|
||||
|
||||
gen_logging.broadcast_status()
|
||||
|
||||
# Set sampler parameter overrides if provided
|
||||
sampling_config = get_sampling_config()
|
||||
sampling_config = config.sampling_config()
|
||||
sampling_override_preset = sampling_config.get("override_preset")
|
||||
if sampling_override_preset:
|
||||
try:
|
||||
set_overrides_from_file(sampling_override_preset)
|
||||
sampling.overrides_from_file(sampling_override_preset)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
model_config = get_model_config()
|
||||
model_config = config.model_config()
|
||||
model_name = model_config.get("model_name")
|
||||
if model_name:
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
|
|
@ -737,7 +722,7 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
await model.load_model(model_path.resolve(), **model_config)
|
||||
|
||||
# Load loras after loading the model
|
||||
lora_config = get_lora_config()
|
||||
lora_config = config.lora_config()
|
||||
if lora_config.get("loras"):
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue