From 362b8d5818bbe3fa46128f596d952bfb6fd9d349 Mon Sep 17 00:00:00 2001 From: Jake <84923604+SecretiveShell@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:04:56 +0100 Subject: [PATCH] config is now backed by pydantic (WIP) - add models for config options - add function to regenerate config.yml - replace references to config with pydantic compatible references - remove unnecessary unwrap() statements TODO: - auto generate env vars - auto generate argparse - test loading a model --- common/config_models.py | 248 ++++++++++++++++++++++++++++++++++ common/downloader.py | 4 +- common/gen_logging.py | 32 +---- common/model.py | 2 +- common/networking.py | 4 +- common/tabby_config.py | 29 ++-- endpoints/OAI/router.py | 8 +- endpoints/core/router.py | 24 ++-- endpoints/core/types/model.py | 4 +- endpoints/server.py | 2 +- main.py | 34 ++--- 11 files changed, 297 insertions(+), 94 deletions(-) create mode 100644 common/config_models.py diff --git a/common/config_models.py b/common/config_models.py new file mode 100644 index 0000000..9bbf5f1 --- /dev/null +++ b/common/config_models.py @@ -0,0 +1,248 @@ +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing import List, Optional, Union, get_type_hints + +from common.utils import unwrap + + +class config_config_model(BaseModel): + config: Optional[str] = Field( + None, description="Path to an overriding config.yml file" + ) + + +class network_config_model(BaseModel): + host: Optional[str] = Field("127.0.0.1", description="The IP to host on") + port: Optional[int] = Field(5000, description="The port to host on") + disable_auth: Optional[bool] = Field( + False, description="Disable HTTP token authentication with requests" + ) + send_tracebacks: Optional[bool] = Field( + False, description="Decide whether to send error tracebacks over the API" + ) + api_servers: Optional[List[str]] = Field( + [ + "OAI", + ], + description="API servers to enable. Options: (OAI, Kobold)", + ) + + +class logging_config_model(BaseModel): + log_prompt: Optional[bool] = Field(False, description="Enable prompt logging") + log_generation_params: Optional[bool] = Field( + False, description="Enable generation parameter logging" + ) + log_requests: Optional[bool] = Field(False, description="Enable request logging") + + +class model_config_model(BaseModel): + model_dir: str = Field( + "models", + description="Overrides the directory to look for models (default: models). Windows users, do NOT put this path in quotes.", + ) + use_dummy_models: Optional[bool] = Field( + False, + description="Sends dummy model names when the models endpoint is queried. Enable this if looking for specific OAI models.", + ) + model_name: Optional[str] = Field( + None, + description="An initial model to load. Make sure the model is located in the model directory! REQUIRED: This must be filled out to load a model on startup.", + ) + use_as_default: List[str] = Field( + default_factory=list, + description="Names of args to use as a default fallback for API load requests (default: []). Example: ['max_seq_len', 'cache_mode']", + ) + max_seq_len: Optional[int] = Field( + None, + description="Max sequence length. Fetched from the model's base sequence length in config.json by default.", + ) + override_base_seq_len: Optional[int] = Field( + None, + description="Overrides base model context length. WARNING: Only use this if the model's base sequence length is incorrect.", + ) + tensor_parallel: Optional[bool] = Field( + False, + description="Load model with tensor parallelism. Fallback to autosplit if GPU split isn't provided.", + ) + gpu_split_auto: Optional[bool] = Field( + True, + description="Automatically allocate resources to GPUs (default: True). Not parsed for single GPU users.", + ) + autosplit_reserve: List[int] = Field( + [96], + description="Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0). Represented as an array of MB per GPU.", + ) + gpu_split: List[float] = Field( + default_factory=list, + description="An integer array of GBs of VRAM to split between GPUs (default: []). Used with tensor parallelism.", + ) + rope_scale: Optional[float] = Field( + 1.0, + description="Rope scale (default: 1.0). Same as compress_pos_emb. Only use if the model was trained on long context with rope.", + ) + rope_alpha: Optional[Union[float, str]] = Field( + 1.0, + description="Rope alpha (default: 1.0). Same as alpha_value. Set to 'auto' to auto-calculate.", + ) + cache_mode: Optional[str] = Field( + "FP16", + description="Enable different cache modes for VRAM savings (default: FP16). Possible values: FP16, Q8, Q6, Q4.", + ) + cache_size: Optional[int] = Field( + None, + description="Size of the prompt cache to allocate (default: max_seq_len). Must be a multiple of 256.", + ) + chunk_size: Optional[int] = Field( + 2048, + description="Chunk size for prompt ingestion (default: 2048). A lower value reduces VRAM usage but decreases ingestion speed.", + ) + max_batch_size: Optional[int] = Field( + None, + description="Set the maximum number of prompts to process at one time (default: None/Automatic). Automatically calculated if left blank.", + ) + prompt_template: Optional[str] = Field( + None, + description="Set the prompt template for this model. If empty, attempts to look for the model's chat template.", + ) + num_experts_per_token: Optional[int] = Field( + None, + description="Number of experts to use per token. Fetched from the model's config.json. For MoE models only.", + ) + fasttensors: Optional[bool] = Field( + False, + description="Enables fasttensors to possibly increase model loading speeds (default: False).", + ) + + +class draft_model_config_model(BaseModel): + draft_model_dir: Optional[str] = Field( + "models", + description="Overrides the directory to look for draft models (default: models)", + ) + draft_model_name: Optional[str] = Field( + None, + description="An initial draft model to load. Ensure the model is in the model directory.", + ) + draft_rope_scale: Optional[float] = Field( + 1.0, + description="Rope scale for draft models (default: 1.0). Same as compress_pos_emb. Use if the draft model was trained on long context with rope.", + ) + draft_rope_alpha: Optional[float] = Field( + None, + description="Rope alpha for draft models (default: None). Same as alpha_value. Leave blank to auto-calculate the alpha value.", + ) + draft_cache_mode: Optional[str] = Field( + "FP16", + description="Cache mode for draft models to save VRAM (default: FP16). Possible values: FP16, Q8, Q6, Q4.", + ) + + +class lora_instance_model(BaseModel): + name: str = Field(..., description="Name of the LoRA model") + scaling: float = Field( + 1.0, description="Scaling factor for the LoRA model (default: 1.0)" + ) + + +class lora_config_model(BaseModel): + lora_dir: Optional[str] = Field( + "loras", description="Directory to look for LoRAs (default: 'loras')" + ) + loras: Optional[List[lora_instance_model]] = Field( + None, + description="List of LoRAs to load and associated scaling factors (default scaling: 1.0)", + ) + + +class sampling_config_model(BaseModel): + override_preset: Optional[str] = Field( + None, description="Select a sampler override preset" + ) + + +class developer_config_model(BaseModel): + unsafe_launch: Optional[bool] = Field( + False, description="Skip Exllamav2 version check" + ) + disable_request_streaming: Optional[bool] = Field( + False, description="Disables API request streaming" + ) + cuda_malloc_backend: Optional[bool] = Field( + False, description="Runs with the pytorch CUDA malloc backend" + ) + uvloop: Optional[bool] = Field( + False, description="Run asyncio using Uvloop or Winloop" + ) + realtime_process_priority: Optional[bool] = Field( + False, + description="Set process to use a higher priority For realtime process priority, run as administrator or sudo Otherwise, the priority will be set to high", + ) + + +class embeddings_config_model(BaseModel): + embedding_model_dir: Optional[str] = Field( + "models", + description="Overrides directory to look for embedding models (default: models)", + ) + embeddings_device: Optional[str] = Field( + "cpu", + description="Device to load embedding models on (default: cpu). Possible values: cpu, auto, cuda. If using an AMD GPU, set this value to 'cuda'.", + ) + embedding_model_name: Optional[str] = Field( + None, description="The embeddings model to load" + ) + + +class tabby_config_model(BaseModel): + config: config_config_model = Field(default_factory=config_config_model) + network: network_config_model = Field(default_factory=network_config_model) + logging: logging_config_model = Field(default_factory=logging_config_model) + model: model_config_model = Field(default_factory=model_config_model) + draft_model: draft_model_config_model = Field( + default_factory=draft_model_config_model + ) + lora: lora_config_model = Field(default_factory=lora_config_model) + sampling: sampling_config_model = Field(default_factory=sampling_config_model) + developer: developer_config_model = Field(default_factory=developer_config_model) + embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model) + + @model_validator(mode="before") + def set_defaults(cls, values): + for field_name, field_value in values.items(): + if field_value is None: + default_instance = cls.__annotations__[field_name]().dict() + values[field_name] = cls.__annotations__[field_name](**default_instance) + return values + + model_config = ConfigDict(validate_assignment=True) + + +def generate_config_file(filename="config_sample.yml", indentation=2): + schema = tabby_config_model.model_json_schema() + + def dump_def(id: str, indent=2): + yaml = "" + indent = " " * indentation * indent + id = id.split("/")[-1] + + section = schema["$defs"][id]["properties"] + for property in section.keys(): # get type + comment = section[property]["description"] + yaml += f"{indent}# {comment}\n" + + value = unwrap(section[property].get("default"), "") + yaml += f"{indent}{property}: {value}\n\n" + + return yaml + "\n" + + yaml = "" + for section in schema["properties"].keys(): + yaml += f"{section}:\n" + yaml += dump_def(schema["properties"][section]["$ref"]) + yaml += "\n" + + with open(filename, "w") as f: + f.write(yaml) + + +# generate_config_file("test.yml") diff --git a/common/downloader.py b/common/downloader.py index b0a8d93..6813e0d 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str """Gets the download folder for the repo.""" if repo_type == "lora": - download_path = pathlib.Path(config.lora.get("lora_dir") or "loras") + download_path = pathlib.Path(config.lora.lora_dir) else: - download_path = pathlib.Path(config.model.get("model_dir") or "models") + download_path = pathlib.Path(config.model.model_dir) download_path = download_path / (folder_name or repo_id.split("/")[-1]) return download_path diff --git a/common/gen_logging.py b/common/gen_logging.py index 9995818..3252bb2 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -6,37 +6,19 @@ from pydantic import BaseModel from loguru import logger from typing import Dict, Optional - -class GenLogPreferences(BaseModel): - """Logging preference config.""" - - prompt: bool = False - generation_params: bool = False - +from common.tabby_config import config # Global logging preferences constant -PREFERENCES = GenLogPreferences() - - -def update_from_dict(options_dict: Dict[str, bool]): - """Wrapper to set the logging config for generations""" - global PREFERENCES - - # Force bools on the dict - for value in options_dict.values(): - if value is None: - value = False - - PREFERENCES = GenLogPreferences.model_validate(options_dict) +PREFERENCES = config.logging def broadcast_status(): """Broadcasts the current logging status""" enabled = [] - if PREFERENCES.prompt: + if PREFERENCES.log_prompt: enabled.append("prompts") - if PREFERENCES.generation_params: + if PREFERENCES.log_generation_params: enabled.append("generation params") if len(enabled) > 0: @@ -47,13 +29,13 @@ def broadcast_status(): def log_generation_params(**kwargs): """Logs generation parameters to console.""" - if PREFERENCES.generation_params: + if PREFERENCES.log_generation_params: logger.info(f"Generation options: {kwargs}\n") def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]): """Logs the prompt to console.""" - if PREFERENCES.prompt: + if PREFERENCES.log_prompt: formatted_prompt = "\n" + prompt logger.info( f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n" @@ -66,7 +48,7 @@ def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]): def log_response(request_id: str, response: str): """Logs the response to console.""" - if PREFERENCES.prompt: + if PREFERENCES.log_prompt: formatted_response = "\n" + response logger.info( f"Response (ID: {request_id}): " diff --git a/common/model.py b/common/model.py index a9ddfff..4da1d90 100644 --- a/common/model.py +++ b/common/model.py @@ -153,7 +153,7 @@ async def unload_embedding_model(): def get_config_default(key: str, model_type: str = "model"): """Fetches a default value from model config if allowed by the user.""" - default_keys = unwrap(config.model.get("use_as_default"), []) + default_keys = unwrap(config.model.use_as_default, []) # Add extra keys to defaults default_keys.append("embeddings_device") diff --git a/common/networking.py b/common/networking.py index be6f1ab..e081272 100644 --- a/common/networking.py +++ b/common/networking.py @@ -39,7 +39,7 @@ def handle_request_error(message: str, exc_info: bool = True): """Log a request error to the console.""" trace = traceback.format_exc() - send_trace = unwrap(config.network.get("send_tracebacks"), False) + send_trace = config.network.send_tracebacks error_message = TabbyRequestErrorMessage( message=message, trace=trace if send_trace else None @@ -134,7 +134,7 @@ def get_global_depends(): depends = [Depends(add_request_id)] - if config.logging.get("requests"): + if config.logging.log_requests: depends.append(Depends(log_request)) return depends diff --git a/common/tabby_config.py b/common/tabby_config.py index c0c9e58..5aac0b8 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -4,21 +4,11 @@ from loguru import logger from typing import Optional from common.utils import unwrap, merge_dicts +from common.config_models import tabby_config_model +import common.config_models -class TabbyConfig: - network: dict = {} - logging: dict = {} - model: dict = {} - draft_model: dict = {} - lora: dict = {} - sampling: dict = {} - developer: dict = {} - embeddings: dict = {} - - def __init__(self): - pass - +class TabbyConfig(tabby_config_model): def load_config(self, arguments: Optional[dict] = None): """load the global application config""" @@ -30,14 +20,11 @@ class TabbyConfig: merged_config = merge_dicts(*configs) - self.network = unwrap(merged_config.get("network"), {}) - self.logging = unwrap(merged_config.get("logging"), {}) - self.model = unwrap(merged_config.get("model"), {}) - self.draft_model = unwrap(merged_config.get("draft"), {}) - self.lora = unwrap(merged_config.get("draft"), {}) - self.sampling = unwrap(merged_config.get("sampling"), {}) - self.developer = unwrap(merged_config.get("developer"), {}) - self.embeddings = unwrap(merged_config.get("embeddings"), {}) + for field in tabby_config_model.model_fields.keys(): + value = unwrap(merged_config.get(field), {}) + model = getattr(common.config_models, f"{field}_config_model") + + setattr(self, field, model.parse_obj(value)) def _from_file(self, config_path: pathlib.Path): """loads config from a given file path""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b888f19..7cf08d7 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -58,9 +58,7 @@ async def completion_request( if isinstance(data.prompt, list): data.prompt = "\n".join(data.prompt) - disable_request_streaming = unwrap( - config.developer.get("disable_request_streaming"), False - ) + disable_request_streaming = config.developer.disable_request_streaming # Set an empty JSON schema if the request wants a JSON response if data.response_format.type == "json": @@ -117,9 +115,7 @@ async def chat_completion_request( if data.response_format.type == "json": data.json_schema = {"type": "object"} - disable_request_streaming = unwrap( - config.developer.get("disable_request_streaming"), False - ) + disable_request_streaming = config.developer.disable_request_streaming if data.stream and not disable_request_streaming: return EventSourceResponse( diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 1f9d194..29a615c 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -62,17 +62,17 @@ async def list_models(request: Request) -> ModelList: Requires an admin key to see all models. """ - model_dir = unwrap(config.model.get("model_dir"), "models") + model_dir = config.model.model_dir model_path = pathlib.Path(model_dir) - draft_model_dir = config.draft_model.get("draft_model_dir") + draft_model_dir = config.draft_model.draft_model_dir if get_key_permission(request) == "admin": models = get_model_list(model_path.resolve(), draft_model_dir) else: models = await get_current_model_list() - if unwrap(config.model.get("use_dummy_models"), False): + if config.model.use_dummy_models: models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) return models @@ -98,7 +98,7 @@ async def list_draft_models(request: Request) -> ModelList: """ if get_key_permission(request) == "admin": - draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models") + draft_model_dir = config.draft_model.draft_model_dir draft_model_path = pathlib.Path(draft_model_dir) models = get_model_list(draft_model_path.resolve()) @@ -122,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models")) + model_path = pathlib.Path(config.model.model_dir) model_path = model_path / data.name draft_model_path = None @@ -135,7 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models") + draft_model_path = config.draft_model.draft_model_dir if not model_path.exists(): error_message = handle_request_error( @@ -192,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList: """ if get_key_permission(request) == "admin": - lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) + lora_path = pathlib.Path(config.lora.lora_dir) loras = get_lora_list(lora_path.resolve()) else: loras = get_active_loras() @@ -227,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: raise HTTPException(400, error_message) - lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) + lora_dir = pathlib.Path(config.lora.lora_dir) if not lora_dir.exists(): error_message = handle_request_error( "A parent lora directory does not exist for load. Check your config.yml?", @@ -266,9 +266,7 @@ async def list_embedding_models(request: Request) -> ModelList: """ if get_key_permission(request) == "admin": - embedding_model_dir = unwrap( - config.embeddings.get("embedding_model_dir"), "models" - ) + embedding_model_dir = config.embeddings.embedding_model_dir embedding_model_path = pathlib.Path(embedding_model_dir) models = get_model_list(embedding_model_path.resolve()) @@ -302,9 +300,7 @@ async def load_embedding_model( raise HTTPException(400, error_message) - embedding_model_dir = pathlib.Path( - unwrap(config.embeddings.get("embedding_model_dir"), "models") - ) + embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir) embedding_model_path = embedding_model_dir / data.name if not embedding_model_path.exists(): diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 154a906..6966359 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, ConfigDict from time import time from typing import List, Literal, Optional, Union -from common.gen_logging import GenLogPreferences +from common.config_models import logging_config_model from common.model import get_config_default @@ -33,7 +33,7 @@ class ModelCard(BaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time())) owned_by: str = "tabbyAPI" - logging: Optional[GenLogPreferences] = None + logging: Optional[logging_config_model] = None parameters: Optional[ModelCardParameters] = None diff --git a/endpoints/server.py b/endpoints/server.py index 0f6a19b..e1c81b5 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None): allow_headers=["*"], ) - api_servers = unwrap(config.network.get("api_servers"), []) + api_servers = config.network.api_servers # Map for API id to server router router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter} diff --git a/main.py b/main.py index f017ecc..89cc6bf 100644 --- a/main.py +++ b/main.py @@ -27,8 +27,8 @@ if not do_export_openapi: async def entrypoint_async(): """Async entry function for program startup""" - host = unwrap(config.network.get("host"), "127.0.0.1") - port = unwrap(config.network.get("port"), 5000) + host = config.network.host + port = config.network.port # Check if the port is available and attempt to bind a fallback if is_port_in_use(port): @@ -50,16 +50,12 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(unwrap(config.network.get("disable_auth"), False)) - - # Override the generation log options if given - if config.logging: - gen_logging.update_from_dict(config.logging) + load_auth_keys(config.network.disable_auth) gen_logging.broadcast_status() # Set sampler parameter overrides if provided - sampling_override_preset = config.sampling.get("override_preset") + sampling_override_preset = config.sampling.override_preset if sampling_override_preset: try: sampling.overrides_from_file(sampling_override_preset) @@ -68,25 +64,23 @@ async def entrypoint_async(): # If an initial model name is specified, create a container # and load the model - model_name = config.model.get("model_name") + model_name = config.model.model_name if model_name: - model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models")) + model_path = pathlib.Path(config.model.model_dir) model_path = model_path / model_name await model.load_model(model_path.resolve(), **config.model) # Load loras after loading the model - if config.lora.get("loras"): - lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras")) + if config.lora.loras: + lora_dir = pathlib.Path(config.lora.lora_dir) await model.container.load_loras(lora_dir.resolve(), **config.lora) # If an initial embedding model name is specified, create a separate container # and load the model - embedding_model_name = config.embeddings.get("embedding_model_name") + embedding_model_name = config.embeddings.embedding_model_name if embedding_model_name: - embedding_model_path = pathlib.Path( - unwrap(config.embeddings.get("embedding_model_dir"), "models") - ) + embedding_model_path = pathlib.Path(config.embeddings.embedding_model_dir) embedding_model_path = embedding_model_path / embedding_model_name try: @@ -124,7 +118,7 @@ def entrypoint(arguments: Optional[dict] = None): # Check exllamav2 version and give a descriptive error if it's too old # Skip if launching unsafely print(f"MAIN.PY {config=}") - if unwrap(config.developer.get("unsafe_launch"), False): + if config.developer.unsafe_launch: logger.warning( "UNSAFE: Skipping ExllamaV2 version check.\n" "If you aren't a developer, please keep this off!" @@ -133,12 +127,12 @@ def entrypoint(arguments: Optional[dict] = None): check_exllama_version() # Enable CUDA malloc backend - if unwrap(config.developer.get("cuda_malloc_backend"), False): + if config.developer.cuda_malloc_backend: os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.") # Use Uvloop/Winloop - if unwrap(config.developer.get("uvloop"), False): + if config.developer.uvloop: if platform.system() == "Windows": from winloop import install else: @@ -150,7 +144,7 @@ def entrypoint(arguments: Optional[dict] = None): logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.") # Set the process priority - if unwrap(config.developer.get("realtime_process_priority"), False): + if config.developer.realtime_process_priority: import psutil current_process = psutil.Process(os.getpid())