Config: Cleanup and organize functions

Remove access of private attributes and use safer functions. Also
move generalized functions into utils files.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-09-14 21:48:39 -04:00
parent 0903f852db
commit a09dd802c2
3 changed files with 47 additions and 15 deletions

View file

@ -1,20 +1,10 @@
"""Argparser for overriding config values"""
import argparse
from typing import Any
from pydantic import BaseModel
from common.config_models import TabbyConfigModel
def is_list_type(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is list:
return True
if hasattr(type_hint, "__args__"):
# Recursively check for lists inside type arguments
return any(is_list_type(arg) for arg in type_hint.__args__)
return False
from common.utils import is_list_type
def add_field_to_group(group, field_name, field_type, field) -> None:
@ -26,6 +16,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None:
"help": field.description if field.description else "No description available",
}
# If the inner type contains a list, specify argparse as such
if is_list_type(field_type):
kwargs["nargs"] = "+"
@ -63,7 +54,7 @@ def init_argparser() -> argparse.ArgumentParser:
def convert_args_to_dict(
args: argparse.Namespace, parser: argparse.ArgumentParser
) -> dict[str, dict[str, Any]]:
) -> dict:
"""Broad conversion of surface level arg groups to dictionaries"""
arg_groups = {}

View file

@ -6,6 +6,8 @@ CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
class ConfigOverrideConfig(BaseModel):
"""Model for overriding a provided config file."""
# TODO: convert this to a pathlib.path?
config: Optional[str] = Field(
None, description=("Path to an overriding config.yml file")
@ -13,6 +15,9 @@ class ConfigOverrideConfig(BaseModel):
class UtilityActions(BaseModel):
"""Model used for arg actions."""
# YAML export options
export_config: Optional[str] = Field(
None, description="generate a template config file"
)
@ -20,6 +25,7 @@ class UtilityActions(BaseModel):
"config_sample.yml", description="path to export configuration file to"
)
# OpenAPI JSON export options
export_openapi: Optional[bool] = Field(
False, description="export openapi schema files"
)
@ -29,6 +35,8 @@ class UtilityActions(BaseModel):
class NetworkConfig(BaseModel):
"""Model for network configuration."""
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
port: Optional[int] = Field(5000, description=("The port to host on"))
disable_auth: Optional[bool] = Field(
@ -47,6 +55,8 @@ class NetworkConfig(BaseModel):
class LoggingConfig(BaseModel):
"""Model for logging configuration."""
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
log_generation_params: Optional[bool] = Field(
False, description=("Enable generation parameter logging")
@ -55,6 +65,8 @@ class LoggingConfig(BaseModel):
class ModelConfig(BaseModel):
"""Model for LLM configuration."""
# TODO: convert this to a pathlib.path?
model_dir: str = Field(
"models",
@ -201,6 +213,8 @@ class ModelConfig(BaseModel):
class DraftModelConfig(BaseModel):
"""Model for draft LLM model configuration."""
# TODO: convert this to a pathlib.path?
draft_model_dir: Optional[str] = Field(
"models",
@ -239,6 +253,8 @@ class DraftModelConfig(BaseModel):
class LoraInstanceModel(BaseModel):
"""Model representing an instance of a Lora."""
name: str = Field(..., description=("Name of the LoRA model"))
scaling: float = Field(
1.0,
@ -248,6 +264,8 @@ class LoraInstanceModel(BaseModel):
class LoraConfig(BaseModel):
"""Model for lora configuration."""
# TODO: convert this to a pathlib.path?
lora_dir: Optional[str] = Field(
"loras", description=("Directory to look for LoRAs (default: 'loras')")
@ -262,12 +280,16 @@ class LoraConfig(BaseModel):
class SamplingConfig(BaseModel):
"""Model for sampling (overrides) config."""
override_preset: Optional[str] = Field(
None, description=("Select a sampler override preset")
)
class DeveloperConfig(BaseModel):
"""Model for developer settings configuration."""
unsafe_launch: Optional[bool] = Field(
False, description=("Skip Exllamav2 version check")
)
@ -290,6 +312,8 @@ class DeveloperConfig(BaseModel):
class EmbeddingsConfig(BaseModel):
"""Model for embeddings configuration."""
# TODO: convert this to a pathlib.path?
embedding_model_dir: Optional[str] = Field(
"models",
@ -310,6 +334,8 @@ class EmbeddingsConfig(BaseModel):
class TabbyConfigModel(BaseModel):
"""Base model for a TabbyConfig."""
config: ConfigOverrideConfig = Field(
default_factory=ConfigOverrideConfig.model_construct
)
@ -331,6 +357,8 @@ class TabbyConfigModel(BaseModel):
def generate_config_file(filename="config_sample.yml", indentation=2):
"""Creates a config.yml file from Pydantic models."""
schema = TabbyConfigModel.model_json_schema()
def dump_def(id: str, indent=2):
@ -356,6 +384,3 @@ def generate_config_file(filename="config_sample.yml", indentation=2):
with open(filename, "w") as f:
f.write(yaml)
# generate_config_file("test.yml")

View file

@ -1,5 +1,7 @@
"""Common utility functions"""
from typing import get_args, get_origin
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
@ -43,3 +45,17 @@ def flat_map(input_list):
"""Flattens a list of lists into a single list."""
return [item for sublist in input_list for item in sublist]
def is_list_type(type_hint):
"""Checks if a type contains a list."""
if get_origin(type_hint) is list:
return True
# Recursively check for lists inside type arguments
type_args = get_args(type_hint)
if type_args:
return any(is_list_type(arg) for arg in type_args)
return False