Tree: Update to cleanup globals
Use the module singleton pattern to share global state. This can also be a modified version of the Global Object Pattern. The main reason this pattern is used is for ease of use when handling global state rather than adding extra dependencies for a DI parameter. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b373b25235
commit
5a2de30066
7 changed files with 68 additions and 82 deletions
67
main.py
67
main.py
|
|
@ -18,33 +18,16 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from functools import partial
|
||||
from loguru import logger
|
||||
|
||||
from common.logger import UVICORN_LOG_CONFIG, setup_logger
|
||||
import common.gen_logging as gen_logging
|
||||
from backends.exllamav2.utils import check_exllama_version
|
||||
from common import model
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
from common.auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from common.config import (
|
||||
get_developer_config,
|
||||
get_sampling_config,
|
||||
override_config_from_args,
|
||||
read_config_from_file,
|
||||
get_gen_logging_config,
|
||||
get_model_config,
|
||||
get_draft_model_config,
|
||||
get_lora_config,
|
||||
get_network_config,
|
||||
)
|
||||
from common.generators import (
|
||||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
release_semaphore,
|
||||
)
|
||||
from common.sampling import (
|
||||
get_sampler_overrides,
|
||||
set_overrides_from_file,
|
||||
set_overrides_from_dict,
|
||||
)
|
||||
from common.logger import UVICORN_LOG_CONFIG, setup_logger
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
|
|
@ -119,11 +102,11 @@ async def check_model_container():
|
|||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = get_model_config()
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = get_draft_model_config().get("draft_model_dir")
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
|
|
@ -170,7 +153,9 @@ async def get_current_model():
|
|||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models():
|
||||
"""Lists all draft models in the model directory."""
|
||||
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
|
|
@ -187,7 +172,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not data.name:
|
||||
raise HTTPException(400, "A model name was not provided.")
|
||||
|
||||
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
load_data = data.model_dump()
|
||||
|
|
@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
)
|
||||
|
||||
load_data["draft"]["draft_model_dir"] = unwrap(
|
||||
get_draft_model_config().get("draft_model_dir"), "models"
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
|
|
@ -309,7 +294,7 @@ async def unload_template():
|
|||
async def list_sampler_overrides():
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return get_sampler_overrides()
|
||||
return sampling.overrides
|
||||
|
||||
|
||||
@app.post(
|
||||
|
|
@ -321,13 +306,13 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
|
||||
if data.preset:
|
||||
try:
|
||||
set_overrides_from_file(data.preset)
|
||||
sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(
|
||||
400, "Sampler override preset does not exist. Check the name?"
|
||||
) from e
|
||||
elif data.overrides:
|
||||
set_overrides_from_dict(data.overrides)
|
||||
sampling.overrides_from_dict(data.overrides)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400, "A sampler override preset or dictionary wasn't provided."
|
||||
|
|
@ -341,7 +326,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
set_overrides_from_dict({})
|
||||
sampling.overrides_from_dict({})
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
|
|
@ -349,7 +334,7 @@ async def unload_sampler_override():
|
|||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_all_loras():
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
return loras
|
||||
|
|
@ -387,7 +372,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
raise HTTPException(
|
||||
400,
|
||||
|
|
@ -468,7 +453,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
get_developer_config().get("disable_request_streaming"), False
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
|
|
@ -565,7 +550,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
) from exc
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
get_developer_config().get("disable_request_streaming"), False
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
|
|
@ -656,16 +641,16 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Load from YAML config
|
||||
read_config_from_file(pathlib.Path("config.yml"))
|
||||
config.from_file(pathlib.Path("config.yml"))
|
||||
|
||||
# Parse and override config from args
|
||||
if args is None:
|
||||
parser = init_argparser()
|
||||
args = convert_args_to_dict(parser.parse_args(), parser)
|
||||
|
||||
override_config_from_args(args)
|
||||
config.from_args(args)
|
||||
|
||||
developer_config = get_developer_config()
|
||||
developer_config = config.developer_config()
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
|
|
@ -683,7 +668,7 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("Enabled the experimental CUDA malloc backend.")
|
||||
|
||||
network_config = get_network_config()
|
||||
network_config = config.network_config()
|
||||
|
||||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
||||
port = unwrap(network_config.get("port"), 5000)
|
||||
|
|
@ -711,24 +696,24 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
log_config = get_gen_logging_config()
|
||||
log_config = config.gen_logging_config()
|
||||
if log_config:
|
||||
gen_logging.update_from_dict(log_config)
|
||||
|
||||
gen_logging.broadcast_status()
|
||||
|
||||
# Set sampler parameter overrides if provided
|
||||
sampling_config = get_sampling_config()
|
||||
sampling_config = config.sampling_config()
|
||||
sampling_override_preset = sampling_config.get("override_preset")
|
||||
if sampling_override_preset:
|
||||
try:
|
||||
set_overrides_from_file(sampling_override_preset)
|
||||
sampling.overrides_from_file(sampling_override_preset)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
model_config = get_model_config()
|
||||
model_config = config.model_config()
|
||||
model_name = model_config.get("model_name")
|
||||
if model_name:
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
|
|
@ -737,7 +722,7 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
await model.load_model(model_path.resolve(), **model_config)
|
||||
|
||||
# Load loras after loading the model
|
||||
lora_config = get_lora_config()
|
||||
lora_config = config.lora_config()
|
||||
if lora_config.get("loras"):
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue