Merge pull request #185 from SecretiveShell/refactor-config-loading
Refactor config loading
This commit is contained in:
commit
ec7f64d530
10 changed files with 160 additions and 176 deletions
107
common/config.py
107
common/config.py
|
|
@ -1,107 +0,0 @@
|
|||
import yaml
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global config dictionary constant
|
||||
GLOBAL_CONFIG: dict = {}
|
||||
|
||||
|
||||
def from_file(config_path: pathlib.Path):
|
||||
"""Sets the global config from a given file path"""
|
||||
global GLOBAL_CONFIG
|
||||
|
||||
try:
|
||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"The YAML config couldn't load because of the following error: "
|
||||
f"\n\n{exc}"
|
||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
||||
)
|
||||
GLOBAL_CONFIG = {}
|
||||
|
||||
|
||||
def from_args(args: dict):
|
||||
"""Overrides the config based on a dict representation of args"""
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Attempting to override config.yml from args.")
|
||||
from_file(pathlib.Path(config_override))
|
||||
return
|
||||
|
||||
# Network config
|
||||
network_override = args.get("network")
|
||||
if network_override:
|
||||
cur_network_config = network_config()
|
||||
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
|
||||
|
||||
# Model config
|
||||
model_override = args.get("model")
|
||||
if model_override:
|
||||
cur_model_config = model_config()
|
||||
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
||||
|
||||
# Generation Logging config
|
||||
logging_override = args.get("logging")
|
||||
if logging_override:
|
||||
cur_logging_config = logging_config()
|
||||
GLOBAL_CONFIG["logging"] = {
|
||||
**cur_logging_config,
|
||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
}
|
||||
|
||||
developer_override = args.get("developer")
|
||||
if developer_override:
|
||||
cur_developer_config = developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||
|
||||
embeddings_override = args.get("embeddings")
|
||||
if embeddings_override:
|
||||
cur_embeddings_config = embeddings_config()
|
||||
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
||||
|
||||
|
||||
def sampling_config():
|
||||
"""Returns the sampling parameter config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
|
||||
|
||||
|
||||
def model_config():
|
||||
"""Returns the model config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
|
||||
|
||||
def draft_model_config():
|
||||
"""Returns the draft model config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("draft"), {})
|
||||
|
||||
|
||||
def lora_config():
|
||||
"""Returns the lora config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("lora"), {})
|
||||
|
||||
|
||||
def network_config():
|
||||
"""Returns the network config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
||||
|
||||
|
||||
def logging_config():
|
||||
"""Returns the logging config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
||||
|
||||
|
||||
def developer_config():
|
||||
"""Returns the developer specific config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||
|
||||
|
||||
def embeddings_config():
|
||||
"""Returns the embeddings config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
|
||||
|
|
@ -10,8 +10,8 @@ from loguru import logger
|
|||
from rich.progress import Progress
|
||||
from typing import List, Optional
|
||||
|
||||
from common.config import lora_config, model_config
|
||||
from common.logger import get_progress_bar
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
|
|
@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
|
|||
"""Gets the download folder for the repo."""
|
||||
|
||||
if repo_type == "lora":
|
||||
download_path = pathlib.Path(lora_config().get("lora_dir") or "loras")
|
||||
download_path = pathlib.Path(config.lora.get("lora_dir") or "loras")
|
||||
else:
|
||||
download_path = pathlib.Path(model_config().get("model_dir") or "models")
|
||||
download_path = pathlib.Path(config.model.get("model_dir") or "models")
|
||||
|
||||
download_path = download_path / (folder_name or repo_id.split("/")[-1])
|
||||
return download_path
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from fastapi import HTTPException
|
|||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common import config
|
||||
from common.logger import get_loading_progress_bar
|
||||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
|
|
@ -153,8 +153,7 @@ async def unload_embedding_model():
|
|||
def get_config_default(key: str, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
model_config = config.model_config()
|
||||
default_keys = unwrap(model_config.get("use_as_default"), [])
|
||||
default_keys = unwrap(config.model.get("use_as_default"), [])
|
||||
|
||||
# Add extra keys to defaults
|
||||
default_keys.append("embeddings_device")
|
||||
|
|
@ -162,13 +161,11 @@ def get_config_default(key: str, model_type: str = "model"):
|
|||
if key in default_keys:
|
||||
# Is this a draft model load parameter?
|
||||
if model_type == "draft":
|
||||
draft_config = config.draft_model_config()
|
||||
return draft_config.get(key)
|
||||
return config.draft_model.get(key)
|
||||
elif model_type == "embedding":
|
||||
embeddings_config = config.embeddings_config()
|
||||
return embeddings_config.get(key)
|
||||
return config.embeddings.get(key)
|
||||
else:
|
||||
return model_config.get(key)
|
||||
return config.model.get(key)
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from common import config
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ def handle_request_error(message: str, exc_info: bool = True):
|
|||
"""Log a request error to the console."""
|
||||
|
||||
trace = traceback.format_exc()
|
||||
send_trace = unwrap(config.network_config().get("send_tracebacks"), False)
|
||||
send_trace = unwrap(config.network.get("send_tracebacks"), False)
|
||||
|
||||
error_message = TabbyRequestErrorMessage(
|
||||
message=message, trace=trace if send_trace else None
|
||||
|
|
@ -134,7 +134,7 @@ def get_global_depends():
|
|||
|
||||
depends = [Depends(add_request_id)]
|
||||
|
||||
if config.logging_config().get("requests"):
|
||||
if config.logging.get("requests"):
|
||||
depends.append(Depends(log_request))
|
||||
|
||||
return depends
|
||||
|
|
|
|||
88
common/tabby_config.py
Normal file
88
common/tabby_config.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
import yaml
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common.utils import unwrap, merge_dicts
|
||||
|
||||
|
||||
class TabbyConfig:
|
||||
network: dict = {}
|
||||
logging: dict = {}
|
||||
model: dict = {}
|
||||
draft_model: dict = {}
|
||||
lora: dict = {}
|
||||
sampling: dict = {}
|
||||
developer: dict = {}
|
||||
embeddings: dict = {}
|
||||
|
||||
def load(self, arguments: Optional[dict] = None):
|
||||
"""load the global application config"""
|
||||
|
||||
# config is applied in order of items in the list
|
||||
configs = [
|
||||
self._from_file(pathlib.Path("config.yml")),
|
||||
self._from_args(unwrap(arguments, {})),
|
||||
]
|
||||
|
||||
merged_config = merge_dicts(*configs)
|
||||
|
||||
self.network = unwrap(merged_config.get("network"), {})
|
||||
self.logging = unwrap(merged_config.get("logging"), {})
|
||||
self.model = unwrap(merged_config.get("model"), {})
|
||||
self.draft_model = unwrap(merged_config.get("draft"), {})
|
||||
self.lora = unwrap(merged_config.get("draft"), {})
|
||||
self.sampling = unwrap(merged_config.get("sampling"), {})
|
||||
self.developer = unwrap(merged_config.get("developer"), {})
|
||||
self.embeddings = unwrap(merged_config.get("embeddings"), {})
|
||||
|
||||
def _from_file(self, config_path: pathlib.Path):
|
||||
"""loads config from a given file path"""
|
||||
|
||||
# try loading from file
|
||||
try:
|
||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||
return unwrap(yaml.safe_load(config_file), {})
|
||||
except FileNotFoundError:
|
||||
logger.info(f"The '{config_path.name}' file cannot be found")
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"The YAML config from '{config_path.name}' couldn't load because of "
|
||||
f"the following error:\n\n{exc}"
|
||||
)
|
||||
|
||||
# if no config file was loaded
|
||||
return {}
|
||||
|
||||
def _from_args(self, args: dict):
|
||||
"""loads config from the provided arguments"""
|
||||
config = {}
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Config file override detected in args.")
|
||||
config = self.from_file(pathlib.Path(config_override))
|
||||
return config # Return early if loading from file
|
||||
|
||||
for key in ["network", "model", "logging", "developer", "embeddings"]:
|
||||
override = args.get(key)
|
||||
if override:
|
||||
if key == "logging":
|
||||
# Strip the "log_" prefix from logging keys if present
|
||||
override = {k.replace("log_", ""): v for k, v in override.items()}
|
||||
config[key] = override
|
||||
|
||||
return config
|
||||
|
||||
def _from_environment(self):
|
||||
"""loads configuration from environment variables"""
|
||||
|
||||
# TODO: load config from environment variables
|
||||
# this means that we can have host default to 0.0.0.0 in docker for example
|
||||
# this would also mean that docker containers no longer require a non
|
||||
# default config file to be used
|
||||
pass
|
||||
|
||||
|
||||
# Create an empty instance of the config class
|
||||
config: TabbyConfig = TabbyConfig()
|
||||
|
|
@ -20,6 +20,25 @@ def prune_dict(input_dict):
|
|||
return {k: v for k, v in input_dict.items() if v is not None}
|
||||
|
||||
|
||||
def merge_dict(dict1, dict2):
|
||||
"""Merge 2 dictionaries"""
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
|
||||
merge_dict(dict1[key], value)
|
||||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
||||
|
||||
def merge_dicts(*dicts):
|
||||
"""Merge an arbitrary amount of dictionaries"""
|
||||
result = {}
|
||||
for dictionary in dicts:
|
||||
result = merge_dict(result, dictionary)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def flat_map(input_list):
|
||||
"""Flattens a list of lists into a single list."""
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
|
||||
from common import config, model
|
||||
from common import model
|
||||
from common.auth import check_api_key
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
|
|
@ -64,7 +65,7 @@ async def completion_request(
|
|||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
config.developer.get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
|
|
@ -128,7 +129,7 @@ async def chat_completion_request(
|
|||
data.json_schema = {"type": "object"}
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
config.developer.get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@ from sys import maxsize
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from common import config, model, sampling
|
||||
from common import model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.types.auth import AuthPermissionResponse
|
||||
|
|
@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList:
|
|||
Requires an admin key to see all models.
|
||||
"""
|
||||
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_dir = unwrap(config.model.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
draft_model_dir = config.draft_model.get("draft_model_dir")
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
if unwrap(config.model.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
|
|
@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
|||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
|
|
@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
|
|
@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
|
|
@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
|
|||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
|
|
@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
|||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
|
|
@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList:
|
|||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings_config().get("embedding_model_dir"), "models"
|
||||
config.embeddings.get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
|
|
@ -307,7 +303,7 @@ async def load_embedding_model(
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from common import config
|
||||
from common.logger import UVICORN_LOG_CONFIG
|
||||
from common.networking import get_global_depends
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.Kobold import router as KoboldRouter
|
||||
from endpoints.OAI import router as OAIRouter
|
||||
|
|
@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
api_servers = unwrap(config.network_config().get("api_servers"), [])
|
||||
api_servers = unwrap(config.network.get("api_servers"), [])
|
||||
|
||||
# Map for API id to server router
|
||||
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||
|
|
|
|||
58
main.py
58
main.py
|
|
@ -9,12 +9,13 @@ import signal
|
|||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common import config, gen_logging, sampling, model
|
||||
from common import gen_logging, sampling, model
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
from common.auth import load_auth_keys
|
||||
from common.logger import setup_logger
|
||||
from common.networking import is_port_in_use
|
||||
from common.signals import signal_handler
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.server import export_openapi, start_api
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
|
@ -26,10 +27,8 @@ if not do_export_openapi:
|
|||
async def entrypoint_async():
|
||||
"""Async entry function for program startup"""
|
||||
|
||||
network_config = config.network_config()
|
||||
|
||||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
||||
port = unwrap(network_config.get("port"), 5000)
|
||||
host = unwrap(config.network.get("host"), "127.0.0.1")
|
||||
port = unwrap(config.network.get("port"), 5000)
|
||||
|
||||
# Check if the port is available and attempt to bind a fallback
|
||||
if is_port_in_use(port):
|
||||
|
|
@ -51,18 +50,16 @@ async def entrypoint_async():
|
|||
port = fallback_port
|
||||
|
||||
# Initialize auth keys
|
||||
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
|
||||
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
log_config = config.logging_config()
|
||||
if log_config:
|
||||
gen_logging.update_from_dict(log_config)
|
||||
if config.logging:
|
||||
gen_logging.update_from_dict(config.logging)
|
||||
|
||||
gen_logging.broadcast_status()
|
||||
|
||||
# Set sampler parameter overrides if provided
|
||||
sampling_config = config.sampling_config()
|
||||
sampling_override_preset = sampling_config.get("override_preset")
|
||||
sampling_override_preset = config.sampling.get("override_preset")
|
||||
if sampling_override_preset:
|
||||
try:
|
||||
sampling.overrides_from_file(sampling_override_preset)
|
||||
|
|
@ -71,32 +68,29 @@ async def entrypoint_async():
|
|||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
model_config = config.model_config()
|
||||
model_name = model_config.get("model_name")
|
||||
model_name = config.model.get("model_name")
|
||||
if model_name:
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = model_path / model_name
|
||||
|
||||
await model.load_model(model_path.resolve(), **model_config)
|
||||
await model.load_model(model_path.resolve(), **config.model)
|
||||
|
||||
# Load loras after loading the model
|
||||
lora_config = config.lora_config()
|
||||
if lora_config.get("loras"):
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
await model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
if config.lora.get("loras"):
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
await model.container.load_loras(lora_dir.resolve(), **config.lora)
|
||||
|
||||
# If an initial embedding model name is specified, create a separate container
|
||||
# and load the model
|
||||
embedding_config = config.embeddings_config()
|
||||
embedding_model_name = embedding_config.get("embedding_model_name")
|
||||
embedding_model_name = config.embeddings.get("embedding_model_name")
|
||||
if embedding_model_name:
|
||||
embedding_model_path = pathlib.Path(
|
||||
unwrap(embedding_config.get("embedding_model_dir"), "models")
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_path / embedding_model_name
|
||||
|
||||
try:
|
||||
await model.load_embedding_model(embedding_model_path, **embedding_config)
|
||||
await model.load_embedding_model(embedding_model_path, **config.embeddings)
|
||||
except ImportError as ex:
|
||||
logger.error(ex.msg)
|
||||
|
||||
|
|
@ -110,15 +104,13 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Load from YAML config
|
||||
config.from_file(pathlib.Path("config.yml"))
|
||||
|
||||
# Parse and override config from args
|
||||
if arguments is None:
|
||||
parser = init_argparser()
|
||||
arguments = convert_args_to_dict(parser.parse_args(), parser)
|
||||
|
||||
config.from_args(arguments)
|
||||
# load config
|
||||
config.load(arguments)
|
||||
|
||||
if do_export_openapi:
|
||||
openapi_json = export_openapi()
|
||||
|
|
@ -129,12 +121,10 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||
|
||||
return
|
||||
|
||||
developer_config = config.developer_config()
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
|
||||
if unwrap(developer_config.get("unsafe_launch"), False):
|
||||
print(f"MAIN.PY {config=}")
|
||||
if unwrap(config.developer.get("unsafe_launch"), False):
|
||||
logger.warning(
|
||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||
"If you aren't a developer, please keep this off!"
|
||||
|
|
@ -143,12 +133,12 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||
check_exllama_version()
|
||||
|
||||
# Enable CUDA malloc backend
|
||||
if unwrap(developer_config.get("cuda_malloc_backend"), False):
|
||||
if unwrap(config.developer.get("cuda_malloc_backend"), False):
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
||||
|
||||
# Use Uvloop/Winloop
|
||||
if unwrap(developer_config.get("uvloop"), False):
|
||||
if unwrap(config.developer.get("uvloop"), False):
|
||||
if platform.system() == "Windows":
|
||||
from winloop import install
|
||||
else:
|
||||
|
|
@ -160,7 +150,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
||||
|
||||
# Set the process priority
|
||||
if unwrap(developer_config.get("realtime_process_priority"), False):
|
||||
if unwrap(config.developer.get("realtime_process_priority"), False):
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue