Tree: Update to cleanup globals

Use the module singleton pattern to share global state. This can also
be a modified version of the Global Object Pattern. The main reason
this pattern is used is for ease of use when handling global state
rather than adding extra dependencies for a DI parameter.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-09 22:31:47 -05:00 committed by Brian Dashore
parent b373b25235
commit 5a2de30066
7 changed files with 68 additions and 82 deletions

67
main.py
View file

@ -18,33 +18,16 @@ from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
from common.logger import UVICORN_LOG_CONFIG, setup_logger
import common.gen_logging as gen_logging
from backends.exllamav2.utils import check_exllama_version
from common import model
from common import config, model, gen_logging, sampling
from common.args import convert_args_to_dict, init_argparser
from common.auth import check_admin_key, check_api_key, load_auth_keys
from common.config import (
get_developer_config,
get_sampling_config,
override_config_from_args,
read_config_from_file,
get_gen_logging_config,
get_model_config,
get_draft_model_config,
get_lora_config,
get_network_config,
)
from common.generators import (
call_with_semaphore,
generate_with_semaphore,
release_semaphore,
)
from common.sampling import (
get_sampler_overrides,
set_overrides_from_file,
set_overrides_from_dict,
)
from common.logger import UVICORN_LOG_CONFIG, setup_logger
from common.templating import (
get_all_templates,
get_prompt_from_template,
@ -119,11 +102,11 @@ async def check_model_container():
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = get_model_config()
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = get_draft_model_config().get("draft_model_dir")
draft_model_dir = config.draft_model_config().get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
@ -170,7 +153,9 @@ async def get_current_model():
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@ -187,7 +172,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "A model name was not provided.")
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name
load_data = data.model_dump()
@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
load_data["draft"]["draft_model_dir"] = unwrap(
get_draft_model_config().get("draft_model_dir"), "models"
config.draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
@ -309,7 +294,7 @@ async def unload_template():
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return get_sampler_overrides()
return sampling.overrides
@app.post(
@ -321,13 +306,13 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
if data.preset:
try:
set_overrides_from_file(data.preset)
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
raise HTTPException(
400, "Sampler override preset does not exist. Check the name?"
) from e
elif data.overrides:
set_overrides_from_dict(data.overrides)
sampling.overrides_from_dict(data.overrides)
else:
raise HTTPException(
400, "A sampler override preset or dictionary wasn't provided."
@ -341,7 +326,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
set_overrides_from_dict({})
sampling.overrides_from_dict({})
# Lora list endpoint
@ -349,7 +334,7 @@ async def unload_sampler_override():
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
return loras
@ -387,7 +372,7 @@ async def load_lora(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(
400,
@ -468,7 +453,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
get_developer_config().get("disable_request_streaming"), False
config.developer_config().get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:
@ -565,7 +550,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
) from exc
disable_request_streaming = unwrap(
get_developer_config().get("disable_request_streaming"), False
config.developer_config().get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:
@ -656,16 +641,16 @@ async def entrypoint(args: Optional[dict] = None):
signal.signal(signal.SIGTERM, signal_handler)
# Load from YAML config
read_config_from_file(pathlib.Path("config.yml"))
config.from_file(pathlib.Path("config.yml"))
# Parse and override config from args
if args is None:
parser = init_argparser()
args = convert_args_to_dict(parser.parse_args(), parser)
override_config_from_args(args)
config.from_args(args)
developer_config = get_developer_config()
developer_config = config.developer_config()
# Check exllamav2 version and give a descriptive error if it's too old
# Skip if launching unsafely
@ -683,7 +668,7 @@ async def entrypoint(args: Optional[dict] = None):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
logger.warning("Enabled the experimental CUDA malloc backend.")
network_config = get_network_config()
network_config = config.network_config()
host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000)
@ -711,24 +696,24 @@ async def entrypoint(args: Optional[dict] = None):
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
# Override the generation log options if given
log_config = get_gen_logging_config()
log_config = config.gen_logging_config()
if log_config:
gen_logging.update_from_dict(log_config)
gen_logging.broadcast_status()
# Set sampler parameter overrides if provided
sampling_config = get_sampling_config()
sampling_config = config.sampling_config()
sampling_override_preset = sampling_config.get("override_preset")
if sampling_override_preset:
try:
set_overrides_from_file(sampling_override_preset)
sampling.overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))
# If an initial model name is specified, create a container
# and load the model
model_config = get_model_config()
model_config = config.model_config()
model_name = model_config.get("model_name")
if model_name:
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
@ -737,7 +722,7 @@ async def entrypoint(args: Optional[dict] = None):
await model.load_model(model_path.resolve(), **model_config)
# Load loras after loading the model
lora_config = get_lora_config()
lora_config = config.lora_config()
if lora_config.get("loras"):
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model.container.load_loras(lora_dir.resolve(), **lora_config)